├── LICENSE.md
├── README.md
├── architecture.png
├── atari
├── LICENSE
├── conda_env.yml
├── create_dataset.py
├── fixed_replay_buffer.py
├── mingpt
│ ├── __init__.py
│ ├── model_atari.py
│ ├── trainer_atari.py
│ └── utils.py
├── readme-atari.md
├── run.sh
└── run_dt_atari.py
└── gym
├── conda_env.yml
├── data
└── download_d4rl_datasets.py
├── decision_transformer
├── envs
│ ├── assets
│ │ └── reacher_2d.xml
│ └── reacher_2d.py
├── evaluation
│ └── evaluate_episodes.py
├── models
│ ├── decision_transformer.py
│ ├── mlp_bc.py
│ ├── model.py
│ └── trajectory_gpt2.py
└── training
│ ├── act_trainer.py
│ ├── seq_trainer.py
│ └── trainer.py
├── experiment.py
└── readme-gym.md
/LICENSE.md:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2021 Decision Transformer (Decision Transformer: Reinforcement Learning via Sequence Modeling) Authors (https://arxiv.org/abs/2106.01345)
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions:
6 |
7 | The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software.
8 |
9 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
10 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 |
2 | # Decision Transformer
3 |
4 | Lili Chen\*, Kevin Lu\*, Aravind Rajeswaran, Kimin Lee, Aditya Grover, Michael Laskin, Pieter Abbeel, Aravind Srinivas†, and Igor Mordatch†
5 |
6 | \*equal contribution, †equal advising
7 |
8 | A link to our paper can be found on [arXiv](https://arxiv.org/abs/2106.01345).
9 |
10 | ## Overview
11 |
12 | Official codebase for [Decision Transformer: Reinforcement Learning via Sequence Modeling](https://sites.google.com/berkeley.edu/decision-transformer).
13 | Contains scripts to reproduce experiments.
14 |
15 | 
16 |
17 | ## Instructions
18 |
19 | We provide code in two sub-directories: `atari` containing code for Atari experiments and `gym` containing code for OpenAI Gym experiments.
20 | See corresponding READMEs in each folder for instructions; scripts should be run from the respective directories.
21 | It may be necessary to add the respective directories to your PYTHONPATH.
22 |
23 | ## Citation
24 |
25 | Please cite our paper as:
26 |
27 | ```
28 | @article{chen2021decisiontransformer,
29 | title={Decision Transformer: Reinforcement Learning via Sequence Modeling},
30 | author={Lili Chen and Kevin Lu and Aravind Rajeswaran and Kimin Lee and Aditya Grover and Michael Laskin and Pieter Abbeel and Aravind Srinivas and Igor Mordatch},
31 | journal={arXiv preprint arXiv:2106.01345},
32 | year={2021}
33 | }
34 | ```
35 |
36 | Note: this is not an official Google or Facebook product.
37 |
38 | ## License
39 |
40 | MIT
41 |
--------------------------------------------------------------------------------
/architecture.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/kzl/decision-transformer/e2d82e68f330c00f763507b3b01d774740bee53f/architecture.png
--------------------------------------------------------------------------------
/atari/LICENSE:
--------------------------------------------------------------------------------
1 | The MIT License (MIT) Copyright (c) 2020 Andrej Karpathy
2 |
3 | Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions:
4 |
5 | The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software.
6 |
7 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
8 |
--------------------------------------------------------------------------------
/atari/conda_env.yml:
--------------------------------------------------------------------------------
1 | name: decision-transformer-atari
2 | channels:
3 | - pytorch
4 | dependencies:
5 | - python=3.7.9
6 | - pytorch=1.2
7 | - cudatoolkit=10.
8 | - numpy
9 | - psutil
10 | - opencv
11 | - pip
12 | - pip:
13 | - atari-py
14 | - pyprind
15 | - tensorflow-gpu>=1.13
16 | - absl-py
17 | - atari-py
18 | - gin-config
19 | - gym
20 | - tqdm
21 | - blosc
22 | - git+https://github.com/google/dopamine.git
23 |
--------------------------------------------------------------------------------
/atari/create_dataset.py:
--------------------------------------------------------------------------------
1 | import csv
2 | import logging
3 | # make deterministic
4 | from mingpt.utils import set_seed
5 | import numpy as np
6 | import torch
7 | import torch.nn as nn
8 | from torch.nn import functional as F
9 | import math
10 | from torch.utils.data import Dataset
11 | from mingpt.model_atari import GPT, GPTConfig
12 | from mingpt.trainer_atari import Trainer, TrainerConfig
13 | from mingpt.utils import sample
14 | from collections import deque
15 | import random
16 | import torch
17 | import pickle
18 | import blosc
19 | import argparse
20 | from fixed_replay_buffer import FixedReplayBuffer
21 |
22 | def create_dataset(num_buffers, num_steps, game, data_dir_prefix, trajectories_per_buffer):
23 | # -- load data from memory (make more efficient)
24 | obss = []
25 | actions = []
26 | returns = [0]
27 | done_idxs = []
28 | stepwise_returns = []
29 |
30 | transitions_per_buffer = np.zeros(50, dtype=int)
31 | num_trajectories = 0
32 | while len(obss) < num_steps:
33 | buffer_num = np.random.choice(np.arange(50 - num_buffers, 50), 1)[0]
34 | i = transitions_per_buffer[buffer_num]
35 | print('loading from buffer %d which has %d already loaded' % (buffer_num, i))
36 | frb = FixedReplayBuffer(
37 | data_dir=data_dir_prefix + game + '/1/replay_logs',
38 | replay_suffix=buffer_num,
39 | observation_shape=(84, 84),
40 | stack_size=4,
41 | update_horizon=1,
42 | gamma=0.99,
43 | observation_dtype=np.uint8,
44 | batch_size=32,
45 | replay_capacity=100000)
46 | if frb._loaded_buffers:
47 | done = False
48 | curr_num_transitions = len(obss)
49 | trajectories_to_load = trajectories_per_buffer
50 | while not done:
51 | states, ac, ret, next_states, next_action, next_reward, terminal, indices = frb.sample_transition_batch(batch_size=1, indices=[i])
52 | states = states.transpose((0, 3, 1, 2))[0] # (1, 84, 84, 4) --> (4, 84, 84)
53 | obss += [states]
54 | actions += [ac[0]]
55 | stepwise_returns += [ret[0]]
56 | if terminal[0]:
57 | done_idxs += [len(obss)]
58 | returns += [0]
59 | if trajectories_to_load == 0:
60 | done = True
61 | else:
62 | trajectories_to_load -= 1
63 | returns[-1] += ret[0]
64 | i += 1
65 | if i >= 100000:
66 | obss = obss[:curr_num_transitions]
67 | actions = actions[:curr_num_transitions]
68 | stepwise_returns = stepwise_returns[:curr_num_transitions]
69 | returns[-1] = 0
70 | i = transitions_per_buffer[buffer_num]
71 | done = True
72 | num_trajectories += (trajectories_per_buffer - trajectories_to_load)
73 | transitions_per_buffer[buffer_num] = i
74 | print('this buffer has %d loaded transitions and there are now %d transitions total divided into %d trajectories' % (i, len(obss), num_trajectories))
75 |
76 | actions = np.array(actions)
77 | returns = np.array(returns)
78 | stepwise_returns = np.array(stepwise_returns)
79 | done_idxs = np.array(done_idxs)
80 |
81 | # -- create reward-to-go dataset
82 | start_index = 0
83 | rtg = np.zeros_like(stepwise_returns)
84 | for i in done_idxs:
85 | i = int(i)
86 | curr_traj_returns = stepwise_returns[start_index:i]
87 | for j in range(i-1, start_index-1, -1): # start from i-1
88 | rtg_j = curr_traj_returns[j-start_index:i-start_index]
89 | rtg[j] = sum(rtg_j)
90 | start_index = i
91 | print('max rtg is %d' % max(rtg))
92 |
93 | # -- create timestep dataset
94 | start_index = 0
95 | timesteps = np.zeros(len(actions)+1, dtype=int)
96 | for i in done_idxs:
97 | i = int(i)
98 | timesteps[start_index:i+1] = np.arange(i+1 - start_index)
99 | start_index = i+1
100 | print('max timestep is %d' % max(timesteps))
101 |
102 | return obss, actions, returns, done_idxs, rtg, timesteps
103 |
--------------------------------------------------------------------------------
/atari/fixed_replay_buffer.py:
--------------------------------------------------------------------------------
1 | # source: https://github.com/google-research/batch_rl/blob/master/batch_rl/fixed_replay/replay_memory/fixed_replay_buffer.py
2 |
3 | import collections
4 | from concurrent import futures
5 | from dopamine.replay_memory import circular_replay_buffer
6 | import numpy as np
7 | import tensorflow.compat.v1 as tf
8 | import gin
9 |
10 | gfile = tf.gfile
11 |
12 | STORE_FILENAME_PREFIX = circular_replay_buffer.STORE_FILENAME_PREFIX
13 |
14 | class FixedReplayBuffer(object):
15 | """Object composed of a list of OutofGraphReplayBuffers."""
16 |
17 | def __init__(self, data_dir, replay_suffix, *args, **kwargs): # pylint: disable=keyword-arg-before-vararg
18 | """Initialize the FixedReplayBuffer class.
19 | Args:
20 | data_dir: str, log Directory from which to load the replay buffer.
21 | replay_suffix: int, If not None, then only load the replay buffer
22 | corresponding to the specific suffix in data directory.
23 | *args: Arbitrary extra arguments.
24 | **kwargs: Arbitrary keyword arguments.
25 | """
26 | self._args = args
27 | self._kwargs = kwargs
28 | self._data_dir = data_dir
29 | self._loaded_buffers = False
30 | self.add_count = np.array(0)
31 | self._replay_suffix = replay_suffix
32 | if not self._loaded_buffers:
33 | if replay_suffix is not None:
34 | assert replay_suffix >= 0, 'Please pass a non-negative replay suffix'
35 | self.load_single_buffer(replay_suffix)
36 | else:
37 | self._load_replay_buffers(num_buffers=50)
38 |
39 | def load_single_buffer(self, suffix):
40 | """Load a single replay buffer."""
41 | replay_buffer = self._load_buffer(suffix)
42 | if replay_buffer is not None:
43 | self._replay_buffers = [replay_buffer]
44 | self.add_count = replay_buffer.add_count
45 | self._num_replay_buffers = 1
46 | self._loaded_buffers = True
47 |
48 | def _load_buffer(self, suffix):
49 | """Loads a OutOfGraphReplayBuffer replay buffer."""
50 | try:
51 | # pytype: disable=attribute-error
52 | replay_buffer = circular_replay_buffer.OutOfGraphReplayBuffer(
53 | *self._args, **self._kwargs)
54 | replay_buffer.load(self._data_dir, suffix)
55 | tf.logging.info('Loaded replay buffer ckpt {} from {}'.format(
56 | suffix, self._data_dir))
57 | # pytype: enable=attribute-error
58 | return replay_buffer
59 | except tf.errors.NotFoundError:
60 | return None
61 |
62 | def _load_replay_buffers(self, num_buffers=None):
63 | """Loads multiple checkpoints into a list of replay buffers."""
64 | if not self._loaded_buffers: # pytype: disable=attribute-error
65 | ckpts = gfile.ListDirectory(self._data_dir) # pytype: disable=attribute-error
66 | # Assumes that the checkpoints are saved in a format CKPT_NAME.{SUFFIX}.gz
67 | ckpt_counters = collections.Counter(
68 | [name.split('.')[-2] for name in ckpts])
69 | # Should contain the files for add_count, action, observation, reward,
70 | # terminal and invalid_range
71 | ckpt_suffixes = [x for x in ckpt_counters if ckpt_counters[x] in [6, 7]]
72 | if num_buffers is not None:
73 | ckpt_suffixes = np.random.choice(
74 | ckpt_suffixes, num_buffers, replace=False)
75 | self._replay_buffers = []
76 | # Load the replay buffers in parallel
77 | with futures.ThreadPoolExecutor(
78 | max_workers=num_buffers) as thread_pool_executor:
79 | replay_futures = [thread_pool_executor.submit(
80 | self._load_buffer, suffix) for suffix in ckpt_suffixes]
81 | for f in replay_futures:
82 | replay_buffer = f.result()
83 | if replay_buffer is not None:
84 | self._replay_buffers.append(replay_buffer)
85 | self.add_count = max(replay_buffer.add_count, self.add_count)
86 | self._num_replay_buffers = len(self._replay_buffers)
87 | if self._num_replay_buffers:
88 | self._loaded_buffers = True
89 |
90 | def get_transition_elements(self):
91 | return self._replay_buffers[0].get_transition_elements()
92 |
93 | def sample_transition_batch(self, batch_size=None, indices=None):
94 | buffer_index = np.random.randint(self._num_replay_buffers)
95 | return self._replay_buffers[buffer_index].sample_transition_batch(
96 | batch_size=batch_size, indices=indices)
97 |
98 | def load(self, *args, **kwargs): # pylint: disable=unused-argument
99 | pass
100 |
101 | def reload_buffer(self, num_buffers=None):
102 | self._loaded_buffers = False
103 | self._load_replay_buffers(num_buffers)
104 |
105 | def save(self, *args, **kwargs): # pylint: disable=unused-argument
106 | pass
107 |
108 | def add(self, *args, **kwargs): # pylint: disable=unused-argument
109 | pass
--------------------------------------------------------------------------------
/atari/mingpt/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/kzl/decision-transformer/e2d82e68f330c00f763507b3b01d774740bee53f/atari/mingpt/__init__.py
--------------------------------------------------------------------------------
/atari/mingpt/model_atari.py:
--------------------------------------------------------------------------------
1 | """
2 | The MIT License (MIT) Copyright (c) 2020 Andrej Karpathy
3 |
4 | Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions:
5 |
6 | The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software.
7 |
8 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
9 | """
10 |
11 | """
12 | GPT model:
13 | - the initial stem consists of a combination of token encoding and a positional encoding
14 | - the meat of it is a uniform sequence of Transformer blocks
15 | - each Transformer is a sequential combination of a 1-hidden-layer MLP block and a self-attention block
16 | - all blocks feed into a central residual pathway similar to resnets
17 | - the final decoder is a linear projection into a vanilla Softmax classifier
18 | """
19 |
20 | import math
21 | import logging
22 |
23 | import torch
24 | import torch.nn as nn
25 | from torch.nn import functional as F
26 |
27 | logger = logging.getLogger(__name__)
28 |
29 | import numpy as np
30 |
31 | class GELU(nn.Module):
32 | def forward(self, input):
33 | return F.gelu(input)
34 |
35 | class GPTConfig:
36 | """ base GPT config, params common to all GPT versions """
37 | embd_pdrop = 0.1
38 | resid_pdrop = 0.1
39 | attn_pdrop = 0.1
40 |
41 | def __init__(self, vocab_size, block_size, **kwargs):
42 | self.vocab_size = vocab_size
43 | self.block_size = block_size
44 | for k,v in kwargs.items():
45 | setattr(self, k, v)
46 |
47 | class GPT1Config(GPTConfig):
48 | """ GPT-1 like network roughly 125M params """
49 | n_layer = 12
50 | n_head = 12
51 | n_embd = 768
52 |
53 | class CausalSelfAttention(nn.Module):
54 | """
55 | A vanilla multi-head masked self-attention layer with a projection at the end.
56 | It is possible to use torch.nn.MultiheadAttention here but I am including an
57 | explicit implementation here to show that there is nothing too scary here.
58 | """
59 |
60 | def __init__(self, config):
61 | super().__init__()
62 | assert config.n_embd % config.n_head == 0
63 | # key, query, value projections for all heads
64 | self.key = nn.Linear(config.n_embd, config.n_embd)
65 | self.query = nn.Linear(config.n_embd, config.n_embd)
66 | self.value = nn.Linear(config.n_embd, config.n_embd)
67 | # regularization
68 | self.attn_drop = nn.Dropout(config.attn_pdrop)
69 | self.resid_drop = nn.Dropout(config.resid_pdrop)
70 | # output projection
71 | self.proj = nn.Linear(config.n_embd, config.n_embd)
72 | # causal mask to ensure that attention is only applied to the left in the input sequence
73 | # self.register_buffer("mask", torch.tril(torch.ones(config.block_size, config.block_size))
74 | # .view(1, 1, config.block_size, config.block_size))
75 | self.register_buffer("mask", torch.tril(torch.ones(config.block_size + 1, config.block_size + 1))
76 | .view(1, 1, config.block_size + 1, config.block_size + 1))
77 | self.n_head = config.n_head
78 |
79 | def forward(self, x, layer_past=None):
80 | B, T, C = x.size()
81 |
82 | # calculate query, key, values for all heads in batch and move head forward to be the batch dim
83 | k = self.key(x).view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs)
84 | q = self.query(x).view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs)
85 | v = self.value(x).view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs)
86 |
87 | # causal self-attention; Self-attend: (B, nh, T, hs) x (B, nh, hs, T) -> (B, nh, T, T)
88 | att = (q @ k.transpose(-2, -1)) * (1.0 / math.sqrt(k.size(-1)))
89 | att = att.masked_fill(self.mask[:,:,:T,:T] == 0, float('-inf'))
90 | att = F.softmax(att, dim=-1)
91 | att = self.attn_drop(att)
92 | y = att @ v # (B, nh, T, T) x (B, nh, T, hs) -> (B, nh, T, hs)
93 | y = y.transpose(1, 2).contiguous().view(B, T, C) # re-assemble all head outputs side by side
94 |
95 | # output projection
96 | y = self.resid_drop(self.proj(y))
97 | return y
98 |
99 | class Block(nn.Module):
100 | """ an unassuming Transformer block """
101 |
102 | def __init__(self, config):
103 | super().__init__()
104 | self.ln1 = nn.LayerNorm(config.n_embd)
105 | self.ln2 = nn.LayerNorm(config.n_embd)
106 | self.attn = CausalSelfAttention(config)
107 | self.mlp = nn.Sequential(
108 | nn.Linear(config.n_embd, 4 * config.n_embd),
109 | GELU(),
110 | nn.Linear(4 * config.n_embd, config.n_embd),
111 | nn.Dropout(config.resid_pdrop),
112 | )
113 |
114 | def forward(self, x):
115 | x = x + self.attn(self.ln1(x))
116 | x = x + self.mlp(self.ln2(x))
117 | return x
118 |
119 | class GPT(nn.Module):
120 | """ the full GPT language model, with a context size of block_size """
121 |
122 | def __init__(self, config):
123 | super().__init__()
124 |
125 | self.config = config
126 |
127 | self.model_type = config.model_type
128 |
129 | # input embedding stem
130 | self.tok_emb = nn.Embedding(config.vocab_size, config.n_embd)
131 | # self.pos_emb = nn.Parameter(torch.zeros(1, config.block_size, config.n_embd))
132 | self.pos_emb = nn.Parameter(torch.zeros(1, config.block_size + 1, config.n_embd))
133 | self.global_pos_emb = nn.Parameter(torch.zeros(1, config.max_timestep+1, config.n_embd))
134 | self.drop = nn.Dropout(config.embd_pdrop)
135 |
136 | # transformer
137 | self.blocks = nn.Sequential(*[Block(config) for _ in range(config.n_layer)])
138 | # decoder head
139 | self.ln_f = nn.LayerNorm(config.n_embd)
140 | self.head = nn.Linear(config.n_embd, config.vocab_size, bias=False)
141 |
142 | self.block_size = config.block_size
143 | self.apply(self._init_weights)
144 |
145 |
146 | logger.info("number of parameters: %e", sum(p.numel() for p in self.parameters()))
147 |
148 |
149 | self.state_encoder = nn.Sequential(nn.Conv2d(4, 32, 8, stride=4, padding=0), nn.ReLU(),
150 | nn.Conv2d(32, 64, 4, stride=2, padding=0), nn.ReLU(),
151 | nn.Conv2d(64, 64, 3, stride=1, padding=0), nn.ReLU(),
152 | nn.Flatten(), nn.Linear(3136, config.n_embd), nn.Tanh())
153 |
154 | self.ret_emb = nn.Sequential(nn.Linear(1, config.n_embd), nn.Tanh())
155 |
156 | self.action_embeddings = nn.Sequential(nn.Embedding(config.vocab_size, config.n_embd), nn.Tanh())
157 | nn.init.normal_(self.action_embeddings[0].weight, mean=0.0, std=0.02)
158 |
159 | def get_block_size(self):
160 | return self.block_size
161 |
162 | def _init_weights(self, module):
163 | if isinstance(module, (nn.Linear, nn.Embedding)):
164 | module.weight.data.normal_(mean=0.0, std=0.02)
165 | if isinstance(module, nn.Linear) and module.bias is not None:
166 | module.bias.data.zero_()
167 | elif isinstance(module, nn.LayerNorm):
168 | module.bias.data.zero_()
169 | module.weight.data.fill_(1.0)
170 |
171 | def configure_optimizers(self, train_config):
172 | """
173 | This long function is unfortunately doing something very simple and is being very defensive:
174 | We are separating out all parameters of the model into two buckets: those that will experience
175 | weight decay for regularization and those that won't (biases, and layernorm/embedding weights).
176 | We are then returning the PyTorch optimizer object.
177 | """
178 |
179 | # separate out all parameters to those that will and won't experience regularizing weight decay
180 | decay = set()
181 | no_decay = set()
182 | # whitelist_weight_modules = (torch.nn.Linear, )
183 | whitelist_weight_modules = (torch.nn.Linear, torch.nn.Conv2d)
184 | blacklist_weight_modules = (torch.nn.LayerNorm, torch.nn.Embedding)
185 | for mn, m in self.named_modules():
186 | for pn, p in m.named_parameters():
187 | fpn = '%s.%s' % (mn, pn) if mn else pn # full param name
188 |
189 | if pn.endswith('bias'):
190 | # all biases will not be decayed
191 | no_decay.add(fpn)
192 | elif pn.endswith('weight') and isinstance(m, whitelist_weight_modules):
193 | # weights of whitelist modules will be weight decayed
194 | decay.add(fpn)
195 | elif pn.endswith('weight') and isinstance(m, blacklist_weight_modules):
196 | # weights of blacklist modules will NOT be weight decayed
197 | no_decay.add(fpn)
198 |
199 | # special case the position embedding parameter in the root GPT module as not decayed
200 | no_decay.add('pos_emb')
201 | no_decay.add('global_pos_emb')
202 |
203 | # validate that we considered every parameter
204 | param_dict = {pn: p for pn, p in self.named_parameters()}
205 | inter_params = decay & no_decay
206 | union_params = decay | no_decay
207 | assert len(inter_params) == 0, "parameters %s made it into both decay/no_decay sets!" % (str(inter_params), )
208 | assert len(param_dict.keys() - union_params) == 0, "parameters %s were not separated into either decay/no_decay set!" \
209 | % (str(param_dict.keys() - union_params), )
210 |
211 | # create the pytorch optimizer object
212 | optim_groups = [
213 | {"params": [param_dict[pn] for pn in sorted(list(decay))], "weight_decay": train_config.weight_decay},
214 | {"params": [param_dict[pn] for pn in sorted(list(no_decay))], "weight_decay": 0.0},
215 | ]
216 | optimizer = torch.optim.AdamW(optim_groups, lr=train_config.learning_rate, betas=train_config.betas)
217 | return optimizer
218 |
219 | # state, action, and return
220 | def forward(self, states, actions, targets=None, rtgs=None, timesteps=None):
221 | # states: (batch, block_size, 4*84*84)
222 | # actions: (batch, block_size, 1)
223 | # targets: (batch, block_size, 1)
224 | # rtgs: (batch, block_size, 1)
225 | # timesteps: (batch, 1, 1)
226 |
227 | state_embeddings = self.state_encoder(states.reshape(-1, 4, 84, 84).type(torch.float32).contiguous()) # (batch * block_size, n_embd)
228 | state_embeddings = state_embeddings.reshape(states.shape[0], states.shape[1], self.config.n_embd) # (batch, block_size, n_embd)
229 |
230 | if actions is not None and self.model_type == 'reward_conditioned':
231 | rtg_embeddings = self.ret_emb(rtgs.type(torch.float32))
232 | action_embeddings = self.action_embeddings(actions.type(torch.long).squeeze(-1)) # (batch, block_size, n_embd)
233 |
234 | token_embeddings = torch.zeros((states.shape[0], states.shape[1]*3 - int(targets is None), self.config.n_embd), dtype=torch.float32, device=state_embeddings.device)
235 | token_embeddings[:,::3,:] = rtg_embeddings
236 | token_embeddings[:,1::3,:] = state_embeddings
237 | token_embeddings[:,2::3,:] = action_embeddings[:,-states.shape[1] + int(targets is None):,:]
238 | elif actions is None and self.model_type == 'reward_conditioned': # only happens at very first timestep of evaluation
239 | rtg_embeddings = self.ret_emb(rtgs.type(torch.float32))
240 |
241 | token_embeddings = torch.zeros((states.shape[0], states.shape[1]*2, self.config.n_embd), dtype=torch.float32, device=state_embeddings.device)
242 | token_embeddings[:,::2,:] = rtg_embeddings # really just [:,0,:]
243 | token_embeddings[:,1::2,:] = state_embeddings # really just [:,1,:]
244 | elif actions is not None and self.model_type == 'naive':
245 | action_embeddings = self.action_embeddings(actions.type(torch.long).squeeze(-1)) # (batch, block_size, n_embd)
246 |
247 | token_embeddings = torch.zeros((states.shape[0], states.shape[1]*2 - int(targets is None), self.config.n_embd), dtype=torch.float32, device=state_embeddings.device)
248 | token_embeddings[:,::2,:] = state_embeddings
249 | token_embeddings[:,1::2,:] = action_embeddings[:,-states.shape[1] + int(targets is None):,:]
250 | elif actions is None and self.model_type == 'naive': # only happens at very first timestep of evaluation
251 | token_embeddings = state_embeddings
252 | else:
253 | raise NotImplementedError()
254 |
255 | batch_size = states.shape[0]
256 | all_global_pos_emb = torch.repeat_interleave(self.global_pos_emb, batch_size, dim=0) # batch_size, traj_length, n_embd
257 |
258 | position_embeddings = torch.gather(all_global_pos_emb, 1, torch.repeat_interleave(timesteps, self.config.n_embd, dim=-1)) + self.pos_emb[:, :token_embeddings.shape[1], :]
259 |
260 | x = self.drop(token_embeddings + position_embeddings)
261 | x = self.blocks(x)
262 | x = self.ln_f(x)
263 | logits = self.head(x)
264 |
265 | if actions is not None and self.model_type == 'reward_conditioned':
266 | logits = logits[:, 1::3, :] # only keep predictions from state_embeddings
267 | elif actions is None and self.model_type == 'reward_conditioned':
268 | logits = logits[:, 1:, :]
269 | elif actions is not None and self.model_type == 'naive':
270 | logits = logits[:, ::2, :] # only keep predictions from state_embeddings
271 | elif actions is None and self.model_type == 'naive':
272 | logits = logits # for completeness
273 | else:
274 | raise NotImplementedError()
275 |
276 | # if we are given some desired targets also calculate the loss
277 | loss = None
278 | if targets is not None:
279 | loss = F.cross_entropy(logits.reshape(-1, logits.size(-1)), targets.reshape(-1))
280 |
281 | return logits, loss
282 |
--------------------------------------------------------------------------------
/atari/mingpt/trainer_atari.py:
--------------------------------------------------------------------------------
1 | """
2 | The MIT License (MIT) Copyright (c) 2020 Andrej Karpathy
3 |
4 | Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions:
5 |
6 | The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software.
7 |
8 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
9 | """
10 |
11 | """
12 | Simple training loop; Boilerplate that could apply to any arbitrary neural network,
13 | so nothing in this file really has anything to do with GPT specifically.
14 | """
15 |
16 | import math
17 | import logging
18 |
19 | from tqdm import tqdm
20 | import numpy as np
21 |
22 | import torch
23 | import torch.optim as optim
24 | from torch.optim.lr_scheduler import LambdaLR
25 | from torch.utils.data.dataloader import DataLoader
26 |
27 | logger = logging.getLogger(__name__)
28 |
29 | from mingpt.utils import sample
30 | import atari_py
31 | from collections import deque
32 | import random
33 | import cv2
34 | import torch
35 | from PIL import Image
36 |
37 | class TrainerConfig:
38 | # optimization parameters
39 | max_epochs = 10
40 | batch_size = 64
41 | learning_rate = 3e-4
42 | betas = (0.9, 0.95)
43 | grad_norm_clip = 1.0
44 | weight_decay = 0.1 # only applied on matmul weights
45 | # learning rate decay params: linear warmup followed by cosine decay to 10% of original
46 | lr_decay = False
47 | warmup_tokens = 375e6 # these two numbers come from the GPT-3 paper, but may not be good defaults elsewhere
48 | final_tokens = 260e9 # (at what point we reach 10% of original LR)
49 | # checkpoint settings
50 | ckpt_path = None
51 | num_workers = 0 # for DataLoader
52 |
53 | def __init__(self, **kwargs):
54 | for k,v in kwargs.items():
55 | setattr(self, k, v)
56 |
57 | class Trainer:
58 |
59 | def __init__(self, model, train_dataset, test_dataset, config):
60 | self.model = model
61 | self.train_dataset = train_dataset
62 | self.test_dataset = test_dataset
63 | self.config = config
64 |
65 | # take over whatever gpus are on the system
66 | self.device = 'cpu'
67 | if torch.cuda.is_available():
68 | self.device = torch.cuda.current_device()
69 | self.model = torch.nn.DataParallel(self.model).to(self.device)
70 |
71 | def save_checkpoint(self):
72 | # DataParallel wrappers keep raw model object in .module attribute
73 | raw_model = self.model.module if hasattr(self.model, "module") else self.model
74 | logger.info("saving %s", self.config.ckpt_path)
75 | # torch.save(raw_model.state_dict(), self.config.ckpt_path)
76 |
77 | def train(self):
78 | model, config = self.model, self.config
79 | raw_model = model.module if hasattr(self.model, "module") else model
80 | optimizer = raw_model.configure_optimizers(config)
81 |
82 | def run_epoch(split, epoch_num=0):
83 | is_train = split == 'train'
84 | model.train(is_train)
85 | data = self.train_dataset if is_train else self.test_dataset
86 | loader = DataLoader(data, shuffle=True, pin_memory=True,
87 | batch_size=config.batch_size,
88 | num_workers=config.num_workers)
89 |
90 | losses = []
91 | pbar = tqdm(enumerate(loader), total=len(loader)) if is_train else enumerate(loader)
92 | for it, (x, y, r, t) in pbar:
93 |
94 | # place data on the correct device
95 | x = x.to(self.device)
96 | y = y.to(self.device)
97 | r = r.to(self.device)
98 | t = t.to(self.device)
99 |
100 | # forward the model
101 | with torch.set_grad_enabled(is_train):
102 | # logits, loss = model(x, y, r)
103 | logits, loss = model(x, y, y, r, t)
104 | loss = loss.mean() # collapse all losses if they are scattered on multiple gpus
105 | losses.append(loss.item())
106 |
107 | if is_train:
108 |
109 | # backprop and update the parameters
110 | model.zero_grad()
111 | loss.backward()
112 | torch.nn.utils.clip_grad_norm_(model.parameters(), config.grad_norm_clip)
113 | optimizer.step()
114 |
115 | # decay the learning rate based on our progress
116 | if config.lr_decay:
117 | self.tokens += (y >= 0).sum() # number of tokens processed this step (i.e. label is not -100)
118 | if self.tokens < config.warmup_tokens:
119 | # linear warmup
120 | lr_mult = float(self.tokens) / float(max(1, config.warmup_tokens))
121 | else:
122 | # cosine learning rate decay
123 | progress = float(self.tokens - config.warmup_tokens) / float(max(1, config.final_tokens - config.warmup_tokens))
124 | lr_mult = max(0.1, 0.5 * (1.0 + math.cos(math.pi * progress)))
125 | lr = config.learning_rate * lr_mult
126 | for param_group in optimizer.param_groups:
127 | param_group['lr'] = lr
128 | else:
129 | lr = config.learning_rate
130 |
131 | # report progress
132 | pbar.set_description(f"epoch {epoch+1} iter {it}: train loss {loss.item():.5f}. lr {lr:e}")
133 |
134 | if not is_train:
135 | test_loss = float(np.mean(losses))
136 | logger.info("test loss: %f", test_loss)
137 | return test_loss
138 |
139 | # best_loss = float('inf')
140 |
141 | best_return = -float('inf')
142 |
143 | self.tokens = 0 # counter used for learning rate decay
144 |
145 | for epoch in range(config.max_epochs):
146 |
147 | run_epoch('train', epoch_num=epoch)
148 | # if self.test_dataset is not None:
149 | # test_loss = run_epoch('test')
150 |
151 | # # supports early stopping based on the test loss, or just save always if no test set is provided
152 | # good_model = self.test_dataset is None or test_loss < best_loss
153 | # if self.config.ckpt_path is not None and good_model:
154 | # best_loss = test_loss
155 | # self.save_checkpoint()
156 |
157 | # -- pass in target returns
158 | if self.config.model_type == 'naive':
159 | eval_return = self.get_returns(0)
160 | elif self.config.model_type == 'reward_conditioned':
161 | if self.config.game == 'Breakout':
162 | eval_return = self.get_returns(90)
163 | elif self.config.game == 'Seaquest':
164 | eval_return = self.get_returns(1150)
165 | elif self.config.game == 'Qbert':
166 | eval_return = self.get_returns(14000)
167 | elif self.config.game == 'Pong':
168 | eval_return = self.get_returns(20)
169 | else:
170 | raise NotImplementedError()
171 | else:
172 | raise NotImplementedError()
173 |
174 | def get_returns(self, ret):
175 | self.model.train(False)
176 | args=Args(self.config.game.lower(), self.config.seed)
177 | env = Env(args)
178 | env.eval()
179 |
180 | T_rewards, T_Qs = [], []
181 | done = True
182 | for i in range(10):
183 | state = env.reset()
184 | state = state.type(torch.float32).to(self.device).unsqueeze(0).unsqueeze(0)
185 | rtgs = [ret]
186 | # first state is from env, first rtg is target return, and first timestep is 0
187 | sampled_action = sample(self.model.module, state, 1, temperature=1.0, sample=True, actions=None,
188 | rtgs=torch.tensor(rtgs, dtype=torch.long).to(self.device).unsqueeze(0).unsqueeze(-1),
189 | timesteps=torch.zeros((1, 1, 1), dtype=torch.int64).to(self.device))
190 |
191 | j = 0
192 | all_states = state
193 | actions = []
194 | while True:
195 | if done:
196 | state, reward_sum, done = env.reset(), 0, False
197 | action = sampled_action.cpu().numpy()[0,-1]
198 | actions += [sampled_action]
199 | state, reward, done = env.step(action)
200 | reward_sum += reward
201 | j += 1
202 |
203 | if done:
204 | T_rewards.append(reward_sum)
205 | break
206 |
207 | state = state.unsqueeze(0).unsqueeze(0).to(self.device)
208 |
209 | all_states = torch.cat([all_states, state], dim=0)
210 |
211 | rtgs += [rtgs[-1] - reward]
212 | # all_states has all previous states and rtgs has all previous rtgs (will be cut to block_size in utils.sample)
213 | # timestep is just current timestep
214 | sampled_action = sample(self.model.module, all_states.unsqueeze(0), 1, temperature=1.0, sample=True,
215 | actions=torch.tensor(actions, dtype=torch.long).to(self.device).unsqueeze(1).unsqueeze(0),
216 | rtgs=torch.tensor(rtgs, dtype=torch.long).to(self.device).unsqueeze(0).unsqueeze(-1),
217 | timesteps=(min(j, self.config.max_timestep) * torch.ones((1, 1, 1), dtype=torch.int64).to(self.device)))
218 | env.close()
219 | eval_return = sum(T_rewards)/10.
220 | print("target return: %d, eval return: %d" % (ret, eval_return))
221 | self.model.train(True)
222 | return eval_return
223 |
224 |
225 | class Env():
226 | def __init__(self, args):
227 | self.device = args.device
228 | self.ale = atari_py.ALEInterface()
229 | self.ale.setInt('random_seed', args.seed)
230 | self.ale.setInt('max_num_frames_per_episode', args.max_episode_length)
231 | self.ale.setFloat('repeat_action_probability', 0) # Disable sticky actions
232 | self.ale.setInt('frame_skip', 0)
233 | self.ale.setBool('color_averaging', False)
234 | self.ale.loadROM(atari_py.get_game_path(args.game)) # ROM loading must be done after setting options
235 | actions = self.ale.getMinimalActionSet()
236 | self.actions = dict([i, e] for i, e in zip(range(len(actions)), actions))
237 | self.lives = 0 # Life counter (used in DeepMind training)
238 | self.life_termination = False # Used to check if resetting only from loss of life
239 | self.window = args.history_length # Number of frames to concatenate
240 | self.state_buffer = deque([], maxlen=args.history_length)
241 | self.training = True # Consistent with model training mode
242 |
243 | def _get_state(self):
244 | state = cv2.resize(self.ale.getScreenGrayscale(), (84, 84), interpolation=cv2.INTER_LINEAR)
245 | return torch.tensor(state, dtype=torch.float32, device=self.device).div_(255)
246 |
247 | def _reset_buffer(self):
248 | for _ in range(self.window):
249 | self.state_buffer.append(torch.zeros(84, 84, device=self.device))
250 |
251 | def reset(self):
252 | if self.life_termination:
253 | self.life_termination = False # Reset flag
254 | self.ale.act(0) # Use a no-op after loss of life
255 | else:
256 | # Reset internals
257 | self._reset_buffer()
258 | self.ale.reset_game()
259 | # Perform up to 30 random no-ops before starting
260 | for _ in range(random.randrange(30)):
261 | self.ale.act(0) # Assumes raw action 0 is always no-op
262 | if self.ale.game_over():
263 | self.ale.reset_game()
264 | # Process and return "initial" state
265 | observation = self._get_state()
266 | self.state_buffer.append(observation)
267 | self.lives = self.ale.lives()
268 | return torch.stack(list(self.state_buffer), 0)
269 |
270 | def step(self, action):
271 | # Repeat action 4 times, max pool over last 2 frames
272 | frame_buffer = torch.zeros(2, 84, 84, device=self.device)
273 | reward, done = 0, False
274 | for t in range(4):
275 | reward += self.ale.act(self.actions.get(action))
276 | if t == 2:
277 | frame_buffer[0] = self._get_state()
278 | elif t == 3:
279 | frame_buffer[1] = self._get_state()
280 | done = self.ale.game_over()
281 | if done:
282 | break
283 | observation = frame_buffer.max(0)[0]
284 | self.state_buffer.append(observation)
285 | # Detect loss of life as terminal in training mode
286 | if self.training:
287 | lives = self.ale.lives()
288 | if lives < self.lives and lives > 0: # Lives > 0 for Q*bert
289 | self.life_termination = not done # Only set flag when not truly done
290 | done = True
291 | self.lives = lives
292 | # Return state, reward, done
293 | return torch.stack(list(self.state_buffer), 0), reward, done
294 |
295 | # Uses loss of life as terminal signal
296 | def train(self):
297 | self.training = True
298 |
299 | # Uses standard terminal signal
300 | def eval(self):
301 | self.training = False
302 |
303 | def action_space(self):
304 | return len(self.actions)
305 |
306 | def render(self):
307 | cv2.imshow('screen', self.ale.getScreenRGB()[:, :, ::-1])
308 | cv2.waitKey(1)
309 |
310 | def close(self):
311 | cv2.destroyAllWindows()
312 |
313 | class Args:
314 | def __init__(self, game, seed):
315 | self.device = torch.device('cuda')
316 | self.seed = seed
317 | self.max_episode_length = 108e3
318 | self.game = game
319 | self.history_length = 4
320 |
--------------------------------------------------------------------------------
/atari/mingpt/utils.py:
--------------------------------------------------------------------------------
1 | """
2 | The MIT License (MIT) Copyright (c) 2020 Andrej Karpathy
3 |
4 | Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions:
5 |
6 | The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software.
7 |
8 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
9 | """
10 |
11 | import random
12 | import numpy as np
13 | import torch
14 | import torch.nn as nn
15 | from torch.nn import functional as F
16 |
17 | def set_seed(seed):
18 | random.seed(seed)
19 | np.random.seed(seed)
20 | torch.manual_seed(seed)
21 | torch.cuda.manual_seed_all(seed)
22 |
23 | def top_k_logits(logits, k):
24 | v, ix = torch.topk(logits, k)
25 | out = logits.clone()
26 | out[out < v[:, [-1]]] = -float('Inf')
27 | return out
28 |
29 | @torch.no_grad()
30 | def sample(model, x, steps, temperature=1.0, sample=False, top_k=None, actions=None, rtgs=None, timesteps=None):
31 | """
32 | take a conditioning sequence of indices in x (of shape (b,t)) and predict the next token in
33 | the sequence, feeding the predictions back into the model each time. Clearly the sampling
34 | has quadratic complexity unlike an RNN that is only linear, and has a finite context window
35 | of block_size, unlike an RNN that has an infinite context window.
36 | """
37 | block_size = model.get_block_size()
38 | model.eval()
39 | for k in range(steps):
40 | # x_cond = x if x.size(1) <= block_size else x[:, -block_size:] # crop context if needed
41 | x_cond = x if x.size(1) <= block_size//3 else x[:, -block_size//3:] # crop context if needed
42 | if actions is not None:
43 | actions = actions if actions.size(1) <= block_size//3 else actions[:, -block_size//3:] # crop context if needed
44 | rtgs = rtgs if rtgs.size(1) <= block_size//3 else rtgs[:, -block_size//3:] # crop context if needed
45 | logits, _ = model(x_cond, actions=actions, targets=None, rtgs=rtgs, timesteps=timesteps)
46 | # pluck the logits at the final step and scale by temperature
47 | logits = logits[:, -1, :] / temperature
48 | # optionally crop probabilities to only the top k options
49 | if top_k is not None:
50 | logits = top_k_logits(logits, top_k)
51 | # apply softmax to convert to probabilities
52 | probs = F.softmax(logits, dim=-1)
53 | # sample from the distribution or take the most likely
54 | if sample:
55 | ix = torch.multinomial(probs, num_samples=1)
56 | else:
57 | _, ix = torch.topk(probs, k=1, dim=-1)
58 | # append to the sequence and continue
59 | # x = torch.cat((x, ix), dim=1)
60 | x = ix
61 |
62 | return x
63 |
--------------------------------------------------------------------------------
/atari/readme-atari.md:
--------------------------------------------------------------------------------
1 |
2 | # Atari
3 |
4 | We build our Atari implementation on top of [minGPT](https://github.com/karpathy/minGPT) and benchmark our results on the [DQN-replay](https://github.com/google-research/batch_rl) dataset.
5 |
6 | ## Installation
7 |
8 | Dependencies can be installed with the following command:
9 |
10 | ```
11 | conda env create -f conda_env.yml
12 | ```
13 |
14 | ## Downloading datasets
15 |
16 | Create a directory for the dataset and load the dataset using [gsutil](https://cloud.google.com/storage/docs/gsutil_install#install). Replace `[DIRECTORY_NAME]` and `[GAME_NAME]` accordingly (e.g., `./dqn_replay` for `[DIRECTORY_NAME]` and `Breakout` for `[GAME_NAME]`)
17 | ```
18 | mkdir [DIRECTORY_NAME]
19 | gsutil -m cp -R gs://atari-replay-datasets/dqn/[GAME_NAME] [DIRECTORY_NAME]
20 | ```
21 |
22 | ## Example usage
23 |
24 | Scripts to reproduce our Decision Transformer results can be found in `run.sh`.
25 |
26 | ```
27 | python run_dt_atari.py --seed 123 --block_size 90 --epochs 5 --model_type 'reward_conditioned' --num_steps 500000 --num_buffers 50 --game 'Breakout' --batch_size 128 --data_dir_prefix [DIRECTORY_NAME]
28 | ```
29 |
--------------------------------------------------------------------------------
/atari/run.sh:
--------------------------------------------------------------------------------
1 | # Decision Transformer (DT)
2 | for seed in 123 231 312
3 | do
4 | python run_dt_atari.py --seed $seed --context_length 30 --epochs 5 --model_type 'reward_conditioned' --num_steps 500000 --num_buffers 50 --game 'Breakout' --batch_size 128
5 | done
6 |
7 | for seed in 123 231 312
8 | do
9 | python run_dt_atari.py --seed $seed --context_length 30 --epochs 5 --model_type 'reward_conditioned' --num_steps 500000 --num_buffers 50 --game 'Qbert' --batch_size 128
10 | done
11 |
12 | for seed in 123 231 312
13 | do
14 | python run_dt_atari.py --seed $seed --context_length 50 --epochs 5 --model_type 'reward_conditioned' --num_steps 500000 --num_buffers 50 --game 'Pong' --batch_size 512
15 | done
16 |
17 | for seed in 123 231 312
18 | do
19 | python run_dt_atari.py --seed $seed --context_length 30 --epochs 5 --model_type 'reward_conditioned' --num_steps 500000 --num_buffers 50 --game 'Seaquest' --batch_size 128
20 | done
21 |
22 | # Behavior Cloning (BC)
23 | for seed in 123 231 312
24 | do
25 | python run_dt_atari.py --seed $seed --context_length 30 --epochs 5 --model_type 'naive' --num_steps 500000 --num_buffers 50 --game 'Breakout' --batch_size 128
26 | done
27 |
28 | for seed in 123 231 312
29 | do
30 | python run_dt_atari.py --seed $seed --context_length 30 --epochs 5 --model_type 'naive' --num_steps 500000 --num_buffers 50 --game 'Qbert' --batch_size 128
31 | done
32 |
33 | for seed in 123 231 312
34 | do
35 | python run_dt_atari.py --seed $seed --context_length 50 --epochs 5 --model_type 'naive' --num_steps 500000 --num_buffers 50 --game 'Pong' --batch_size 512
36 | done
37 |
38 | for seed in 123 231 312
39 | do
40 | python run_dt_atari.py --seed $seed --context_length 30 --epochs 5 --model_type 'naive' --num_steps 500000 --num_buffers 50 --game 'Seaquest' --batch_size 128
41 | done
--------------------------------------------------------------------------------
/atari/run_dt_atari.py:
--------------------------------------------------------------------------------
1 | import csv
2 | import logging
3 | # make deterministic
4 | from mingpt.utils import set_seed
5 | import numpy as np
6 | import torch
7 | import torch.nn as nn
8 | from torch.nn import functional as F
9 | import math
10 | from torch.utils.data import Dataset
11 | from mingpt.model_atari import GPT, GPTConfig
12 | from mingpt.trainer_atari import Trainer, TrainerConfig
13 | from mingpt.utils import sample
14 | from collections import deque
15 | import random
16 | import torch
17 | import pickle
18 | import blosc
19 | import argparse
20 | from create_dataset import create_dataset
21 |
22 | parser = argparse.ArgumentParser()
23 | parser.add_argument('--seed', type=int, default=123)
24 | parser.add_argument('--context_length', type=int, default=30)
25 | parser.add_argument('--epochs', type=int, default=5)
26 | parser.add_argument('--model_type', type=str, default='reward_conditioned')
27 | parser.add_argument('--num_steps', type=int, default=500000)
28 | parser.add_argument('--num_buffers', type=int, default=50)
29 | parser.add_argument('--game', type=str, default='Breakout')
30 | parser.add_argument('--batch_size', type=int, default=128)
31 | #
32 | parser.add_argument('--trajectories_per_buffer', type=int, default=10, help='Number of trajectories to sample from each of the buffers.')
33 | parser.add_argument('--data_dir_prefix', type=str, default='./dqn_replay/')
34 | args = parser.parse_args()
35 |
36 | set_seed(args.seed)
37 |
38 | class StateActionReturnDataset(Dataset):
39 |
40 | def __init__(self, data, block_size, actions, done_idxs, rtgs, timesteps):
41 | self.block_size = block_size
42 | self.vocab_size = max(actions) + 1
43 | self.data = data
44 | self.actions = actions
45 | self.done_idxs = done_idxs
46 | self.rtgs = rtgs
47 | self.timesteps = timesteps
48 |
49 | def __len__(self):
50 | return len(self.data) - self.block_size
51 |
52 | def __getitem__(self, idx):
53 | block_size = self.block_size // 3
54 | done_idx = idx + block_size
55 | for i in self.done_idxs:
56 | if i > idx: # first done_idx greater than idx
57 | done_idx = min(int(i), done_idx)
58 | break
59 | idx = done_idx - block_size
60 | states = torch.tensor(np.array(self.data[idx:done_idx]), dtype=torch.float32).reshape(block_size, -1) # (block_size, 4*84*84)
61 | states = states / 255.
62 | actions = torch.tensor(self.actions[idx:done_idx], dtype=torch.long).unsqueeze(1) # (block_size, 1)
63 | rtgs = torch.tensor(self.rtgs[idx:done_idx], dtype=torch.float32).unsqueeze(1)
64 | timesteps = torch.tensor(self.timesteps[idx:idx+1], dtype=torch.int64).unsqueeze(1)
65 |
66 | return states, actions, rtgs, timesteps
67 |
68 | obss, actions, returns, done_idxs, rtgs, timesteps = create_dataset(args.num_buffers, args.num_steps, args.game, args.data_dir_prefix, args.trajectories_per_buffer)
69 |
70 | # set up logging
71 | logging.basicConfig(
72 | format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
73 | datefmt="%m/%d/%Y %H:%M:%S",
74 | level=logging.INFO,
75 | )
76 |
77 | train_dataset = StateActionReturnDataset(obss, args.context_length*3, actions, done_idxs, rtgs, timesteps)
78 |
79 | mconf = GPTConfig(train_dataset.vocab_size, train_dataset.block_size,
80 | n_layer=6, n_head=8, n_embd=128, model_type=args.model_type, max_timestep=max(timesteps))
81 | model = GPT(mconf)
82 |
83 | # initialize a trainer instance and kick off training
84 | epochs = args.epochs
85 | tconf = TrainerConfig(max_epochs=epochs, batch_size=args.batch_size, learning_rate=6e-4,
86 | lr_decay=True, warmup_tokens=512*20, final_tokens=2*len(train_dataset)*args.context_length*3,
87 | num_workers=4, seed=args.seed, model_type=args.model_type, game=args.game, max_timestep=max(timesteps))
88 | trainer = Trainer(model, train_dataset, None, tconf)
89 |
90 | trainer.train()
91 |
--------------------------------------------------------------------------------
/gym/conda_env.yml:
--------------------------------------------------------------------------------
1 | name: decision-transformer-gym
2 | channels:
3 | - pytorch
4 | dependencies:
5 | - python=3.8.5
6 | - anaconda
7 | - cudatoolkit=10.
8 | - numpy
9 | - pip
10 | - pip:
11 | - gym==0.18.3
12 | - mujoco-py==2.0.2.13
13 | - numpy==1.20.3
14 | - torch==1.8.1
15 | - transformers==4.5.1
16 | - wandb==0.9.1
17 |
--------------------------------------------------------------------------------
/gym/data/download_d4rl_datasets.py:
--------------------------------------------------------------------------------
1 | import gym
2 | import numpy as np
3 |
4 | import collections
5 | import pickle
6 |
7 | import d4rl
8 |
9 |
10 | datasets = []
11 |
12 | for env_name in ['halfcheetah', 'hopper', 'walker2d']:
13 | for dataset_type in ['medium', 'medium-replay', 'expert']:
14 | name = f'{env_name}-{dataset_type}-v2'
15 | env = gym.make(name)
16 | dataset = env.get_dataset()
17 |
18 | N = dataset['rewards'].shape[0]
19 | data_ = collections.defaultdict(list)
20 |
21 | use_timeouts = False
22 | if 'timeouts' in dataset:
23 | use_timeouts = True
24 |
25 | episode_step = 0
26 | paths = []
27 | for i in range(N):
28 | done_bool = bool(dataset['terminals'][i])
29 | if use_timeouts:
30 | final_timestep = dataset['timeouts'][i]
31 | else:
32 | final_timestep = (episode_step == 1000-1)
33 | for k in ['observations', 'next_observations', 'actions', 'rewards', 'terminals']:
34 | data_[k].append(dataset[k][i])
35 | if done_bool or final_timestep:
36 | episode_step = 0
37 | episode_data = {}
38 | for k in data_:
39 | episode_data[k] = np.array(data_[k])
40 | paths.append(episode_data)
41 | data_ = collections.defaultdict(list)
42 | episode_step += 1
43 |
44 | returns = np.array([np.sum(p['rewards']) for p in paths])
45 | num_samples = np.sum([p['rewards'].shape[0] for p in paths])
46 | print(f'Number of samples collected: {num_samples}')
47 | print(f'Trajectory returns: mean = {np.mean(returns)}, std = {np.std(returns)}, max = {np.max(returns)}, min = {np.min(returns)}')
48 |
49 | with open(f'{name}.pkl', 'wb') as f:
50 | pickle.dump(paths, f)
51 |
--------------------------------------------------------------------------------
/gym/decision_transformer/envs/assets/reacher_2d.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
9 |
10 |
11 |
12 |
13 |
14 |
15 |
16 |
17 |
18 |
19 |
20 |
21 |
22 |
23 |
24 |
25 |
26 |
27 |
28 |
29 |
30 |
31 |
32 |
33 |
34 |
--------------------------------------------------------------------------------
/gym/decision_transformer/envs/reacher_2d.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | from gym import utils
3 | from gym.envs.mujoco import mujoco_env
4 |
5 | import os
6 |
7 |
8 | class Reacher2dEnv(mujoco_env.MujocoEnv, utils.EzPickle):
9 |
10 | def __init__(self):
11 | self.fingertip_sid = 0
12 | self.target_bid = 0
13 | curr_dir = os.path.dirname(os.path.abspath(__file__))
14 | mujoco_env.MujocoEnv.__init__(self, curr_dir+'/assets/reacher_2d.xml', 15)
15 | self.fingertip_sid = self.sim.model.site_name2id('fingertip')
16 | self.target_bid = self.sim.model.body_name2id('target')
17 | utils.EzPickle.__init__(self)
18 |
19 | def step(self, action):
20 | action = np.clip(action, -1.0, 1.0)
21 | self.do_simulation(action, self.frame_skip)
22 | tip = self.data.site_xpos[self.fingertip_sid][:2]
23 | tar = self.data.body_xpos[self.target_bid][:2]
24 | dist = np.sum(np.abs(tip - tar))
25 | reward_dist = 0. # - 0.1 * dist
26 | reward_ctrl = 0.0
27 | reward_bonus = 1.0 if dist < 0.1 else 0.0
28 | reward = reward_bonus + reward_ctrl + reward_dist
29 | done = False
30 | ob = self._get_obs()
31 | return ob, reward, done, dict(reward_dist=reward_dist, reward_ctrl=reward_ctrl, reward_bonus=reward_bonus)
32 |
33 | def _get_obs(self):
34 | theta = self.data.qpos.ravel()
35 | tip = self.data.site_xpos[self.fingertip_sid][:2]
36 | tar = self.data.body_xpos[self.target_bid][:2]
37 | return np.concatenate([
38 | # self.data.qpos.flat,
39 | np.sin(theta),
40 | np.cos(theta),
41 | self.dt * self.data.qvel.ravel(),
42 | tip,
43 | tar,
44 | tip-tar,
45 | ])
46 |
47 | def reset_model(self):
48 | # qpos = self.np_random.uniform(low=-0.1, high=0.1, size=self.model.nq) + self.init_qpos
49 | # qvel = self.init_qvel + self.np_random.uniform(low=-.005, high=.005, size=self.model.nv)
50 | qpos = self.np_random.uniform(low=-2.0, high=2.0, size=self.model.nq)
51 | qvel = self.init_qvel * 0.0
52 | while True:
53 | self.goal = self.np_random.uniform(low=-1.5, high=1.5, size=2)
54 | if np.linalg.norm(self.goal) <= 1.0 and np.linalg.norm(self.goal) >= 0.5:
55 | break
56 | self.set_state(qpos, qvel)
57 | self.model.body_pos[self.target_bid][:2] = self.goal
58 | self.sim.forward()
59 | return self._get_obs()
60 |
61 | def viewer_setup(self):
62 | self.viewer.cam.distance = self.model.stat.extent * 5.0
63 |
--------------------------------------------------------------------------------
/gym/decision_transformer/evaluation/evaluate_episodes.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import torch
3 |
4 |
5 | def evaluate_episode(
6 | env,
7 | state_dim,
8 | act_dim,
9 | model,
10 | max_ep_len=1000,
11 | device='cuda',
12 | target_return=None,
13 | mode='normal',
14 | state_mean=0.,
15 | state_std=1.,
16 | ):
17 |
18 | model.eval()
19 | model.to(device=device)
20 |
21 | state_mean = torch.from_numpy(state_mean).to(device=device)
22 | state_std = torch.from_numpy(state_std).to(device=device)
23 |
24 | state = env.reset()
25 |
26 | # we keep all the histories on the device
27 | # note that the latest action and reward will be "padding"
28 | states = torch.from_numpy(state).reshape(1, state_dim).to(device=device, dtype=torch.float32)
29 | actions = torch.zeros((0, act_dim), device=device, dtype=torch.float32)
30 | rewards = torch.zeros(0, device=device, dtype=torch.float32)
31 | target_return = torch.tensor(target_return, device=device, dtype=torch.float32)
32 | sim_states = []
33 |
34 | episode_return, episode_length = 0, 0
35 | for t in range(max_ep_len):
36 |
37 | # add padding
38 | actions = torch.cat([actions, torch.zeros((1, act_dim), device=device)], dim=0)
39 | rewards = torch.cat([rewards, torch.zeros(1, device=device)])
40 |
41 | action = model.get_action(
42 | (states.to(dtype=torch.float32) - state_mean) / state_std,
43 | actions.to(dtype=torch.float32),
44 | rewards.to(dtype=torch.float32),
45 | target_return=target_return,
46 | )
47 | actions[-1] = action
48 | action = action.detach().cpu().numpy()
49 |
50 | state, reward, done, _ = env.step(action)
51 |
52 | cur_state = torch.from_numpy(state).to(device=device).reshape(1, state_dim)
53 | states = torch.cat([states, cur_state], dim=0)
54 | rewards[-1] = reward
55 |
56 | episode_return += reward
57 | episode_length += 1
58 |
59 | if done:
60 | break
61 |
62 | return episode_return, episode_length
63 |
64 |
65 | def evaluate_episode_rtg(
66 | env,
67 | state_dim,
68 | act_dim,
69 | model,
70 | max_ep_len=1000,
71 | scale=1000.,
72 | state_mean=0.,
73 | state_std=1.,
74 | device='cuda',
75 | target_return=None,
76 | mode='normal',
77 | ):
78 |
79 | model.eval()
80 | model.to(device=device)
81 |
82 | state_mean = torch.from_numpy(state_mean).to(device=device)
83 | state_std = torch.from_numpy(state_std).to(device=device)
84 |
85 | state = env.reset()
86 | if mode == 'noise':
87 | state = state + np.random.normal(0, 0.1, size=state.shape)
88 |
89 | # we keep all the histories on the device
90 | # note that the latest action and reward will be "padding"
91 | states = torch.from_numpy(state).reshape(1, state_dim).to(device=device, dtype=torch.float32)
92 | actions = torch.zeros((0, act_dim), device=device, dtype=torch.float32)
93 | rewards = torch.zeros(0, device=device, dtype=torch.float32)
94 |
95 | ep_return = target_return
96 | target_return = torch.tensor(ep_return, device=device, dtype=torch.float32).reshape(1, 1)
97 | timesteps = torch.tensor(0, device=device, dtype=torch.long).reshape(1, 1)
98 |
99 | sim_states = []
100 |
101 | episode_return, episode_length = 0, 0
102 | for t in range(max_ep_len):
103 |
104 | # add padding
105 | actions = torch.cat([actions, torch.zeros((1, act_dim), device=device)], dim=0)
106 | rewards = torch.cat([rewards, torch.zeros(1, device=device)])
107 |
108 | action = model.get_action(
109 | (states.to(dtype=torch.float32) - state_mean) / state_std,
110 | actions.to(dtype=torch.float32),
111 | rewards.to(dtype=torch.float32),
112 | target_return.to(dtype=torch.float32),
113 | timesteps.to(dtype=torch.long),
114 | )
115 | actions[-1] = action
116 | action = action.detach().cpu().numpy()
117 |
118 | state, reward, done, _ = env.step(action)
119 |
120 | cur_state = torch.from_numpy(state).to(device=device).reshape(1, state_dim)
121 | states = torch.cat([states, cur_state], dim=0)
122 | rewards[-1] = reward
123 |
124 | if mode != 'delayed':
125 | pred_return = target_return[0,-1] - (reward/scale)
126 | else:
127 | pred_return = target_return[0,-1]
128 | target_return = torch.cat(
129 | [target_return, pred_return.reshape(1, 1)], dim=1)
130 | timesteps = torch.cat(
131 | [timesteps,
132 | torch.ones((1, 1), device=device, dtype=torch.long) * (t+1)], dim=1)
133 |
134 | episode_return += reward
135 | episode_length += 1
136 |
137 | if done:
138 | break
139 |
140 | return episode_return, episode_length
141 |
--------------------------------------------------------------------------------
/gym/decision_transformer/models/decision_transformer.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import torch
3 | import torch.nn as nn
4 |
5 | import transformers
6 |
7 | from decision_transformer.models.model import TrajectoryModel
8 | from decision_transformer.models.trajectory_gpt2 import GPT2Model
9 |
10 |
11 | class DecisionTransformer(TrajectoryModel):
12 |
13 | """
14 | This model uses GPT to model (Return_1, state_1, action_1, Return_2, state_2, ...)
15 | """
16 |
17 | def __init__(
18 | self,
19 | state_dim,
20 | act_dim,
21 | hidden_size,
22 | max_length=None,
23 | max_ep_len=4096,
24 | action_tanh=True,
25 | **kwargs
26 | ):
27 | super().__init__(state_dim, act_dim, max_length=max_length)
28 |
29 | self.hidden_size = hidden_size
30 | config = transformers.GPT2Config(
31 | vocab_size=1, # doesn't matter -- we don't use the vocab
32 | n_embd=hidden_size,
33 | **kwargs
34 | )
35 |
36 | # note: the only difference between this GPT2Model and the default Huggingface version
37 | # is that the positional embeddings are removed (since we'll add those ourselves)
38 | self.transformer = GPT2Model(config)
39 |
40 | self.embed_timestep = nn.Embedding(max_ep_len, hidden_size)
41 | self.embed_return = torch.nn.Linear(1, hidden_size)
42 | self.embed_state = torch.nn.Linear(self.state_dim, hidden_size)
43 | self.embed_action = torch.nn.Linear(self.act_dim, hidden_size)
44 |
45 | self.embed_ln = nn.LayerNorm(hidden_size)
46 |
47 | # note: we don't predict states or returns for the paper
48 | self.predict_state = torch.nn.Linear(hidden_size, self.state_dim)
49 | self.predict_action = nn.Sequential(
50 | *([nn.Linear(hidden_size, self.act_dim)] + ([nn.Tanh()] if action_tanh else []))
51 | )
52 | self.predict_return = torch.nn.Linear(hidden_size, 1)
53 |
54 | def forward(self, states, actions, rewards, returns_to_go, timesteps, attention_mask=None):
55 |
56 | batch_size, seq_length = states.shape[0], states.shape[1]
57 |
58 | if attention_mask is None:
59 | # attention mask for GPT: 1 if can be attended to, 0 if not
60 | attention_mask = torch.ones((batch_size, seq_length), dtype=torch.long)
61 |
62 | # embed each modality with a different head
63 | state_embeddings = self.embed_state(states)
64 | action_embeddings = self.embed_action(actions)
65 | returns_embeddings = self.embed_return(returns_to_go)
66 | time_embeddings = self.embed_timestep(timesteps)
67 |
68 | # time embeddings are treated similar to positional embeddings
69 | state_embeddings = state_embeddings + time_embeddings
70 | action_embeddings = action_embeddings + time_embeddings
71 | returns_embeddings = returns_embeddings + time_embeddings
72 |
73 | # this makes the sequence look like (R_1, s_1, a_1, R_2, s_2, a_2, ...)
74 | # which works nice in an autoregressive sense since states predict actions
75 | stacked_inputs = torch.stack(
76 | (returns_embeddings, state_embeddings, action_embeddings), dim=1
77 | ).permute(0, 2, 1, 3).reshape(batch_size, 3*seq_length, self.hidden_size)
78 | stacked_inputs = self.embed_ln(stacked_inputs)
79 |
80 | # to make the attention mask fit the stacked inputs, have to stack it as well
81 | stacked_attention_mask = torch.stack(
82 | (attention_mask, attention_mask, attention_mask), dim=1
83 | ).permute(0, 2, 1).reshape(batch_size, 3*seq_length)
84 |
85 | # we feed in the input embeddings (not word indices as in NLP) to the model
86 | transformer_outputs = self.transformer(
87 | inputs_embeds=stacked_inputs,
88 | attention_mask=stacked_attention_mask,
89 | )
90 | x = transformer_outputs['last_hidden_state']
91 |
92 | # reshape x so that the second dimension corresponds to the original
93 | # returns (0), states (1), or actions (2); i.e. x[:,1,t] is the token for s_t
94 | x = x.reshape(batch_size, seq_length, 3, self.hidden_size).permute(0, 2, 1, 3)
95 |
96 | # get predictions
97 | return_preds = self.predict_return(x[:,2]) # predict next return given state and action
98 | state_preds = self.predict_state(x[:,2]) # predict next state given state and action
99 | action_preds = self.predict_action(x[:,1]) # predict next action given state
100 |
101 | return state_preds, action_preds, return_preds
102 |
103 | def get_action(self, states, actions, rewards, returns_to_go, timesteps, **kwargs):
104 | # we don't care about the past rewards in this model
105 |
106 | states = states.reshape(1, -1, self.state_dim)
107 | actions = actions.reshape(1, -1, self.act_dim)
108 | returns_to_go = returns_to_go.reshape(1, -1, 1)
109 | timesteps = timesteps.reshape(1, -1)
110 |
111 | if self.max_length is not None:
112 | states = states[:,-self.max_length:]
113 | actions = actions[:,-self.max_length:]
114 | returns_to_go = returns_to_go[:,-self.max_length:]
115 | timesteps = timesteps[:,-self.max_length:]
116 |
117 | # pad all tokens to sequence length
118 | attention_mask = torch.cat([torch.zeros(self.max_length-states.shape[1]), torch.ones(states.shape[1])])
119 | attention_mask = attention_mask.to(dtype=torch.long, device=states.device).reshape(1, -1)
120 | states = torch.cat(
121 | [torch.zeros((states.shape[0], self.max_length-states.shape[1], self.state_dim), device=states.device), states],
122 | dim=1).to(dtype=torch.float32)
123 | actions = torch.cat(
124 | [torch.zeros((actions.shape[0], self.max_length - actions.shape[1], self.act_dim),
125 | device=actions.device), actions],
126 | dim=1).to(dtype=torch.float32)
127 | returns_to_go = torch.cat(
128 | [torch.zeros((returns_to_go.shape[0], self.max_length-returns_to_go.shape[1], 1), device=returns_to_go.device), returns_to_go],
129 | dim=1).to(dtype=torch.float32)
130 | timesteps = torch.cat(
131 | [torch.zeros((timesteps.shape[0], self.max_length-timesteps.shape[1]), device=timesteps.device), timesteps],
132 | dim=1
133 | ).to(dtype=torch.long)
134 | else:
135 | attention_mask = None
136 |
137 | _, action_preds, return_preds = self.forward(
138 | states, actions, None, returns_to_go, timesteps, attention_mask=attention_mask, **kwargs)
139 |
140 | return action_preds[0,-1]
141 |
--------------------------------------------------------------------------------
/gym/decision_transformer/models/mlp_bc.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import torch
3 | import torch.nn as nn
4 |
5 | from decision_transformer.models.model import TrajectoryModel
6 |
7 |
8 | class MLPBCModel(TrajectoryModel):
9 |
10 | """
11 | Simple MLP that predicts next action a from past states s.
12 | """
13 |
14 | def __init__(self, state_dim, act_dim, hidden_size, n_layer, dropout=0.1, max_length=1, **kwargs):
15 | super().__init__(state_dim, act_dim)
16 |
17 | self.hidden_size = hidden_size
18 | self.max_length = max_length
19 |
20 | layers = [nn.Linear(max_length*self.state_dim, hidden_size)]
21 | for _ in range(n_layer-1):
22 | layers.extend([
23 | nn.ReLU(),
24 | nn.Dropout(dropout),
25 | nn.Linear(hidden_size, hidden_size)
26 | ])
27 | layers.extend([
28 | nn.ReLU(),
29 | nn.Dropout(dropout),
30 | nn.Linear(hidden_size, self.act_dim),
31 | nn.Tanh(),
32 | ])
33 |
34 | self.model = nn.Sequential(*layers)
35 |
36 | def forward(self, states, actions, rewards, attention_mask=None, target_return=None):
37 |
38 | states = states[:,-self.max_length:].reshape(states.shape[0], -1) # concat states
39 | actions = self.model(states).reshape(states.shape[0], 1, self.act_dim)
40 |
41 | return None, actions, None
42 |
43 | def get_action(self, states, actions, rewards, **kwargs):
44 | states = states.reshape(1, -1, self.state_dim)
45 | if states.shape[1] < self.max_length:
46 | states = torch.cat(
47 | [torch.zeros((1, self.max_length-states.shape[1], self.state_dim),
48 | dtype=torch.float32, device=states.device), states], dim=1)
49 | states = states.to(dtype=torch.float32)
50 | _, actions, _ = self.forward(states, None, None, **kwargs)
51 | return actions[0,-1]
52 |
--------------------------------------------------------------------------------
/gym/decision_transformer/models/model.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import torch
3 | import torch.nn as nn
4 |
5 |
6 | class TrajectoryModel(nn.Module):
7 |
8 | def __init__(self, state_dim, act_dim, max_length=None):
9 | super().__init__()
10 |
11 | self.state_dim = state_dim
12 | self.act_dim = act_dim
13 | self.max_length = max_length
14 |
15 | def forward(self, states, actions, rewards, masks=None, attention_mask=None):
16 | # "masked" tokens or unspecified inputs can be passed in as None
17 | return None, None, None
18 |
19 | def get_action(self, states, actions, rewards, **kwargs):
20 | # these will come as tensors on the correct device
21 | return torch.zeros_like(actions[-1])
22 |
--------------------------------------------------------------------------------
/gym/decision_transformer/models/trajectory_gpt2.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2018 The OpenAI Team Authors and HuggingFace Inc. team.
3 | # Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved.
4 | #
5 | # Licensed under the Apache License, Version 2.0 (the "License");
6 | # you may not use this file except in compliance with the License.
7 | # You may obtain a copy of the License at
8 | #
9 | # http://www.apache.org/licenses/LICENSE-2.0
10 | #
11 | # Unless required by applicable law or agreed to in writing, software
12 | # distributed under the License is distributed on an "AS IS" BASIS,
13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 | # See the License for the specific language governing permissions and
15 | # limitations under the License.
16 | """PyTorch OpenAI GPT-2 model."""
17 |
18 | import os
19 | from dataclasses import dataclass
20 | from typing import List, Optional, Tuple
21 |
22 | import torch
23 | import torch.nn as nn
24 | from torch.nn import CrossEntropyLoss, MSELoss
25 |
26 | from transformers.activations import ACT2FN
27 | from transformers.file_utils import (
28 | ModelOutput,
29 | add_code_sample_docstrings,
30 | add_start_docstrings,
31 | add_start_docstrings_to_model_forward,
32 | replace_return_docstrings,
33 | )
34 | from transformers.modeling_outputs import (
35 | BaseModelOutputWithPastAndCrossAttentions,
36 | )
37 | from transformers.modeling_utils import (
38 | Conv1D,
39 | PreTrainedModel,
40 | SequenceSummary,
41 | find_pruneable_heads_and_indices,
42 | prune_conv1d_layer,
43 | )
44 | from transformers.utils import logging
45 | from transformers.utils.model_parallel_utils import assert_device_map, get_device_map
46 | from transformers.models.gpt2.configuration_gpt2 import GPT2Config
47 |
48 | logger = logging.get_logger(__name__)
49 |
50 | _CONFIG_FOR_DOC = "GPT2Config"
51 | _TOKENIZER_FOR_DOC = "GPT2Tokenizer"
52 |
53 | GPT2_PRETRAINED_MODEL_ARCHIVE_LIST = [
54 | "gpt2",
55 | "gpt2-medium",
56 | "gpt2-large",
57 | "gpt2-xl",
58 | "distilgpt2",
59 | # See all GPT-2 models at https://huggingface.co/models?filter=gpt2
60 | ]
61 |
62 |
63 | def load_tf_weights_in_gpt2(model, config, gpt2_checkpoint_path):
64 | """Load tf checkpoints in a pytorch model"""
65 | try:
66 | import re
67 |
68 | import tensorflow as tf
69 | except ImportError:
70 | logger.error(
71 | "Loading a TensorFlow model in PyTorch, requires TensorFlow to be installed. Please see "
72 | "https://www.tensorflow.org/install/ for installation instructions."
73 | )
74 | raise
75 | tf_path = os.path.abspath(gpt2_checkpoint_path)
76 | logger.info("Converting TensorFlow checkpoint from {}".format(tf_path))
77 | # Load weights from TF model
78 | init_vars = tf.train.list_variables(tf_path)
79 | names = []
80 | arrays = []
81 | for name, shape in init_vars:
82 | logger.info("Loading TF weight {} with shape {}".format(name, shape))
83 | array = tf.train.load_variable(tf_path, name)
84 | names.append(name)
85 | arrays.append(array.squeeze())
86 |
87 | for name, array in zip(names, arrays):
88 | name = name[6:] # skip "model/"
89 | name = name.split("/")
90 | pointer = model
91 | for m_name in name:
92 | if re.fullmatch(r"[A-Za-z]+\d+", m_name):
93 | scope_names = re.split(r"(\d+)", m_name)
94 | else:
95 | scope_names = [m_name]
96 | if scope_names[0] == "w" or scope_names[0] == "g":
97 | pointer = getattr(pointer, "weight")
98 | elif scope_names[0] == "b":
99 | pointer = getattr(pointer, "bias")
100 | elif scope_names[0] == "wpe" or scope_names[0] == "wte":
101 | pointer = getattr(pointer, scope_names[0])
102 | pointer = getattr(pointer, "weight")
103 | else:
104 | pointer = getattr(pointer, scope_names[0])
105 | if len(scope_names) >= 2:
106 | num = int(scope_names[1])
107 | pointer = pointer[num]
108 | try:
109 | assert (
110 | pointer.shape == array.shape
111 | ), f"Pointer shape {pointer.shape} and array shape {array.shape} mismatched"
112 | except AssertionError as e:
113 | e.args += (pointer.shape, array.shape)
114 | raise
115 | logger.info("Initialize PyTorch weight {}".format(name))
116 | pointer.data = torch.from_numpy(array)
117 | return model
118 |
119 |
120 | class Attention(nn.Module):
121 | def __init__(self, nx, n_ctx, config, scale=False, is_cross_attention=False):
122 | super().__init__()
123 |
124 | n_state = nx # in Attention: n_state=768 (nx=n_embd)
125 | # [switch nx => n_state from Block to Attention to keep identical to TF implem]
126 | assert n_state % config.n_head == 0
127 | self.register_buffer(
128 | "bias", torch.tril(torch.ones((n_ctx, n_ctx), dtype=torch.uint8)).view(1, 1, n_ctx, n_ctx)
129 | )
130 | self.register_buffer("masked_bias", torch.tensor(-1e4))
131 | self.n_head = config.n_head
132 | self.split_size = n_state
133 | self.scale = scale
134 | self.is_cross_attention = is_cross_attention
135 | if self.is_cross_attention:
136 | self.c_attn = Conv1D(2 * n_state, nx)
137 | self.q_attn = Conv1D(n_state, nx)
138 | else:
139 | self.c_attn = Conv1D(3 * n_state, nx)
140 | self.c_proj = Conv1D(n_state, nx)
141 | self.attn_dropout = nn.Dropout(config.attn_pdrop)
142 | self.resid_dropout = nn.Dropout(config.resid_pdrop)
143 | self.pruned_heads = set()
144 |
145 | def prune_heads(self, heads):
146 | if len(heads) == 0:
147 | return
148 | heads, index = find_pruneable_heads_and_indices(
149 | heads, self.n_head, self.split_size // self.n_head, self.pruned_heads
150 | )
151 | index_attn = torch.cat([index, index + self.split_size, index + (2 * self.split_size)])
152 |
153 | # Prune conv1d layers
154 | self.c_attn = prune_conv1d_layer(self.c_attn, index_attn, dim=1)
155 | self.c_proj = prune_conv1d_layer(self.c_proj, index, dim=0)
156 |
157 | # Update hyper params
158 | self.split_size = (self.split_size // self.n_head) * (self.n_head - len(heads))
159 | self.n_head = self.n_head - len(heads)
160 | self.pruned_heads = self.pruned_heads.union(heads)
161 |
162 | def _attn(self, q, k, v, attention_mask=None, head_mask=None, output_attentions=False):
163 | w = torch.matmul(q, k)
164 | if self.scale:
165 | w = w / (float(v.size(-1)) ** 0.5)
166 | nd, ns = w.size(-2), w.size(-1)
167 |
168 | if not self.is_cross_attention:
169 | # if only "normal" attention layer implements causal mask
170 | mask = self.bias[:, :, ns - nd: ns, :ns]
171 | w = torch.where(mask.bool(), w, self.masked_bias.to(w.dtype))
172 |
173 | if attention_mask is not None:
174 | # Apply the attention mask
175 | w = w + attention_mask
176 |
177 | w = nn.Softmax(dim=-1)(w)
178 | w = self.attn_dropout(w)
179 |
180 | # Mask heads if we want to
181 | if head_mask is not None:
182 | w = w * head_mask
183 |
184 | outputs = [torch.matmul(w, v)]
185 | if output_attentions:
186 | outputs.append(w)
187 | return outputs
188 |
189 | def merge_heads(self, x):
190 | x = x.permute(0, 2, 1, 3).contiguous()
191 | new_x_shape = x.size()[:-2] + (x.size(-2) * x.size(-1),)
192 | return x.view(*new_x_shape) # in Tensorflow implem: fct merge_states
193 |
194 | def split_heads(self, x, k=False):
195 | new_x_shape = x.size()[:-1] + (self.n_head, x.size(-1) // self.n_head)
196 | x = x.view(*new_x_shape) # in Tensorflow implem: fct split_states
197 | if k:
198 | return x.permute(0, 2, 3, 1) # (batch, head, head_features, seq_length)
199 | else:
200 | return x.permute(0, 2, 1, 3) # (batch, head, seq_length, head_features)
201 |
202 | def forward(
203 | self,
204 | hidden_states,
205 | layer_past=None,
206 | attention_mask=None,
207 | head_mask=None,
208 | encoder_hidden_states=None,
209 | encoder_attention_mask=None,
210 | use_cache=False,
211 | output_attentions=False,
212 | ):
213 | if encoder_hidden_states is not None:
214 | assert hasattr(
215 | self, "q_attn"
216 | ), "If class is used as cross attention, the weights `q_attn` have to be defined. Please make sure to instantiate class with `Attention(..., is_cross_attention=True)`."
217 | query = self.q_attn(hidden_states)
218 | key, value = self.c_attn(encoder_hidden_states).split(self.split_size, dim=2)
219 | attention_mask = encoder_attention_mask
220 | else:
221 | query, key, value = self.c_attn(hidden_states).split(self.split_size, dim=2)
222 |
223 | query = self.split_heads(query)
224 | key = self.split_heads(key, k=True)
225 | value = self.split_heads(value)
226 | if layer_past is not None:
227 | past_key, past_value = layer_past[0].transpose(-2, -1), layer_past[1] # transpose back cf below
228 | key = torch.cat((past_key, key), dim=-1)
229 | value = torch.cat((past_value, value), dim=-2)
230 |
231 | if use_cache is True:
232 | present = torch.stack((key.transpose(-2, -1), value)) # transpose to have same shapes for stacking
233 | else:
234 | present = (None,)
235 |
236 | attn_outputs = self._attn(query, key, value, attention_mask, head_mask, output_attentions)
237 | a = attn_outputs[0]
238 |
239 | a = self.merge_heads(a)
240 | a = self.c_proj(a)
241 | a = self.resid_dropout(a)
242 |
243 | outputs = [a, present] + attn_outputs[1:]
244 | return outputs # a, present, (attentions)
245 |
246 |
247 | class MLP(nn.Module):
248 | def __init__(self, n_state, config): # in MLP: n_state=3072 (4 * n_embd)
249 | super().__init__()
250 | nx = config.n_embd
251 | self.c_fc = Conv1D(n_state, nx)
252 | self.c_proj = Conv1D(nx, n_state)
253 | self.act = ACT2FN[config.activation_function]
254 | self.dropout = nn.Dropout(config.resid_pdrop)
255 |
256 | def forward(self, x):
257 | h = self.act(self.c_fc(x))
258 | h2 = self.c_proj(h)
259 | return self.dropout(h2)
260 |
261 |
262 | class AdapterMLP(nn.Module):
263 | def __init__(self, n_state, config): # in MLP: n_state=3072 (4 * n_embd)
264 | super().__init__()
265 | nx = config.n_embd
266 | self.c_fc = Conv1D(n_state, nx)
267 | self.c_proj = Conv1D(nx, n_state)
268 | self.act = ACT2FN[config.activation_function]
269 | self.dropout = nn.Dropout(config.resid_pdrop)
270 |
271 | def forward(self, x):
272 | h = self.act(self.c_fc(x))
273 | h2 = self.c_proj(h)
274 | return self.dropout(h2)
275 |
276 |
277 | class Block(nn.Module):
278 | def __init__(self, n_ctx, config, scale=False):
279 | super().__init__()
280 | hidden_size = config.n_embd
281 | inner_dim = config.n_inner if config.n_inner is not None else 4 * hidden_size
282 | self.ln_1 = nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon)
283 | self.attn = Attention(hidden_size, n_ctx, config, scale)
284 | self.ln_2 = nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon)
285 | # self.adapter_ln = nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon)
286 | if config.add_cross_attention:
287 | self.crossattention = Attention(hidden_size, n_ctx, config, scale, is_cross_attention=True)
288 | self.ln_cross_attn = nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon)
289 | self.mlp = MLP(inner_dim, config)
290 | # self.adapter_mlp = AdapterMLP(512, config) # ADAPTER
291 |
292 | def forward(
293 | self,
294 | hidden_states,
295 | layer_past=None,
296 | attention_mask=None,
297 | head_mask=None,
298 | encoder_hidden_states=None,
299 | encoder_attention_mask=None,
300 | use_cache=False,
301 | output_attentions=False,
302 | ):
303 | attn_outputs = self.attn(
304 | self.ln_1(hidden_states),
305 | layer_past=layer_past,
306 | attention_mask=attention_mask,
307 | head_mask=head_mask,
308 | use_cache=use_cache,
309 | output_attentions=output_attentions,
310 | )
311 | attn_output = attn_outputs[0] # output_attn: a, present, (attentions)
312 | outputs = attn_outputs[1:]
313 | # residual connection
314 | hidden_states = attn_output + hidden_states
315 |
316 | if encoder_hidden_states is not None:
317 | # add one self-attention block for cross-attention
318 | assert hasattr(
319 | self, "crossattention"
320 | ), f"If `encoder_hidden_states` are passed, {self} has to be instantiated with cross-attention layers by setting `config.add_cross_attention=True`"
321 | cross_attn_outputs = self.crossattention(
322 | self.ln_cross_attn(hidden_states),
323 | attention_mask=attention_mask,
324 | head_mask=head_mask,
325 | encoder_hidden_states=encoder_hidden_states,
326 | encoder_attention_mask=encoder_attention_mask,
327 | output_attentions=output_attentions,
328 | )
329 | attn_output = cross_attn_outputs[0]
330 | # residual connection
331 | hidden_states = hidden_states + attn_output
332 | outputs = outputs + cross_attn_outputs[2:] # add cross attentions if we output attention weights
333 |
334 | feed_forward_hidden_states = self.mlp(self.ln_2(hidden_states))
335 | # residual connection
336 | hidden_states = hidden_states + feed_forward_hidden_states
337 | # hidden_states = hidden_states + self.adapter_ln(self.adapter_mlp(hidden_states))
338 |
339 | outputs = [hidden_states] + outputs
340 | return outputs # hidden_states, present, (attentions, cross_attentions)
341 |
342 |
343 | class GPT2PreTrainedModel(PreTrainedModel):
344 | """
345 | An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
346 | models.
347 | """
348 |
349 | config_class = GPT2Config
350 | load_tf_weights = load_tf_weights_in_gpt2
351 | base_model_prefix = "transformer"
352 |
353 | def __init__(self, *inputs, **kwargs):
354 | super().__init__(*inputs, **kwargs)
355 |
356 | def _init_weights(self, module):
357 | """Initialize the weights."""
358 | if isinstance(module, (nn.Linear, nn.Embedding, Conv1D)):
359 | # Slightly different from the TF version which uses truncated_normal for initialization
360 | # cf https://github.com/pytorch/pytorch/pull/5617
361 | module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
362 | if isinstance(module, (nn.Linear, Conv1D)) and module.bias is not None:
363 | module.bias.data.zero_()
364 | elif isinstance(module, nn.LayerNorm):
365 | module.bias.data.zero_()
366 | module.weight.data.fill_(1.0)
367 | # module.weight.data.fill_(.01) # KL: Adapter change
368 |
369 |
370 | @dataclass
371 | class GPT2DoubleHeadsModelOutput(ModelOutput):
372 | """
373 | Base class for outputs of models predicting if two sentences are consecutive or not.
374 | Args:
375 | loss (:obj:`torch.FloatTensor` of shape :obj:`(1,)`, `optional`, returned when ``labels`` is provided):
376 | Language modeling loss.
377 | mc_loss (:obj:`torch.FloatTensor` of shape :obj:`(1,)`, `optional`, returned when :obj:`mc_labels` is provided):
378 | Multiple choice classification loss.
379 | logits (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, num_choices, sequence_length, config.vocab_size)`):
380 | Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax).
381 | mc_logits (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, num_choices)`):
382 | Prediction scores of the multiple choice classification head (scores for each choice before SoftMax).
383 | past_key_values (:obj:`List[torch.FloatTensor]`, `optional`, returned when ``use_cache=True`` is passed or when ``config.use_cache=True``):
384 | List of :obj:`torch.FloatTensor` of length :obj:`config.n_layers`, with each tensor of shape :obj:`(2,
385 | batch_size, num_heads, sequence_length, embed_size_per_head)`).
386 | Contains pre-computed hidden-states (key and values in the attention blocks) that can be used (see
387 | :obj:`past_key_values` input) to speed up sequential decoding.
388 | hidden_states (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``):
389 | Tuple of :obj:`torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer)
390 | of shape :obj:`(batch_size, sequence_length, hidden_size)`.
391 | Hidden-states of the model at the output of each layer plus the initial embedding outputs.
392 | attentions (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_attentions=True`` is passed or when ``config.output_attentions=True``):
393 | Tuple of :obj:`torch.FloatTensor` (one for each layer) of shape :obj:`(batch_size, num_heads,
394 | sequence_length, sequence_length)`.
395 | Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
396 | heads.
397 | """
398 |
399 | loss: Optional[torch.FloatTensor] = None
400 | mc_loss: Optional[torch.FloatTensor] = None
401 | logits: torch.FloatTensor = None
402 | mc_logits: torch.FloatTensor = None
403 | past_key_values: Optional[List[torch.FloatTensor]] = None
404 | hidden_states: Optional[Tuple[torch.FloatTensor]] = None
405 | attentions: Optional[Tuple[torch.FloatTensor]] = None
406 |
407 |
408 | GPT2_START_DOCSTRING = r"""
409 | This model inherits from :class:`~transformers.PreTrainedModel`. Check the superclass documentation for the generic
410 | methods the library implements for all its model (such as downloading or saving, resizing the input embeddings,
411 | pruning heads etc.)
412 | This model is also a PyTorch `torch.nn.Module `__
413 | subclass. Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to
414 | general usage and behavior.
415 | Parameters:
416 | config (:class:`~transformers.GPT2Config`): Model configuration class with all the parameters of the model.
417 | Initializing with a config file does not load the weights associated with the model, only the
418 | configuration. Check out the :meth:`~transformers.PreTrainedModel.from_pretrained` method to load the model
419 | weights.
420 | """
421 |
422 | GPT2_INPUTS_DOCSTRING = r"""
423 | Args:
424 | input_ids (:obj:`torch.LongTensor` of shape :obj:`(batch_size, input_ids_length)`):
425 | :obj:`input_ids_length` = ``sequence_length`` if :obj:`past_key_values` is ``None`` else
426 | ``past_key_values[0].shape[-2]`` (``sequence_length`` of input past key value states). Indices of input
427 | sequence tokens in the vocabulary.
428 | If :obj:`past_key_values` is used, only ``input_ids`` that do not have their past calculated should be
429 | passed as ``input_ids``.
430 | Indices can be obtained using :class:`~transformers.GPT2Tokenizer`. See
431 | :meth:`transformers.PreTrainedTokenizer.encode` and :meth:`transformers.PreTrainedTokenizer.__call__` for
432 | details.
433 | `What are input IDs? <../glossary.html#input-ids>`__
434 | past_key_values (:obj:`List[torch.FloatTensor]` of length :obj:`config.n_layers`):
435 | Contains precomputed hidden-states (key and values in the attention blocks) as computed by the model (see
436 | :obj:`past_key_values` output below). Can be used to speed up sequential decoding. The ``input_ids`` which
437 | have their past given to this model should not be passed as ``input_ids`` as they have already been
438 | computed.
439 | attention_mask (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`):
440 | Mask to avoid performing attention on padding token indices. Mask values selected in ``[0, 1]``:
441 | - 1 for tokens that are **not masked**,
442 | - 0 for tokens that are **masked**.
443 | `What are attention masks? <../glossary.html#attention-mask>`__
444 | token_type_ids (:obj:`torch.LongTensor` of shape :obj:`(batch_size, input_ids_length)`, `optional`):
445 | Segment token indices to indicate first and second portions of the inputs. Indices are selected in ``[0,
446 | 1]``:
447 | - 0 corresponds to a `sentence A` token,
448 | - 1 corresponds to a `sentence B` token.
449 | `What are token type IDs? <../glossary.html#token-type-ids>`_
450 | position_ids (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`):
451 | Indices of positions of each input sequence tokens in the position embeddings. Selected in the range ``[0,
452 | config.max_position_embeddings - 1]``.
453 | `What are position IDs? <../glossary.html#position-ids>`_
454 | head_mask (:obj:`torch.FloatTensor` of shape :obj:`(num_heads,)` or :obj:`(num_layers, num_heads)`, `optional`):
455 | Mask to nullify selected heads of the self-attention modules. Mask values selected in ``[0, 1]``:
456 | - 1 indicates the head is **not masked**,
457 | - 0 indicates the head is **masked**.
458 | inputs_embeds (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`, `optional`):
459 | Optionally, instead of passing :obj:`input_ids` you can choose to directly pass an embedded representation.
460 | This is useful if you want more control over how to convert :obj:`input_ids` indices into associated
461 | vectors than the model's internal embedding lookup matrix.
462 | If :obj:`past_key_values` is used, optionally only the last :obj:`inputs_embeds` have to be input (see
463 | :obj:`past_key_values`).
464 | use_cache (:obj:`bool`, `optional`):
465 | If set to :obj:`True`, :obj:`past_key_values` key value states are returned and can be used to speed up
466 | decoding (see :obj:`past_key_values`).
467 | output_attentions (:obj:`bool`, `optional`):
468 | Whether or not to return the attentions tensors of all attention layers. See ``attentions`` under returned
469 | tensors for more detail.
470 | output_hidden_states (:obj:`bool`, `optional`):
471 | Whether or not to return the hidden states of all layers. See ``hidden_states`` under returned tensors for
472 | more detail.
473 | return_dict (:obj:`bool`, `optional`):
474 | Whether or not to return a :class:`~transformers.file_utils.ModelOutput` instead of a plain tuple.
475 | """
476 | PARALLELIZE_DOCSTRING = r"""
477 | Uses a device map to distribute attention modules of the model across several devices. If no device map is given,
478 | it will evenly distribute blocks across all devices.
479 | Args:
480 | device_map (:obj:`Dict[int, list]`, optional, defaults to None):
481 | A dictionary that maps attention modules to devices. Note that the embedding module and LMHead are always
482 | automatically mapped to the first device (for esoteric reasons). That means that the first device should
483 | have fewer attention modules mapped to it than other devices. For reference, the gpt2 models have the
484 | following number of attention modules:
485 | - gpt2: 12
486 | - gpt2-medium: 24
487 | - gpt2-large: 36
488 | - gpt2-xl: 48
489 | Example::
490 | # Here is an example of a device map on a machine with 4 GPUs using gpt2-xl, which has a total of 48 attention modules:
491 | model = GPT2LMHeadModel.from_pretrained('gpt2-xl')
492 | device_map = {0: [0, 1, 2, 3, 4, 5, 6, 7, 8],
493 | 1: [9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21],
494 | 2: [22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34],
495 | 3: [35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47]}
496 | model.parallelize(device_map)
497 | """
498 | DEPARALLELIZE_DOCSTRING = r"""
499 | Moves the model to cpu from a model parallel state.
500 | Example::
501 | # On a 4 GPU machine with gpt2-large:
502 | model = GPT2LMHeadModel.from_pretrained('gpt2-large')
503 | device_map = {0: [0, 1, 2, 3, 4, 5, 6, 7],
504 | 1: [8, 9, 10, 11, 12, 13, 14, 15],
505 | 2: [16, 17, 18, 19, 20, 21, 22, 23],
506 | 3: [24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35]}
507 | model.parallelize(device_map) # Splits the model across several devices
508 | model.deparallelize() # Put the model back on cpu and cleans memory by calling torch.cuda.empty_cache()
509 | """
510 |
511 |
512 | @add_start_docstrings(
513 | "The bare GPT2 Model transformer outputting raw hidden-states without any specific head on top.",
514 | GPT2_START_DOCSTRING,
515 | )
516 | class GPT2Model(GPT2PreTrainedModel):
517 | def __init__(self, config):
518 | super().__init__(config)
519 |
520 | self.wte = nn.Embedding(config.vocab_size, config.n_embd)
521 | # self.wpe = nn.Embedding(config.n_positions, config.n_embd)
522 | self.drop = nn.Dropout(config.embd_pdrop)
523 | self.h = nn.ModuleList([Block(config.n_ctx, config, scale=True) for _ in range(config.n_layer)])
524 | self.ln_f = nn.LayerNorm(config.n_embd, eps=config.layer_norm_epsilon)
525 |
526 | self.init_weights()
527 | # Model parallel
528 | self.model_parallel = False
529 | self.device_map = None
530 |
531 | self.use_layers = None
532 |
533 | def set_layers(self, num_layers):
534 | assert 1 <= num_layers <= len(self.h)
535 | if num_layers is not None:
536 | num_layers -= 1
537 | self.use_layers = num_layers
538 |
539 | @add_start_docstrings(PARALLELIZE_DOCSTRING)
540 | def parallelize(self, device_map=None):
541 | # Check validity of device_map
542 | self.device_map = (
543 | get_device_map(len(self.h), range(torch.cuda.device_count())) if device_map is None else device_map
544 | )
545 | assert_device_map(self.device_map, len(self.h))
546 | self.model_parallel = True
547 | self.first_device = "cpu" if "cpu" in self.device_map.keys() else "cuda:" + str(min(self.device_map.keys()))
548 | self.last_device = "cuda:" + str(max(self.device_map.keys()))
549 | self.wte = self.wte.to(self.first_device)
550 | self.wpe = self.wpe.to(self.first_device)
551 | # Load onto devices
552 | for k, v in self.device_map.items():
553 | for block in v:
554 | cuda_device = "cuda:" + str(k)
555 | self.h[block] = self.h[block].to(cuda_device)
556 | # ln_f to last
557 | self.ln_f = self.ln_f.to(self.last_device)
558 |
559 | @add_start_docstrings(DEPARALLELIZE_DOCSTRING)
560 | def deparallelize(self):
561 | self.model_parallel = False
562 | self.device_map = None
563 | self.first_device = "cpu"
564 | self.last_device = "cpu"
565 | self.wte = self.wte.to("cpu")
566 | self.wpe = self.wpe.to("cpu")
567 | for index in range(len(self.h)):
568 | self.h[index] = self.h[index].to("cpu")
569 | self.ln_f = self.ln_f.to("cpu")
570 | torch.cuda.empty_cache()
571 |
572 | def get_input_embeddings(self):
573 | return self.wte
574 |
575 | def set_input_embeddings(self, new_embeddings):
576 | self.wte = new_embeddings
577 |
578 | def _prune_heads(self, heads_to_prune):
579 | """
580 | Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer}
581 | """
582 | for layer, heads in heads_to_prune.items():
583 | self.h[layer].attn.prune_heads(heads)
584 |
585 | @add_start_docstrings_to_model_forward(GPT2_INPUTS_DOCSTRING)
586 | @add_code_sample_docstrings(
587 | tokenizer_class=_TOKENIZER_FOR_DOC,
588 | checkpoint="gpt2",
589 | output_type=BaseModelOutputWithPastAndCrossAttentions,
590 | config_class=_CONFIG_FOR_DOC,
591 | )
592 | def forward(
593 | self,
594 | input_ids=None,
595 | past_key_values=None,
596 | attention_mask=None,
597 | token_type_ids=None,
598 | position_ids=None,
599 | head_mask=None,
600 | inputs_embeds=None,
601 | encoder_hidden_states=None,
602 | encoder_attention_mask=None,
603 | use_cache=None,
604 | output_attentions=None,
605 | output_hidden_states=None,
606 | return_dict=None,
607 | ):
608 | output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
609 | output_hidden_states = (
610 | output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
611 | )
612 | use_cache = use_cache if use_cache is not None else self.config.use_cache
613 | return_dict = return_dict if return_dict is not None else self.config.use_return_dict
614 |
615 | if input_ids is not None and inputs_embeds is not None:
616 | raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
617 | elif input_ids is not None:
618 | input_shape = input_ids.size()
619 | input_ids = input_ids.view(-1, input_shape[-1])
620 | batch_size = input_ids.shape[0]
621 | elif inputs_embeds is not None:
622 | input_shape = inputs_embeds.size()[:-1]
623 | batch_size = inputs_embeds.shape[0]
624 | else:
625 | raise ValueError("You have to specify either input_ids or inputs_embeds")
626 |
627 | if token_type_ids is not None:
628 | token_type_ids = token_type_ids.view(-1, input_shape[-1])
629 | if position_ids is not None:
630 | position_ids = position_ids.view(-1, input_shape[-1])
631 |
632 | if past_key_values is None:
633 | past_length = 0
634 | past_key_values = [None] * len(self.h)
635 | else:
636 | past_length = past_key_values[0][0].size(-2)
637 | if position_ids is None:
638 | device = input_ids.device if input_ids is not None else inputs_embeds.device
639 | position_ids = torch.arange(past_length, input_shape[-1] + past_length, dtype=torch.long, device=device)
640 | position_ids = position_ids.unsqueeze(0).view(-1, input_shape[-1])
641 |
642 | # Attention mask.
643 | if attention_mask is not None:
644 | assert batch_size > 0, "batch_size has to be defined and > 0"
645 | attention_mask = attention_mask.view(batch_size, -1)
646 | # We create a 3D attention mask from a 2D tensor mask.
647 | # Sizes are [batch_size, 1, 1, to_seq_length]
648 | # So we can broadcast to [batch_size, num_heads, from_seq_length, to_seq_length]
649 | # this attention mask is more simple than the triangular masking of causal attention
650 | # used in OpenAI GPT, we just need to prepare the broadcast dimension here.
651 | attention_mask = attention_mask[:, None, None, :]
652 |
653 | # Since attention_mask is 1.0 for positions we want to attend and 0.0 for
654 | # masked positions, this operation will create a tensor which is 0.0 for
655 | # positions we want to attend and -10000.0 for masked positions.
656 | # Since we are adding it to the raw scores before the softmax, this is
657 | # effectively the same as removing these entirely.
658 | attention_mask = attention_mask.to(dtype=self.dtype) # fp16 compatibility
659 | attention_mask = (1.0 - attention_mask) * -10000.0
660 |
661 | # If a 2D ou 3D attention mask is provided for the cross-attention
662 | # we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length]
663 | if self.config.add_cross_attention and encoder_hidden_states is not None:
664 | encoder_batch_size, encoder_sequence_length, _ = encoder_hidden_states.size()
665 | encoder_hidden_shape = (encoder_batch_size, encoder_sequence_length)
666 | if encoder_attention_mask is None:
667 | encoder_attention_mask = torch.ones(encoder_hidden_shape, device=device)
668 | encoder_attention_mask = self.invert_attention_mask(encoder_attention_mask)
669 | else:
670 | encoder_attention_mask = None
671 |
672 | # Prepare head mask if needed
673 | # 1.0 in head_mask indicate we keep the head
674 | # attention_probs has shape bsz x n_heads x N x N
675 | # head_mask has shape n_layer x batch x n_heads x N x N
676 | head_mask = self.get_head_mask(head_mask, self.config.n_layer)
677 |
678 | if inputs_embeds is None:
679 | inputs_embeds = self.wte(input_ids)
680 | # position_embeds = self.wpe(position_ids)
681 | hidden_states = inputs_embeds # + position_embeds
682 |
683 | if token_type_ids is not None:
684 | token_type_embeds = self.wte(token_type_ids)
685 | hidden_states = hidden_states + token_type_embeds
686 |
687 | hidden_states = self.drop(hidden_states)
688 |
689 | output_shape = input_shape + (hidden_states.size(-1),)
690 |
691 | presents = () if use_cache else None
692 | all_self_attentions = () if output_attentions else None
693 | all_cross_attentions = () if output_attentions and self.config.add_cross_attention else None
694 | all_hidden_states = () if output_hidden_states else None
695 | for i, (block, layer_past) in enumerate(zip(self.h, past_key_values)):
696 |
697 | if self.use_layers is not None and i >= self.use_layers:
698 | break
699 |
700 | # Model parallel
701 | if self.model_parallel:
702 | torch.cuda.set_device(hidden_states.device)
703 | # Ensure layer_past is on same device as hidden_states (might not be correct)
704 | if layer_past is not None:
705 | layer_past = layer_past.to(hidden_states.device)
706 | # Ensure that attention_mask is always on the same device as hidden_states
707 | if attention_mask is not None:
708 | attention_mask = attention_mask.to(hidden_states.device)
709 | if isinstance(head_mask, torch.Tensor):
710 | head_mask = head_mask.to(hidden_states.device)
711 | if output_hidden_states:
712 | all_hidden_states = all_hidden_states + (hidden_states,)
713 |
714 | if getattr(self.config, "gradient_checkpointing", False):
715 |
716 | def create_custom_forward(module):
717 | def custom_forward(*inputs):
718 | # checkpointing only works with tuple returns, not with lists
719 | return tuple(output for output in module(*inputs, use_cache, output_attentions))
720 |
721 | return custom_forward
722 |
723 | outputs = torch.utils.checkpoint.checkpoint(
724 | create_custom_forward(block),
725 | hidden_states,
726 | layer_past,
727 | attention_mask,
728 | head_mask[i],
729 | encoder_hidden_states,
730 | encoder_attention_mask,
731 | )
732 | else:
733 | outputs = block(
734 | hidden_states,
735 | layer_past=layer_past,
736 | attention_mask=attention_mask,
737 | head_mask=head_mask[i],
738 | encoder_hidden_states=encoder_hidden_states,
739 | encoder_attention_mask=encoder_attention_mask,
740 | use_cache=use_cache,
741 | output_attentions=output_attentions,
742 | )
743 |
744 | hidden_states, present = outputs[:2]
745 | if use_cache is True:
746 | presents = presents + (present,)
747 |
748 | if output_attentions:
749 | all_self_attentions = all_self_attentions + (outputs[2],)
750 | if self.config.add_cross_attention:
751 | all_cross_attentions = all_cross_attentions + (outputs[3],)
752 |
753 | # Model Parallel: If it's the last layer for that device, put things on the next device
754 | if self.model_parallel:
755 | for k, v in self.device_map.items():
756 | if i == v[-1] and "cuda:" + str(k) != self.last_device:
757 | hidden_states = hidden_states.to("cuda:" + str(k + 1))
758 |
759 | hidden_states = self.ln_f(hidden_states)
760 |
761 | hidden_states = hidden_states.view(*output_shape)
762 | # Add last hidden state
763 | if output_hidden_states:
764 | all_hidden_states = all_hidden_states + (hidden_states,)
765 |
766 | if not return_dict:
767 | return tuple(v for v in [hidden_states, presents, all_hidden_states, all_self_attentions] if v is not None)
768 |
769 | return BaseModelOutputWithPastAndCrossAttentions(
770 | last_hidden_state=hidden_states,
771 | past_key_values=presents,
772 | hidden_states=all_hidden_states,
773 | attentions=all_self_attentions,
774 | cross_attentions=all_cross_attentions,
775 | )
776 |
--------------------------------------------------------------------------------
/gym/decision_transformer/training/act_trainer.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import torch
3 |
4 | from decision_transformer.training.trainer import Trainer
5 |
6 |
7 | class ActTrainer(Trainer):
8 |
9 | def train_step(self):
10 | states, actions, rewards, dones, rtg, _, attention_mask = self.get_batch(self.batch_size)
11 | state_target, action_target, reward_target = torch.clone(states), torch.clone(actions), torch.clone(rewards)
12 |
13 | state_preds, action_preds, reward_preds = self.model.forward(
14 | states, actions, rewards, attention_mask=attention_mask, target_return=rtg[:,0],
15 | )
16 |
17 | act_dim = action_preds.shape[2]
18 | action_preds = action_preds.reshape(-1, act_dim)
19 | action_target = action_target[:,-1].reshape(-1, act_dim)
20 |
21 | loss = self.loss_fn(
22 | state_preds, action_preds, reward_preds,
23 | state_target, action_target, reward_target,
24 | )
25 | self.optimizer.zero_grad()
26 | loss.backward()
27 | self.optimizer.step()
28 |
29 | return loss.detach().cpu().item()
30 |
--------------------------------------------------------------------------------
/gym/decision_transformer/training/seq_trainer.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import torch
3 |
4 | from decision_transformer.training.trainer import Trainer
5 |
6 |
7 | class SequenceTrainer(Trainer):
8 |
9 | def train_step(self):
10 | states, actions, rewards, dones, rtg, timesteps, attention_mask = self.get_batch(self.batch_size)
11 | action_target = torch.clone(actions)
12 |
13 | state_preds, action_preds, reward_preds = self.model.forward(
14 | states, actions, rewards, rtg[:,:-1], timesteps, attention_mask=attention_mask,
15 | )
16 |
17 | act_dim = action_preds.shape[2]
18 | action_preds = action_preds.reshape(-1, act_dim)[attention_mask.reshape(-1) > 0]
19 | action_target = action_target.reshape(-1, act_dim)[attention_mask.reshape(-1) > 0]
20 |
21 | loss = self.loss_fn(
22 | None, action_preds, None,
23 | None, action_target, None,
24 | )
25 |
26 | self.optimizer.zero_grad()
27 | loss.backward()
28 | torch.nn.utils.clip_grad_norm_(self.model.parameters(), .25)
29 | self.optimizer.step()
30 |
31 | with torch.no_grad():
32 | self.diagnostics['training/action_error'] = torch.mean((action_preds-action_target)**2).detach().cpu().item()
33 |
34 | return loss.detach().cpu().item()
35 |
--------------------------------------------------------------------------------
/gym/decision_transformer/training/trainer.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import torch
3 |
4 | import time
5 |
6 |
7 | class Trainer:
8 |
9 | def __init__(self, model, optimizer, batch_size, get_batch, loss_fn, scheduler=None, eval_fns=None):
10 | self.model = model
11 | self.optimizer = optimizer
12 | self.batch_size = batch_size
13 | self.get_batch = get_batch
14 | self.loss_fn = loss_fn
15 | self.scheduler = scheduler
16 | self.eval_fns = [] if eval_fns is None else eval_fns
17 | self.diagnostics = dict()
18 |
19 | self.start_time = time.time()
20 |
21 | def train_iteration(self, num_steps, iter_num=0, print_logs=False):
22 |
23 | train_losses = []
24 | logs = dict()
25 |
26 | train_start = time.time()
27 |
28 | self.model.train()
29 | for _ in range(num_steps):
30 | train_loss = self.train_step()
31 | train_losses.append(train_loss)
32 | if self.scheduler is not None:
33 | self.scheduler.step()
34 |
35 | logs['time/training'] = time.time() - train_start
36 |
37 | eval_start = time.time()
38 |
39 | self.model.eval()
40 | for eval_fn in self.eval_fns:
41 | outputs = eval_fn(self.model)
42 | for k, v in outputs.items():
43 | logs[f'evaluation/{k}'] = v
44 |
45 | logs['time/total'] = time.time() - self.start_time
46 | logs['time/evaluation'] = time.time() - eval_start
47 | logs['training/train_loss_mean'] = np.mean(train_losses)
48 | logs['training/train_loss_std'] = np.std(train_losses)
49 |
50 | for k in self.diagnostics:
51 | logs[k] = self.diagnostics[k]
52 |
53 | if print_logs:
54 | print('=' * 80)
55 | print(f'Iteration {iter_num}')
56 | for k, v in logs.items():
57 | print(f'{k}: {v}')
58 |
59 | return logs
60 |
61 | def train_step(self):
62 | states, actions, rewards, dones, attention_mask, returns = self.get_batch(self.batch_size)
63 | state_target, action_target, reward_target = torch.clone(states), torch.clone(actions), torch.clone(rewards)
64 |
65 | state_preds, action_preds, reward_preds = self.model.forward(
66 | states, actions, rewards, masks=None, attention_mask=attention_mask, target_return=returns,
67 | )
68 |
69 | # note: currently indexing & masking is not fully correct
70 | loss = self.loss_fn(
71 | state_preds, action_preds, reward_preds,
72 | state_target[:,1:], action_target, reward_target[:,1:],
73 | )
74 | self.optimizer.zero_grad()
75 | loss.backward()
76 | self.optimizer.step()
77 |
78 | return loss.detach().cpu().item()
79 |
--------------------------------------------------------------------------------
/gym/experiment.py:
--------------------------------------------------------------------------------
1 | import gym
2 | import numpy as np
3 | import torch
4 | import wandb
5 |
6 | import argparse
7 | import pickle
8 | import random
9 | import sys
10 |
11 | from decision_transformer.evaluation.evaluate_episodes import evaluate_episode, evaluate_episode_rtg
12 | from decision_transformer.models.decision_transformer import DecisionTransformer
13 | from decision_transformer.models.mlp_bc import MLPBCModel
14 | from decision_transformer.training.act_trainer import ActTrainer
15 | from decision_transformer.training.seq_trainer import SequenceTrainer
16 |
17 |
18 | def discount_cumsum(x, gamma):
19 | discount_cumsum = np.zeros_like(x)
20 | discount_cumsum[-1] = x[-1]
21 | for t in reversed(range(x.shape[0]-1)):
22 | discount_cumsum[t] = x[t] + gamma * discount_cumsum[t+1]
23 | return discount_cumsum
24 |
25 |
26 | def experiment(
27 | exp_prefix,
28 | variant,
29 | ):
30 | device = variant.get('device', 'cuda')
31 | log_to_wandb = variant.get('log_to_wandb', False)
32 |
33 | env_name, dataset = variant['env'], variant['dataset']
34 | model_type = variant['model_type']
35 | group_name = f'{exp_prefix}-{env_name}-{dataset}'
36 | exp_prefix = f'{group_name}-{random.randint(int(1e5), int(1e6) - 1)}'
37 |
38 | if env_name == 'hopper':
39 | env = gym.make('Hopper-v3')
40 | max_ep_len = 1000
41 | env_targets = [3600, 1800] # evaluation conditioning targets
42 | scale = 1000. # normalization for rewards/returns
43 | elif env_name == 'halfcheetah':
44 | env = gym.make('HalfCheetah-v3')
45 | max_ep_len = 1000
46 | env_targets = [12000, 6000]
47 | scale = 1000.
48 | elif env_name == 'walker2d':
49 | env = gym.make('Walker2d-v3')
50 | max_ep_len = 1000
51 | env_targets = [5000, 2500]
52 | scale = 1000.
53 | elif env_name == 'reacher2d':
54 | from decision_transformer.envs.reacher_2d import Reacher2dEnv
55 | env = Reacher2dEnv()
56 | max_ep_len = 100
57 | env_targets = [76, 40]
58 | scale = 10.
59 | else:
60 | raise NotImplementedError
61 |
62 | if model_type == 'bc':
63 | env_targets = env_targets[:1] # since BC ignores target, no need for different evaluations
64 |
65 | state_dim = env.observation_space.shape[0]
66 | act_dim = env.action_space.shape[0]
67 |
68 | # load dataset
69 | dataset_path = f'data/{env_name}-{dataset}-v2.pkl'
70 | with open(dataset_path, 'rb') as f:
71 | trajectories = pickle.load(f)
72 |
73 | # save all path information into separate lists
74 | mode = variant.get('mode', 'normal')
75 | states, traj_lens, returns = [], [], []
76 | for path in trajectories:
77 | if mode == 'delayed': # delayed: all rewards moved to end of trajectory
78 | path['rewards'][-1] = path['rewards'].sum()
79 | path['rewards'][:-1] = 0.
80 | states.append(path['observations'])
81 | traj_lens.append(len(path['observations']))
82 | returns.append(path['rewards'].sum())
83 | traj_lens, returns = np.array(traj_lens), np.array(returns)
84 |
85 | # used for input normalization
86 | states = np.concatenate(states, axis=0)
87 | state_mean, state_std = np.mean(states, axis=0), np.std(states, axis=0) + 1e-6
88 |
89 | num_timesteps = sum(traj_lens)
90 |
91 | print('=' * 50)
92 | print(f'Starting new experiment: {env_name} {dataset}')
93 | print(f'{len(traj_lens)} trajectories, {num_timesteps} timesteps found')
94 | print(f'Average return: {np.mean(returns):.2f}, std: {np.std(returns):.2f}')
95 | print(f'Max return: {np.max(returns):.2f}, min: {np.min(returns):.2f}')
96 | print('=' * 50)
97 |
98 | K = variant['K']
99 | batch_size = variant['batch_size']
100 | num_eval_episodes = variant['num_eval_episodes']
101 | pct_traj = variant.get('pct_traj', 1.)
102 |
103 | # only train on top pct_traj trajectories (for %BC experiment)
104 | num_timesteps = max(int(pct_traj*num_timesteps), 1)
105 | sorted_inds = np.argsort(returns) # lowest to highest
106 | num_trajectories = 1
107 | timesteps = traj_lens[sorted_inds[-1]]
108 | ind = len(trajectories) - 2
109 | while ind >= 0 and timesteps + traj_lens[sorted_inds[ind]] <= num_timesteps:
110 | timesteps += traj_lens[sorted_inds[ind]]
111 | num_trajectories += 1
112 | ind -= 1
113 | sorted_inds = sorted_inds[-num_trajectories:]
114 |
115 | # used to reweight sampling so we sample according to timesteps instead of trajectories
116 | p_sample = traj_lens[sorted_inds] / sum(traj_lens[sorted_inds])
117 |
118 | def get_batch(batch_size=256, max_len=K):
119 | batch_inds = np.random.choice(
120 | np.arange(num_trajectories),
121 | size=batch_size,
122 | replace=True,
123 | p=p_sample, # reweights so we sample according to timesteps
124 | )
125 |
126 | s, a, r, d, rtg, timesteps, mask = [], [], [], [], [], [], []
127 | for i in range(batch_size):
128 | traj = trajectories[int(sorted_inds[batch_inds[i]])]
129 | si = random.randint(0, traj['rewards'].shape[0] - 1)
130 |
131 | # get sequences from dataset
132 | s.append(traj['observations'][si:si + max_len].reshape(1, -1, state_dim))
133 | a.append(traj['actions'][si:si + max_len].reshape(1, -1, act_dim))
134 | r.append(traj['rewards'][si:si + max_len].reshape(1, -1, 1))
135 | if 'terminals' in traj:
136 | d.append(traj['terminals'][si:si + max_len].reshape(1, -1))
137 | else:
138 | d.append(traj['dones'][si:si + max_len].reshape(1, -1))
139 | timesteps.append(np.arange(si, si + s[-1].shape[1]).reshape(1, -1))
140 | timesteps[-1][timesteps[-1] >= max_ep_len] = max_ep_len-1 # padding cutoff
141 | rtg.append(discount_cumsum(traj['rewards'][si:], gamma=1.)[:s[-1].shape[1] + 1].reshape(1, -1, 1))
142 | if rtg[-1].shape[1] <= s[-1].shape[1]:
143 | rtg[-1] = np.concatenate([rtg[-1], np.zeros((1, 1, 1))], axis=1)
144 |
145 | # padding and state + reward normalization
146 | tlen = s[-1].shape[1]
147 | s[-1] = np.concatenate([np.zeros((1, max_len - tlen, state_dim)), s[-1]], axis=1)
148 | s[-1] = (s[-1] - state_mean) / state_std
149 | a[-1] = np.concatenate([np.ones((1, max_len - tlen, act_dim)) * -10., a[-1]], axis=1)
150 | r[-1] = np.concatenate([np.zeros((1, max_len - tlen, 1)), r[-1]], axis=1)
151 | d[-1] = np.concatenate([np.ones((1, max_len - tlen)) * 2, d[-1]], axis=1)
152 | rtg[-1] = np.concatenate([np.zeros((1, max_len - tlen, 1)), rtg[-1]], axis=1) / scale
153 | timesteps[-1] = np.concatenate([np.zeros((1, max_len - tlen)), timesteps[-1]], axis=1)
154 | mask.append(np.concatenate([np.zeros((1, max_len - tlen)), np.ones((1, tlen))], axis=1))
155 |
156 | s = torch.from_numpy(np.concatenate(s, axis=0)).to(dtype=torch.float32, device=device)
157 | a = torch.from_numpy(np.concatenate(a, axis=0)).to(dtype=torch.float32, device=device)
158 | r = torch.from_numpy(np.concatenate(r, axis=0)).to(dtype=torch.float32, device=device)
159 | d = torch.from_numpy(np.concatenate(d, axis=0)).to(dtype=torch.long, device=device)
160 | rtg = torch.from_numpy(np.concatenate(rtg, axis=0)).to(dtype=torch.float32, device=device)
161 | timesteps = torch.from_numpy(np.concatenate(timesteps, axis=0)).to(dtype=torch.long, device=device)
162 | mask = torch.from_numpy(np.concatenate(mask, axis=0)).to(device=device)
163 |
164 | return s, a, r, d, rtg, timesteps, mask
165 |
166 | def eval_episodes(target_rew):
167 | def fn(model):
168 | returns, lengths = [], []
169 | for _ in range(num_eval_episodes):
170 | with torch.no_grad():
171 | if model_type == 'dt':
172 | ret, length = evaluate_episode_rtg(
173 | env,
174 | state_dim,
175 | act_dim,
176 | model,
177 | max_ep_len=max_ep_len,
178 | scale=scale,
179 | target_return=target_rew/scale,
180 | mode=mode,
181 | state_mean=state_mean,
182 | state_std=state_std,
183 | device=device,
184 | )
185 | else:
186 | ret, length = evaluate_episode(
187 | env,
188 | state_dim,
189 | act_dim,
190 | model,
191 | max_ep_len=max_ep_len,
192 | target_return=target_rew/scale,
193 | mode=mode,
194 | state_mean=state_mean,
195 | state_std=state_std,
196 | device=device,
197 | )
198 | returns.append(ret)
199 | lengths.append(length)
200 | return {
201 | f'target_{target_rew}_return_mean': np.mean(returns),
202 | f'target_{target_rew}_return_std': np.std(returns),
203 | f'target_{target_rew}_length_mean': np.mean(lengths),
204 | f'target_{target_rew}_length_std': np.std(lengths),
205 | }
206 | return fn
207 |
208 | if model_type == 'dt':
209 | model = DecisionTransformer(
210 | state_dim=state_dim,
211 | act_dim=act_dim,
212 | max_length=K,
213 | max_ep_len=max_ep_len,
214 | hidden_size=variant['embed_dim'],
215 | n_layer=variant['n_layer'],
216 | n_head=variant['n_head'],
217 | n_inner=4*variant['embed_dim'],
218 | activation_function=variant['activation_function'],
219 | n_positions=1024,
220 | resid_pdrop=variant['dropout'],
221 | attn_pdrop=variant['dropout'],
222 | )
223 | elif model_type == 'bc':
224 | model = MLPBCModel(
225 | state_dim=state_dim,
226 | act_dim=act_dim,
227 | max_length=K,
228 | hidden_size=variant['embed_dim'],
229 | n_layer=variant['n_layer'],
230 | )
231 | else:
232 | raise NotImplementedError
233 |
234 | model = model.to(device=device)
235 |
236 | warmup_steps = variant['warmup_steps']
237 | optimizer = torch.optim.AdamW(
238 | model.parameters(),
239 | lr=variant['learning_rate'],
240 | weight_decay=variant['weight_decay'],
241 | )
242 | scheduler = torch.optim.lr_scheduler.LambdaLR(
243 | optimizer,
244 | lambda steps: min((steps+1)/warmup_steps, 1)
245 | )
246 |
247 | if model_type == 'dt':
248 | trainer = SequenceTrainer(
249 | model=model,
250 | optimizer=optimizer,
251 | batch_size=batch_size,
252 | get_batch=get_batch,
253 | scheduler=scheduler,
254 | loss_fn=lambda s_hat, a_hat, r_hat, s, a, r: torch.mean((a_hat - a)**2),
255 | eval_fns=[eval_episodes(tar) for tar in env_targets],
256 | )
257 | elif model_type == 'bc':
258 | trainer = ActTrainer(
259 | model=model,
260 | optimizer=optimizer,
261 | batch_size=batch_size,
262 | get_batch=get_batch,
263 | scheduler=scheduler,
264 | loss_fn=lambda s_hat, a_hat, r_hat, s, a, r: torch.mean((a_hat - a)**2),
265 | eval_fns=[eval_episodes(tar) for tar in env_targets],
266 | )
267 |
268 | if log_to_wandb:
269 | wandb.init(
270 | name=exp_prefix,
271 | group=group_name,
272 | project='decision-transformer',
273 | config=variant
274 | )
275 | # wandb.watch(model) # wandb has some bug
276 |
277 | for iter in range(variant['max_iters']):
278 | outputs = trainer.train_iteration(num_steps=variant['num_steps_per_iter'], iter_num=iter+1, print_logs=True)
279 | if log_to_wandb:
280 | wandb.log(outputs)
281 |
282 |
283 | if __name__ == '__main__':
284 | parser = argparse.ArgumentParser()
285 | parser.add_argument('--env', type=str, default='hopper')
286 | parser.add_argument('--dataset', type=str, default='medium') # medium, medium-replay, medium-expert, expert
287 | parser.add_argument('--mode', type=str, default='normal') # normal for standard setting, delayed for sparse
288 | parser.add_argument('--K', type=int, default=20)
289 | parser.add_argument('--pct_traj', type=float, default=1.)
290 | parser.add_argument('--batch_size', type=int, default=64)
291 | parser.add_argument('--model_type', type=str, default='dt') # dt for decision transformer, bc for behavior cloning
292 | parser.add_argument('--embed_dim', type=int, default=128)
293 | parser.add_argument('--n_layer', type=int, default=3)
294 | parser.add_argument('--n_head', type=int, default=1)
295 | parser.add_argument('--activation_function', type=str, default='relu')
296 | parser.add_argument('--dropout', type=float, default=0.1)
297 | parser.add_argument('--learning_rate', '-lr', type=float, default=1e-4)
298 | parser.add_argument('--weight_decay', '-wd', type=float, default=1e-4)
299 | parser.add_argument('--warmup_steps', type=int, default=10000)
300 | parser.add_argument('--num_eval_episodes', type=int, default=100)
301 | parser.add_argument('--max_iters', type=int, default=10)
302 | parser.add_argument('--num_steps_per_iter', type=int, default=10000)
303 | parser.add_argument('--device', type=str, default='cuda')
304 | parser.add_argument('--log_to_wandb', '-w', type=bool, default=False)
305 |
306 | args = parser.parse_args()
307 |
308 | experiment('gym-experiment', variant=vars(args))
309 |
--------------------------------------------------------------------------------
/gym/readme-gym.md:
--------------------------------------------------------------------------------
1 |
2 | # OpenAI Gym
3 |
4 | ## Installation
5 |
6 | Experiments require MuJoCo.
7 | Follow the instructions in the [mujoco-py repo](https://github.com/openai/mujoco-py) to install.
8 | Then, dependencies can be installed with the following command:
9 |
10 | ```
11 | conda env create -f conda_env.yml
12 | ```
13 |
14 | ## Downloading datasets
15 |
16 | Datasets are stored in the `data` directory.
17 | Install the [D4RL repo](https://github.com/rail-berkeley/d4rl), following the instructions there.
18 | Then, run the following script in order to download the datasets and save them in our format:
19 |
20 | ```
21 | python download_d4rl_datasets.py
22 | ```
23 |
24 | ## Example usage
25 |
26 | Experiments can be reproduced with the following:
27 |
28 | ```
29 | python experiment.py --env hopper --dataset medium --model_type dt
30 | ```
31 |
32 | Adding `-w True` will log results to Weights and Biases.
33 |
--------------------------------------------------------------------------------