├── .gitignore ├── README.md ├── rl2 ├── __init__.py ├── agents │ ├── __init__.py │ ├── architectures │ │ ├── __init__.py │ │ ├── common │ │ │ ├── __init__.py │ │ │ ├── attention.py │ │ │ └── normalization.py │ │ ├── gru.py │ │ ├── lstm.py │ │ ├── snail.py │ │ └── transformer.py │ ├── heads │ │ ├── __init__.py │ │ ├── policy_heads.py │ │ └── value_heads.py │ ├── integration │ │ ├── __init__.py │ │ ├── policy_net.py │ │ └── value_net.py │ └── preprocessing │ │ ├── __init__.py │ │ ├── common.py │ │ ├── tabular.py │ │ └── vision.py ├── algos │ ├── __init__.py │ ├── common.py │ └── ppo.py ├── envs │ ├── __init__.py │ ├── abstract.py │ ├── bandit_env.py │ └── mdp_env.py └── utils │ ├── __init__.py │ ├── checkpoint_util.py │ ├── comm_util.py │ ├── constants.py │ ├── optim_util.py │ └── stat_util.py ├── setup.py └── train.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Misc.: manually added 2 | *~ 3 | .idea/ 4 | checkpoints/ 5 | 6 | # Byte-compiled / optimized / DLL files 7 | __pycache__/ 8 | *.py[cod] 9 | *$py.class 10 | 11 | # C extensions 12 | *.so 13 | 14 | # Distribution / packaging 15 | .Python 16 | build/ 17 | develop-eggs/ 18 | dist/ 19 | downloads/ 20 | eggs/ 21 | .eggs/ 22 | lib/ 23 | lib64/ 24 | parts/ 25 | sdist/ 26 | var/ 27 | wheels/ 28 | pip-wheel-metadata/ 29 | share/python-wheels/ 30 | *.egg-info/ 31 | .installed.cfg 32 | *.egg 33 | MANIFEST 34 | 35 | # PyInstaller 36 | # Usually these files are written by a python script from a template 37 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 38 | *.manifest 39 | *.spec 40 | 41 | # Installer logs 42 | pip-log.txt 43 | pip-delete-this-directory.txt 44 | 45 | # Unit test / coverage reports 46 | htmlcov/ 47 | .tox/ 48 | .nox/ 49 | .coverage 50 | .coverage.* 51 | .cache 52 | nosetests.xml 53 | coverage.xml 54 | *.cover 55 | *.py,cover 56 | .hypothesis/ 57 | .pytest_cache/ 58 | 59 | # Translations 60 | *.mo 61 | *.pot 62 | 63 | # Django stuff: 64 | *.log 65 | local_settings.py 66 | db.sqlite3 67 | db.sqlite3-journal 68 | 69 | # Flask stuff: 70 | instance/ 71 | .webassets-cache 72 | 73 | # Scrapy stuff: 74 | .scrapy 75 | 76 | # Sphinx documentation 77 | docs/_build/ 78 | 79 | # PyBuilder 80 | target/ 81 | 82 | # Jupyter Notebook 83 | .ipynb_checkpoints 84 | 85 | # IPython 86 | profile_default/ 87 | ipython_config.py 88 | 89 | # pyenv 90 | .python-version 91 | 92 | # pipenv 93 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 94 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 95 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 96 | # install all needed dependencies. 97 | #Pipfile.lock 98 | 99 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 100 | __pypackages__/ 101 | 102 | # Celery stuff 103 | celerybeat-schedule 104 | celerybeat.pid 105 | 106 | # SageMath parsed files 107 | *.sage.py 108 | 109 | # Environments 110 | .env 111 | .venv 112 | env/ 113 | venv/ 114 | ENV/ 115 | env.bak/ 116 | venv.bak/ 117 | 118 | # Spyder project settings 119 | .spyderproject 120 | .spyproject 121 | 122 | # Rope project settings 123 | .ropeproject 124 | 125 | # mkdocs documentation 126 | /site 127 | 128 | # mypy 129 | .mypy_cache/ 130 | .dmypy.json 131 | dmypy.json 132 | 133 | # Pyre type checker 134 | .pyre/ 135 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # RL^2: Fast Reinforcement Learning via Slow Reinforcement Learning 2 | 3 | This repo contains implementations of the algorithms, architectures, and environments from Duan et al., 2016 - ['RL^2: Fast Reinforcement Learning via Slow Reinforcement Learning'](https://arxiv.org/pdf/1611.02779.pdf), and Mishra et al., 2017 - ['A Simple Neural Attentive Meta-Learner'](https://arxiv.org/pdf/1707.03141.pdf). 4 | 5 | It has also recently been redesigned to facilitate rapid prototyping of new stateful architectures for memory-based meta-reinforcement learning agents. 6 | 7 | ## Background 8 | 9 | The main idea of RL^2 is that a reinforcement learning agent with memory can be trained on a distribution of environments, 10 | and can thereby learn an algorithm to effectively transition from exploring these environments to exploiting them. 11 | 12 | In fact, the RL^2 training curriculum effectively trains an agent to behave as if it possesses a probabilistic model of the possible environments it is acting in. 13 | 14 | This theoretical background of RL^2 is discussed by Ortega et al., 2019 and a concise treatment can be found in [my blog post](https://lucaslingle.wordpress.com/2021/10/07/on-memory-based-meta-reinforcement-learning/). 15 | 16 | ## Getting Started 17 | 18 | Install the following system dependencies: 19 | #### Ubuntu 20 | ```bash 21 | sudo apt-get update 22 | sudo apt-get install -y cmake openmpi-bin openmpi-doc libopenmpi-dev 23 | ``` 24 | 25 | #### Mac OS X 26 | Installation of the system packages on Mac requires [Homebrew](https://brew.sh). With Homebrew installed, run the following: 27 | ```bash 28 | brew install cmake openmpi 29 | ``` 30 | 31 | #### Everyone 32 | Once the system dependencies have been installed, it's time to install the python dependencies. 33 | Install the conda package manager from https://docs.conda.io/en/latest/miniconda.html 34 | 35 | Then run 36 | ```bash 37 | conda create --name pytorch_rl2 python=3.8.1 38 | conda activate pytorch_rl2 39 | git clone https://github.com/lucaslingle/pytorch_rl2 40 | cd pytorch_rl2 41 | pip install -e . 42 | ``` 43 | 44 | ## Usage 45 | 46 | ### Training 47 | To train the default settings, you can simply type: 48 | ```bash 49 | mpirun -np 8 python -m train 50 | ``` 51 | 52 | This will launch 8 parallel processes, each running the ```train.py``` script. These processes each generate meta-episodes separately and then synchronously train on the collected experience in a data-parallel manner, with gradient information and model parameters synchronized across processes using mpi4py. 53 | 54 | To see additional configuration options, you can simply type ```python train.py --help```. Among other options, we support various architectures including GRU, LSTM, SNAIL, and Transformer models. 55 | 56 | ### Checkpoints 57 | By default, checkpoints are saved to ```./checkpoints/defaults```. To pick a different checkpoint directory during training, 58 | you can set the ```--checkpoint_dir``` flag, and to pick a different checkpoint name, you can set the 59 | ```--model_name``` flag. 60 | 61 | ## Reproducing the Papers 62 | 63 | Our implementations closely matched or slightly exceeded the published performance of RL^2 GRU (Duan et al., 2016) and RL^2 SNAIL (Mishra et al., 2017) in every setting we tested. 64 | 65 | In the tables below, ```n``` is the number of episodes per meta-episode, and ```k``` is the number of actions. 66 | Following Duan et al., 2016 and Mishra et al., 2017, in our tabular MDP experiments, all MDPs have 10 states and 5 actions, and the episode length is 10. 67 | 68 | ### Bandit case: 69 | 70 | | Setup | Random | Gittins | TS | OTS | UCB1 | eps-Greedy | Greedy | RL^2 GRU (paper) | RL^2 GRU (ours) | RL^2 SNAIL (paper) | RL^2 SNAIL (ours) | 71 | | ---------- | ------ | ------- | ----- | ----- | ----- | ---------- | ------ | ---------------- | --------------- | ------------------ | ------------------ | 72 | | n=10,k=5 | 5.0 | 6.6 | 5.7 | 6.5 | 6.7 | 6.6 | 6.6 | 6.7 | 6.7 | 6.6 | 6.8 | 73 | | n=10,k=10 | 5.0 | 6.6 | 5.5 | 6.2 | 6.7 | 6.6 | 6.6 | 6.7 | | 6.7 | | 74 | | n=10,k=50 | 5.1 | 6.5 | 5.2 | 5.5 | 6.6 | 6.5 | 6.5 | 6.8 | | 6.7 | | 75 | | n=100,k=5 | 49.9 | 78.3 | 74.7 | 77.9 | 78.0 | 75.4 | 74.8 | 78.7 | 78.7 | 79.1 | 78.5 | 76 | | n=100,k=10 | 49.9 | 82.8 | 76.7 | 81.4 | 82.4 | 77.4 | 77.1 | 83.5 | | 83.5 | | 77 | | n=100,k=50 | 49.8 | 85.2 | 64.5 | 67.7 | 84.3 | 78.3 | 78.0 | 84.9 | | 85.1 | | 78 | | n=500,k=5 | 249.8 | 405.8 | 402.0 | 406.7 | 405.8 | 388.2 | 380.6 | 401.6 | | 408.1 | | 79 | | n=500,k=10 | 249.0 | 437.8 | 429.5 | 438.9 | 437.1 | 408.0 | 395.0 | 432.5 | | 432.4 | | 80 | | n=500,k=50 | 249.6 | 463.7 | 427.2 | 437.6 | 457.6 | 413.6 | 402.8 | 438.9 | | 442.6 | | 81 | 82 | ### MDP case: 83 | 84 | | Setup | Random | PSRL | OPSRL | UCRL2 | BEB | eps-Greedy | Greedy | RL^2 GRU (paper) | RL^2 GRU (ours) | RL^2 SNAIL (paper) | RL^2 SNAIL (ours) | 85 | | ---------- | ------ | ------ | ------ | ------ | ------ | ---------- | ------ | ---------------- | --------------- | ------------------ | ------------------ | 86 | | n=10 | 100.1 | 138.1 | 144.1 | 146.6 | 150.2 | 132.8 | 134.8 | 156.2 | 157.3 | 159.1 | 160.1 | 87 | | n=25 | 250.2 | 408.8 | 425.2 | 424.1 | 427.8 | 377.3 | 368.8 | 445.7 | | 447.2 | | 88 | | n=50 | 499.7 | 904.4 | 930.7 | 918.9 | 917.8 | 823.3 | 769.3 | 936.1 | | 942.3 | | 89 | | n=75 | 749.9 | 1417.1 | 1449.2 | 1427.6 | 1422.6 | 1293.9 | 1172.9 | 1428.8 | | 1447.5 | | 90 | | n=100 | 999.4 | 1939.5 | 1973.9 | 1942.1 | 1935.1 | 1778.2 | 1578.5 | 1913.7 | | 1953.1 | | 91 | 92 | To perform policy optimization, we used PPO. We used layer norm instead of weight norm, and we report peak performance over training. Our performance statistics are averaged over 1000 meta-episodes. 93 | 94 | In all cases, for training we used a configuration where the total number of observations per policy improvement phase was equal to 240,000. This is comparable to the 250,000 used in prior works. 95 | The per-process batch size was 60 trajectories. There were 8 processes. There were 8 PPO optimization epochs per policy improvement phase. 96 | 97 | All other hyperparameters were set to their default values in the ```train.py``` script, except for the SNAIL experiments, where we used ```--num_features=32``` due to the skip-connections. 98 | -------------------------------------------------------------------------------- /rl2/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lucaslingle/pytorch_rl2/51afa4b6c6d05b469c7f74076b9490ae0809b289/rl2/__init__.py -------------------------------------------------------------------------------- /rl2/agents/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lucaslingle/pytorch_rl2/51afa4b6c6d05b469c7f74076b9490ae0809b289/rl2/agents/__init__.py -------------------------------------------------------------------------------- /rl2/agents/architectures/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lucaslingle/pytorch_rl2/51afa4b6c6d05b469c7f74076b9490ae0809b289/rl2/agents/architectures/__init__.py -------------------------------------------------------------------------------- /rl2/agents/architectures/common/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lucaslingle/pytorch_rl2/51afa4b6c6d05b469c7f74076b9490ae0809b289/rl2/agents/architectures/common/__init__.py -------------------------------------------------------------------------------- /rl2/agents/architectures/common/attention.py: -------------------------------------------------------------------------------- 1 | """ 2 | Implements attention operations for RL^2 agents 3 | """ 4 | 5 | import functools 6 | 7 | import torch as tc 8 | 9 | 10 | @functools.lru_cache 11 | def sinusoidal_embeddings(src_len, d_model, reverse=False): 12 | pos_seq = tc.arange(src_len) 13 | inv_freq = 1 / (10000 ** (tc.arange(0, d_model, 2) / d_model)) 14 | sinusoid_input = pos_seq.view(-1, 1) * inv_freq.view(1, -1) 15 | pos_emb = tc.cat((tc.sin(sinusoid_input), tc.cos(sinusoid_input)), dim=-1) 16 | if reverse: 17 | pos_emb = tc.flip(pos_emb, dims=(0,)) 18 | return pos_emb 19 | 20 | 21 | @functools.lru_cache 22 | def get_mask(dest_len, src_len): 23 | i = tc.arange(dest_len).view(dest_len, 1) 24 | j = tc.arange(src_len).view(1, src_len) 25 | m = i >= j - (src_len - dest_len) 26 | return m.int() 27 | 28 | 29 | def masked_self_attention(qs, ks, vs, use_mask=True): 30 | scores = tc.bmm(qs, ks.permute(0, 2, 1)) 31 | scores /= qs.shape[-1] ** 0.5 32 | 33 | if use_mask: 34 | mask = get_mask(dest_len=qs.shape[1], src_len=ks.shape[1]) 35 | mask = mask.view(1, *mask.shape) 36 | scores = scores * mask - 1e10 * (1 - mask) 37 | 38 | ws = tc.nn.Softmax(dim=-1)(scores) 39 | output = tc.bmm(ws, vs) 40 | return output 41 | 42 | 43 | def rel_shift(inputs): 44 | # inputs should be a 3d tensor with shape [B, T2, T1+T2] 45 | # this function implements the part of the shift from Dai et al., Appdx B 46 | input_shape = inputs.shape 47 | zp = tc.zeros(size=(input_shape[0], input_shape[1], 1), dtype=tc.float32) 48 | inputs = tc.cat((zp, inputs), dim=2) 49 | inputs = tc.reshape( 50 | inputs, [input_shape[0], input_shape[2]+1, input_shape[1]]) 51 | inputs = inputs[:, 1:, :] 52 | inputs = tc.reshape(inputs, input_shape) 53 | return inputs 54 | 55 | 56 | def relative_masked_self_attention(qs, ks, vs, rs, u_, v_, use_mask=True): 57 | ac_qs = qs + u_.unsqueeze(1) 58 | bd_qs = qs + v_.unsqueeze(1) 59 | ac = tc.bmm(ac_qs, ks.permute(0, 2, 1)) 60 | bd = tc.bmm(bd_qs, rs.permute(0, 2, 1)) 61 | bd = rel_shift(bd) 62 | 63 | bd = bd[:, :, 0:ks.shape[1]] # this is a no-op unless prev row attn is used 64 | scores = ac + bd 65 | scores /= qs.shape[-1] ** 0.5 66 | 67 | if use_mask: 68 | mask = get_mask(dest_len=qs.shape[1], src_len=ks.shape[1]) 69 | mask = mask.view(1, *mask.shape) 70 | scores = scores * mask - 1e10 * (1 - mask) 71 | 72 | ws = tc.nn.Softmax(dim=-1)(scores) 73 | output = tc.bmm(ws, vs) 74 | return output 75 | 76 | 77 | class MultiheadSelfAttention(tc.nn.Module): 78 | def __init__( 79 | self, 80 | input_dim, 81 | num_heads, 82 | num_head_features, 83 | position_encoding_style, 84 | attention_style, 85 | row_len=None 86 | ): 87 | assert position_encoding_style in ['abs', 'rel'] 88 | assert attention_style in ['full', 'row', 'previous_row', 'column'] 89 | assert attention_style == 'full' or row_len is not None 90 | 91 | super().__init__() 92 | self._input_dim = input_dim 93 | self._num_heads = num_heads 94 | self._num_head_features = num_head_features 95 | self._position_encoding_style = position_encoding_style 96 | self._attention_style = attention_style 97 | self._row_len = row_len 98 | 99 | self._qkv_linear = tc.nn.Linear( 100 | in_features=self._input_dim, 101 | out_features=(self._num_heads * self._num_head_features * 3), 102 | bias=False) 103 | tc.nn.init.xavier_normal_(self._qkv_linear.weight) 104 | 105 | if self._position_encoding_style == 'rel': 106 | self._r_linear = tc.nn.Linear( 107 | in_features=self._input_dim, 108 | out_features=(self._num_heads * self._num_head_features), 109 | bias=False) 110 | tc.nn.init.xavier_normal_(self._r_linear.weight) 111 | self._u = tc.nn.Parameter( 112 | tc.zeros(size=(self._num_heads * self._num_head_features,), 113 | dtype=tc.float32)) 114 | self._v = tc.nn.Parameter( 115 | tc.zeros(size=(self._num_heads * self._num_head_features,), 116 | dtype=tc.float32)) 117 | 118 | def attn_preop(self, qs, ks, vs, sampling): 119 | assert type(qs) == type(ks) == type(vs) 120 | assert (sampling and type(qs) == list) or \ 121 | (not sampling and type(qs) == tc.Tensor) 122 | 123 | if self._attention_style == 'full': 124 | if sampling: 125 | qs = tc.stack(qs, dim=1) 126 | ks = tc.stack(ks, dim=1) 127 | vs = tc.stack(vs, dim=1) 128 | return qs, ks, vs, qs.shape[0] 129 | else: 130 | return qs, ks, vs, qs.shape[0] 131 | 132 | if self._attention_style == 'row': 133 | if sampling: 134 | assert len(qs) == 1 135 | row_idx = (len(ks)-1) // self._row_len 136 | row_flat_idx = row_idx * self._row_len 137 | ks = ks[row_flat_idx:] # get relevant row 138 | vs = vs[row_flat_idx:] 139 | qs = tc.stack(qs, dim=1) 140 | ks = tc.stack(ks, dim=1) 141 | vs = tc.stack(vs, dim=1) 142 | return qs, ks, vs, qs.shape[0] 143 | else: 144 | assert qs.shape[1] == ks.shape[1] == vs.shape[1] 145 | assert qs.shape[1] % self._row_len == 0 146 | qs = tc.reshape(qs, [-1, self._row_len, qs.shape[-1]]) 147 | ks = tc.reshape(ks, [-1, self._row_len, ks.shape[-1]]) 148 | vs = tc.reshape(vs, [-1, self._row_len, vs.shape[-1]]) 149 | return qs, ks, vs, qs.shape[0] 150 | 151 | if self._attention_style == 'previous_row': 152 | if sampling: 153 | assert len(qs) == 1 154 | row_idx = (len(ks)-1) // self._row_len 155 | if row_idx > 0: 156 | prev_row_flat_idx = (row_idx - 1) * self._row_len 157 | ks = ks[prev_row_flat_idx:prev_row_flat_idx+self._row_len] 158 | vs = vs[prev_row_flat_idx:prev_row_flat_idx+self._row_len] 159 | qs = tc.stack(qs, dim=1) 160 | ks = tc.stack(ks, dim=1) 161 | vs = tc.stack(vs, dim=1) 162 | return qs, ks, vs, qs.shape[0] 163 | else: 164 | qs = tc.stack(qs, dim=1) 165 | prev_row_shape = [qs.shape[0], self._row_len, qs.shape[2]] 166 | ks = tc.zeros(size=prev_row_shape, dtype=tc.float32) 167 | vs = tc.zeros(size=prev_row_shape, dtype=tc.float32) 168 | return qs, ks, vs, qs.shape[0] 169 | else: 170 | assert qs.shape[1] == ks.shape[1] == vs.shape[1] 171 | assert qs.shape[1] % self._row_len == 0 172 | n_rows = qs.shape[1] // self._row_len 173 | qs = tc.reshape(qs, [-1, n_rows, self._row_len, qs.shape[-1]]) 174 | ks = tc.reshape(ks, [-1, n_rows, self._row_len, ks.shape[-1]]) 175 | vs = tc.reshape(vs, [-1, n_rows, self._row_len, vs.shape[-1]]) 176 | ks = tc.nn.functional.pad(ks[:,:-1,:,:], (0,0,0,0,1,0)) 177 | vs = tc.nn.functional.pad(vs[:,:-1,:,:], (0,0,0,0,1,0)) 178 | qs = tc.reshape(qs, [-1, self._row_len, qs.shape[-1]]) 179 | ks = tc.reshape(ks, [-1, self._row_len, ks.shape[-1]]) 180 | vs = tc.reshape(vs, [-1, self._row_len, vs.shape[-1]]) 181 | return qs, ks, vs, qs.shape[0] 182 | 183 | if self._attention_style == 'column': 184 | if sampling: 185 | assert len(qs) == 1 186 | column_flat_idx = (len(ks)-1) % self._row_len 187 | ks = ks[column_flat_idx::self._row_len] # get relevant column 188 | vs = vs[column_flat_idx::self._row_len] 189 | qs = tc.stack(qs, dim=1) 190 | ks = tc.stack(ks, dim=1) 191 | vs = tc.stack(vs, dim=1) 192 | return qs, ks, vs, qs.shape[0] 193 | else: 194 | assert qs.shape[1] == ks.shape[1] == vs.shape[1] 195 | assert qs.shape[1] % self._row_len == 0 196 | n_rows = qs.shape[1] // self._row_len 197 | qs = tc.reshape(qs, [-1, n_rows, self._row_len, qs.shape[-1]]) 198 | ks = tc.reshape(ks, [-1, n_rows, self._row_len, ks.shape[-1]]) 199 | vs = tc.reshape(vs, [-1, n_rows, self._row_len, vs.shape[-1]]) 200 | qs = qs.permute(0, 2, 1, 3) 201 | ks = ks.permute(0, 2, 1, 3) 202 | vs = vs.permute(0, 2, 1, 3) 203 | qs = tc.reshape(qs, [-1, n_rows, qs.shape[-1]]) 204 | ks = tc.reshape(ks, [-1, n_rows, ks.shape[-1]]) 205 | vs = tc.reshape(vs, [-1, n_rows, vs.shape[-1]]) 206 | return qs, ks, vs, qs.shape[0] 207 | 208 | raise NotImplementedError 209 | 210 | def attn_postop(self, attn_out, input_len, sampling): 211 | if self._attention_style == 'full': 212 | return attn_out 213 | 214 | assert input_len % self._row_len == 0 or sampling 215 | 216 | if self._attention_style == 'row': 217 | if sampling: 218 | return attn_out 219 | else: 220 | attn_out = tc.reshape(attn_out, [-1, input_len, attn_out.shape[-1]]) 221 | return attn_out 222 | 223 | if self._attention_style == 'previous_row': 224 | if sampling: 225 | return attn_out 226 | else: 227 | attn_out = tc.reshape(attn_out, [-1, input_len, attn_out.shape[-1]]) 228 | return attn_out 229 | 230 | if self._attention_style == 'column': 231 | if sampling: 232 | return attn_out 233 | else: 234 | n_rows = input_len // self._row_len 235 | transposed_block_shape = [-1, self._row_len, n_rows, attn_out.shape[-1]] 236 | attn_out = tc.reshape(attn_out, transposed_block_shape) 237 | attn_out = attn_out.permute(0, 2, 1, 3) 238 | attn_out = tc.reshape(attn_out, [-1, input_len, attn_out.shape[-1]]) 239 | return attn_out 240 | 241 | raise NotImplementedError 242 | 243 | def split_heads(self, inputs): 244 | return tc.cat(tc.chunk(inputs, self._num_heads, dim=-1), dim=0) 245 | 246 | def merge_heads(self, inputs): 247 | return tc.cat(tc.chunk(inputs, self._num_heads, dim=0), dim=-1) 248 | 249 | def forward(self, inputs, past_kvs=None): 250 | """ 251 | Args: 252 | inputs: present input tensor with shape [B, T2, I] 253 | past_kvs: optional past kvs 254 | 255 | Returns: 256 | output tensor and new kvs 257 | """ 258 | assert inputs.shape[-1] == self._input_dim 259 | sampling = (inputs.shape[1] == 1) 260 | use_mask = (self._attention_style != 'previous_row') 261 | 262 | qkv = self._qkv_linear(inputs) 263 | qs, ks, vs = tc.chunk(qkv, 3, dim=-1) 264 | 265 | if sampling: 266 | # unbind for memory-efficient append op 267 | qs, ks, vs = map(lambda x: [x.squeeze(1)], [qs, ks, vs]) 268 | if past_kvs is not None: 269 | past_ks, past_vs = past_kvs 270 | past_ks.extend(ks) 271 | past_vs.extend(vs) 272 | ks = past_ks 273 | vs = past_vs 274 | new_kvs = (ks, vs) 275 | else: 276 | if past_kvs is not None: 277 | past_ks, past_vs = past_kvs 278 | ks = tc.cat((past_ks, ks), dim=1) 279 | vs = tc.cat((past_vs, vs), dim=1) 280 | new_kvs = (ks, vs) 281 | 282 | qs, ks, vs, bsp = self.attn_preop(qs, ks, vs, sampling) # [B', ..., H*F] 283 | qs, ks, vs = map(self.split_heads, [qs, ks, vs]) # [B'*H, ..., F] 284 | 285 | if self._position_encoding_style == 'rel': 286 | batch_size, src_len, d_model = bsp, ks.shape[1], inputs.shape[-1] 287 | max_len = src_len 288 | if self._attention_style == 'previous_row': 289 | max_len += qs.shape[1] 290 | r_mat = sinusoidal_embeddings(max_len, d_model, reverse=True) # [M, I] 291 | rs = self._r_linear(r_mat) # [M, H*F] 292 | 293 | rs = tc.tile(rs.unsqueeze(0), [batch_size, 1, 1]) # [B', M, H*F] 294 | u_ = tc.tile(self._u.unsqueeze(0), [batch_size, 1]) # [B', H*F] 295 | v_ = tc.tile(self._v.unsqueeze(0), [batch_size, 1]) # [B', H*F] 296 | rs, u_, v_ = map(self.split_heads, [rs, u_, v_]) # [B'*H, ..., F] 297 | 298 | attn_output = relative_masked_self_attention( 299 | qs, ks, vs, rs, u_, v_, use_mask=use_mask) # [B'*H, T2', F] 300 | else: 301 | attn_output = masked_self_attention( 302 | qs, ks, vs, use_mask=use_mask) # [B'*H, T2', F] 303 | 304 | attn_output = self.merge_heads(attn_output) # [B', T2', H*F] 305 | attn_output = self.attn_postop( 306 | attn_output, 307 | input_len=inputs.shape[1], 308 | sampling=sampling) # [B, T2, H*F] 309 | 310 | return attn_output, new_kvs 311 | -------------------------------------------------------------------------------- /rl2/agents/architectures/common/normalization.py: -------------------------------------------------------------------------------- 1 | """ 2 | Implements normalization layers for RL^2 agents. 3 | """ 4 | 5 | import torch as tc 6 | 7 | 8 | class LayerNorm(tc.nn.Module): 9 | """ 10 | Layer Normalization. 11 | """ 12 | def __init__(self, units): 13 | super().__init__() 14 | self._units = units 15 | self._g = tc.nn.Parameter(tc.ones(units, device='cpu')) 16 | self._b = tc.nn.Parameter(tc.zeros(units, device='cpu')) 17 | 18 | def forward(self, inputs, eps=1e-8): 19 | mu = tc.mean(inputs, dim=-1, keepdim=True) 20 | centered = inputs - mu 21 | sigma2 = tc.mean(tc.square(centered), dim=-1, keepdim=True) 22 | sigma = tc.sqrt(sigma2 + eps) 23 | standardized = centered / sigma 24 | 25 | g, b = self._g, self._b 26 | while len(list(g.shape)) < len(list(inputs.shape)): 27 | g = g.unsqueeze(0) 28 | b = b.unsqueeze(0) 29 | 30 | scaled = g * standardized + b 31 | return scaled 32 | -------------------------------------------------------------------------------- /rl2/agents/architectures/gru.py: -------------------------------------------------------------------------------- 1 | """ 2 | Implements GRU for RL^2. 3 | """ 4 | 5 | from typing import Tuple 6 | 7 | import torch as tc 8 | 9 | from rl2.agents.architectures.common.normalization import LayerNorm 10 | 11 | 12 | class GRU(tc.nn.Module): 13 | def __init__( 14 | self, 15 | input_dim, 16 | hidden_dim, 17 | forget_bias=1.0, 18 | use_ln=True, 19 | reset_after=True 20 | ): 21 | super().__init__() 22 | self._input_dim = input_dim 23 | self._hidden_dim = hidden_dim 24 | self._forget_bias = forget_bias 25 | self._use_ln = use_ln 26 | self._reset_after = reset_after 27 | 28 | self._x2zr = tc.nn.Linear( 29 | in_features=self._input_dim, 30 | out_features=(2 * self._hidden_dim), 31 | bias=(not self._use_ln)) 32 | tc.nn.init.xavier_normal_(self._x2zr.weight) 33 | if not self._use_ln: 34 | tc.nn.init.zeros_(self._x2zr.bias) 35 | 36 | self._h2zr = tc.nn.Linear( 37 | in_features=self._hidden_dim, 38 | out_features=(2 * self._hidden_dim), 39 | bias=False) 40 | tc.nn.init.xavier_normal_(self._h2zr.weight) 41 | 42 | self._x2hhat = tc.nn.Linear( 43 | in_features=self._input_dim, 44 | out_features=self._hidden_dim, 45 | bias=(not self._use_ln)) 46 | tc.nn.init.xavier_normal_(self._x2hhat.weight) 47 | if not self._use_ln: 48 | tc.nn.init.zeros_(self._x2hhat.bias) 49 | 50 | self._h2hhat = tc.nn.Linear( 51 | in_features=self._hidden_dim, 52 | out_features=self._hidden_dim, 53 | bias=False) 54 | tc.nn.init.orthogonal_(self._h2hhat.weight) 55 | 56 | if self._use_ln: 57 | self._x2zr_ln = LayerNorm(units=(2 * self._hidden_dim)) 58 | self._h2zr_ln = LayerNorm(units=(2 * self._hidden_dim)) 59 | self._x2hhat_ln = LayerNorm(units=self._hidden_dim) 60 | self._h2hhat_ln = LayerNorm(units=self._hidden_dim) 61 | 62 | self._initial_state = tc.zeros(self._hidden_dim) 63 | 64 | def initial_state(self, batch_size: int) -> tc.FloatTensor: 65 | """ 66 | Return initial state of zeros. 67 | Args: 68 | batch_size: batch size to tile the initial state by. 69 | Returns: 70 | initial_state FloatTensor. 71 | """ 72 | return self._initial_state.unsqueeze(0).repeat(batch_size, 1) 73 | 74 | @property 75 | def output_dim(self): 76 | return self._hidden_dim 77 | 78 | def forward( 79 | self, 80 | inputs: tc.FloatTensor, 81 | prev_state: tc.FloatTensor 82 | ) -> Tuple[tc.FloatTensor, tc.FloatTensor]: 83 | """ 84 | Run recurrent state update, compute features. 85 | Args: 86 | inputs: input vec tensor with shape [B, ..., ?] 87 | prev_state: prev hidden state w/ shape [B, H]. 88 | Notes: 89 | '...' must be either one dimensional or must not exist. 90 | Returns: 91 | features, new_state. 92 | """ 93 | assert len(list(inputs.shape)) in [2, 3] 94 | if len(list(inputs.shape)) == 2: 95 | inputs = inputs.unsqueeze(1) 96 | 97 | T = inputs.shape[1] 98 | features_by_timestep = [] 99 | state = prev_state 100 | for t in range(0, T): # 0, ..., T-1 101 | h_prev = state 102 | zr_from_x = self._x2zr(inputs[:,t,:]) 103 | zr_from_h = self._h2zr(h_prev) 104 | if self._use_ln: 105 | zr_from_x = self._x2zr_ln(zr_from_x) 106 | zr_from_h = self._h2zr_ln(zr_from_h) 107 | zr = zr_from_x + zr_from_h 108 | z, r = tc.chunk(zr, 2, dim=-1) 109 | 110 | z = tc.nn.Sigmoid()(z-self._forget_bias) 111 | r = tc.nn.Sigmoid()(r) 112 | 113 | if self._reset_after: 114 | hhat_from_x = self._x2hhat(inputs[:,t,:]) 115 | hhat_from_h = self._h2hhat(h_prev) 116 | if self._use_ln: 117 | hhat_from_x = self._x2hhat_ln(hhat_from_x) 118 | hhat_from_h = self._h2hhat_ln(hhat_from_h) 119 | hhat = hhat_from_x + r * hhat_from_h 120 | else: 121 | hhat_from_x = self._x2hhat(inputs[:,t,:]) 122 | hhat_from_h = self._h2hhat(r * h_prev) 123 | if self._use_ln: 124 | hhat_from_x = self._x2hhat_ln(hhat_from_x) 125 | hhat_from_h = self._h2hhat_ln(hhat_from_h) 126 | hhat = hhat_from_x + hhat_from_h 127 | 128 | hhat = tc.nn.ReLU()(hhat) 129 | h_new = (1. - z) * h_prev + z * hhat 130 | 131 | features_by_timestep.append(h_new) 132 | state = h_new 133 | 134 | features = tc.stack(features_by_timestep, dim=1) 135 | if T == 1: 136 | features = features.squeeze(1) 137 | new_state = state 138 | 139 | return features, new_state 140 | -------------------------------------------------------------------------------- /rl2/agents/architectures/lstm.py: -------------------------------------------------------------------------------- 1 | """ 2 | Implements LSTM for RL^2. 3 | """ 4 | 5 | from typing import Tuple 6 | 7 | import torch as tc 8 | 9 | from rl2.agents.architectures.common.normalization import LayerNorm 10 | 11 | 12 | class LSTM(tc.nn.Module): 13 | def __init__( 14 | self, 15 | input_dim, 16 | hidden_dim, 17 | forget_bias=1.0, 18 | use_ln=True 19 | ): 20 | super().__init__() 21 | self._input_dim = input_dim 22 | self._hidden_dim = hidden_dim 23 | self._use_ln = use_ln 24 | self._forget_bias = forget_bias 25 | 26 | self._x2fioj = tc.nn.Linear( 27 | in_features=self._input_dim, 28 | out_features=(4 * self._hidden_dim), 29 | bias=(not self._use_ln)) 30 | tc.nn.init.xavier_normal_(self._x2fioj.weight) 31 | if not self._use_ln: 32 | tc.nn.init.zeros_(self._x2fioj.bias) 33 | 34 | self._h2fioj = tc.nn.Linear( 35 | in_features=self._hidden_dim, 36 | out_features=(4 * self._hidden_dim), 37 | bias=False) 38 | tc.nn.init.xavier_normal_(self._h2fioj.weight) 39 | 40 | if self._use_ln: 41 | self._x2fioj_ln = LayerNorm(units=(4 * self._hidden_dim)) 42 | self._h2fioj_ln = LayerNorm(units=(4 * self._hidden_dim)) 43 | 44 | self._initial_state = tc.zeros(2 * self._hidden_dim) 45 | 46 | def initial_state(self, batch_size: int) -> tc.FloatTensor: 47 | """ 48 | Return initial state of zeros. 49 | Args: 50 | batch_size: batch size to tile the initial state by. 51 | Returns: 52 | initial_state FloatTensor. 53 | """ 54 | return self._initial_state.unsqueeze(0).repeat(batch_size, 1) 55 | 56 | @property 57 | def output_dim(self): 58 | return self._hidden_dim 59 | 60 | def forward( 61 | self, 62 | inputs: tc.FloatTensor, 63 | prev_state: tc.FloatTensor 64 | ) -> Tuple[tc.FloatTensor, tc.FloatTensor]: 65 | """ 66 | Run recurrent state update, compute features. 67 | Args: 68 | inputs: input vec tensor with shape [B, ..., ?] 69 | prev_state: prev lstm state w/ shape [B, 2*H]. 70 | Notes: 71 | '...' must be either one dimensional or must not exist. 72 | Returns: 73 | features, new_state. 74 | """ 75 | assert len(list(inputs.shape)) in [2, 3] 76 | if len(list(inputs.shape)) == 2: 77 | inputs = inputs.unsqueeze(1) 78 | 79 | T = inputs.shape[1] 80 | features_by_timestep = [] 81 | state = prev_state 82 | for t in range(0, T): # 0, ..., T-1 83 | h_prev, c_prev = tc.chunk(state, 2, dim=-1) 84 | fioj_from_x = self._x2fioj(inputs[:,t,:]) 85 | fioj_from_h = self._h2fioj(h_prev) 86 | if self._use_ln: 87 | fioj_from_x = self._x2fioj_ln(fioj_from_x) 88 | fioj_from_h = self._h2fioj_ln(fioj_from_h) 89 | fioj = fioj_from_x + fioj_from_h 90 | f, i, o, j = tc.chunk(fioj, 4, dim=-1) 91 | f = tc.nn.Sigmoid()(f + self._forget_bias) 92 | i = tc.nn.Sigmoid()(i) 93 | o = tc.nn.Sigmoid()(o) 94 | j = tc.nn.ReLU()(j) 95 | c_new = f * c_prev + i * j 96 | h_new = o * c_new 97 | 98 | features_by_timestep.append(h_new) 99 | state = tc.cat((h_new, c_new), dim=-1) 100 | 101 | features = tc.stack(features_by_timestep, dim=1) 102 | if T == 1: 103 | features = features.squeeze(1) 104 | new_state = state 105 | 106 | return features, new_state 107 | -------------------------------------------------------------------------------- /rl2/agents/architectures/snail.py: -------------------------------------------------------------------------------- 1 | """ 2 | Implements SNAIL architecture (Mishra et al., 2017) for RL^2. 3 | """ 4 | 5 | from typing import Optional 6 | 7 | import torch as tc 8 | import numpy as np 9 | 10 | from rl2.agents.architectures.common.normalization import LayerNorm 11 | from rl2.agents.architectures.common.attention import MultiheadSelfAttention 12 | 13 | 14 | class CausalConv(tc.nn.Module): 15 | def __init__( 16 | self, 17 | input_dim, 18 | feature_dim, 19 | kernel_size, 20 | dilation_rate, 21 | use_bias=True 22 | ): 23 | super().__init__() 24 | self._input_dim = input_dim 25 | self._feature_dim = feature_dim 26 | self._kernel_size = kernel_size 27 | self._dilation_rate = dilation_rate 28 | self._use_bias = use_bias 29 | 30 | self._conv = tc.nn.Conv1d( 31 | in_channels=self._input_dim, 32 | out_channels=self._feature_dim, 33 | kernel_size=self._kernel_size, 34 | stride=(1,), 35 | padding=(0,), 36 | dilation=self._dilation_rate, 37 | bias=self._use_bias) 38 | 39 | @property 40 | def effective_kernel_size(self): 41 | k = self._kernel_size 42 | r = self._dilation_rate 43 | return k + (k - 1) * (r - 1) 44 | 45 | def forward( 46 | self, 47 | inputs: tc.FloatTensor, 48 | past_inputs: Optional[tc.FloatTensor] = None 49 | ) -> tc.FloatTensor: 50 | """ 51 | Args: 52 | inputs: present input tensor of shape [B, T2, I] 53 | past_inputs: optional past input tensor of shape [B, T1, I] 54 | 55 | Returns: 56 | Causal convolution of the (padded) present inputs. 57 | The present inputs are padded with past inputs (if any), 58 | and possibly zero padding. 59 | """ 60 | batch_size = inputs.shape[0] 61 | effective_kernel_size = self.effective_kernel_size 62 | 63 | if past_inputs is not None: 64 | t1 = past_inputs.shape[1] 65 | if t1 < effective_kernel_size - 1: 66 | zpl = effective_kernel_size - 1 - t1 67 | zps = (batch_size, zpl, self._input_dim) 68 | zp = tc.zeros(size=zps, dtype=tc.float32) 69 | inputs = tc.cat((zp, past_inputs, inputs), dim=1) 70 | else: 71 | crop_len = effective_kernel_size - 1 72 | cropped_past_inputs = past_inputs[:, -crop_len:, :] 73 | inputs = tc.cat((cropped_past_inputs, inputs), dim=1) 74 | else: 75 | zpl = effective_kernel_size - 1 76 | zps = (batch_size, zpl, self._input_dim) 77 | zp = tc.zeros(size=zps, dtype=tc.float32) 78 | inputs = tc.cat((zp, inputs), dim=1) 79 | 80 | conv = self._conv(inputs.permute(0, 2, 1)).permute(0, 2, 1) 81 | return conv 82 | 83 | 84 | class DenseBlock(tc.nn.Module): 85 | def __init__( 86 | self, 87 | input_dim, 88 | feature_dim, 89 | kernel_size, 90 | dilation_rate, 91 | use_ln=True 92 | ): 93 | super().__init__() 94 | self._input_dim = input_dim 95 | self._feature_dim = feature_dim 96 | self._kernel_size = kernel_size 97 | self._dilation_rate = dilation_rate 98 | self._use_ln = use_ln 99 | 100 | self._conv = CausalConv( 101 | input_dim=self._input_dim, 102 | feature_dim=(2 * self._feature_dim), 103 | kernel_size=self._kernel_size, 104 | dilation_rate=self._dilation_rate, 105 | use_bias=(not self._use_ln)) 106 | 107 | if self._use_ln: 108 | self._conv_ln = LayerNorm(units=(2 * self._feature_dim)) 109 | 110 | def forward(self, inputs, past_inputs=None): 111 | conv = self._conv(inputs=inputs, past_inputs=past_inputs) 112 | 113 | if self._use_ln: 114 | conv = self._conv_ln(conv) 115 | 116 | xg, xf = tc.chunk(conv, 2, dim=-1) 117 | xg, xf = tc.nn.Sigmoid()(xg), tc.nn.Tanh()(xf) 118 | activations = xg * xf 119 | 120 | output = tc.cat((inputs, activations), dim=-1) 121 | return output 122 | 123 | 124 | class TCBlock(tc.nn.Module): 125 | def __init__( 126 | self, 127 | input_dim, 128 | feature_dim, 129 | context_size, 130 | use_ln=True 131 | ): 132 | super().__init__() 133 | self._input_dim = input_dim 134 | self._feature_dim = feature_dim 135 | self._context_size = context_size 136 | self._use_ln = use_ln 137 | 138 | self._dense_blocks = tc.nn.ModuleList([ 139 | DenseBlock( 140 | input_dim=(self._input_dim + l * self._feature_dim), 141 | feature_dim=self._feature_dim, 142 | kernel_size=2, 143 | dilation_rate=2 ** l, 144 | use_ln=self._use_ln) 145 | for l in range(0, self.num_layers) 146 | ]) 147 | 148 | @property 149 | def num_layers(self): 150 | log2_context_size = np.log(self._context_size) / np.log(2) 151 | return int(np.ceil(log2_context_size)) 152 | 153 | @property 154 | def output_dim(self): 155 | return self._input_dim + self.num_layers * self._feature_dim 156 | 157 | @property 158 | def state_dim(self): 159 | return self.output_dim - self._feature_dim 160 | 161 | def forward(self, inputs, past_inputs=None): 162 | """ 163 | Args: 164 | inputs: inputs tensor of shape [B, T2, I] 165 | past_inputs: optional past inputs tensor of shape [B, T1, I+(L-1)*F] 166 | 167 | Returns: 168 | tensor of shape [B, T2, I+L*F] 169 | """ 170 | for l in range(0, self.num_layers): # 0, ..., num_layers-1 171 | if past_inputs is None: 172 | past_inputs_for_layer = None 173 | else: 174 | end_idx = self._input_dim + l * self._feature_dim 175 | past_inputs_for_layer = past_inputs[:, :, 0:end_idx] 176 | 177 | inputs = self._dense_blocks[l]( 178 | inputs=inputs, past_inputs=past_inputs_for_layer) 179 | 180 | return inputs # [B, T2, I+L*F] 181 | 182 | 183 | class SNAIL(tc.nn.Module): 184 | def __init__(self, input_dim, feature_dim, context_size, use_ln=True): 185 | super().__init__() 186 | self._input_dim = input_dim 187 | self._feature_dim = feature_dim 188 | self._context_size = context_size 189 | self._use_ln = use_ln 190 | 191 | self._tc1 = TCBlock( 192 | input_dim=self._input_dim, 193 | feature_dim=self._feature_dim, 194 | context_size=self._context_size, 195 | use_ln=self._use_ln) 196 | 197 | self._tc2 = TCBlock( 198 | input_dim=self._tc1.output_dim, 199 | feature_dim=self._feature_dim, 200 | context_size=self._context_size, 201 | use_ln=self._use_ln) 202 | 203 | self._attn = MultiheadSelfAttention( 204 | input_dim=self._tc2.output_dim, 205 | num_heads=1, 206 | num_head_features=self._feature_dim, 207 | position_encoding_style='abs', 208 | attention_style='full') 209 | 210 | @property 211 | def output_dim(self): 212 | return self._tc2.output_dim + self._feature_dim 213 | 214 | def initial_state(self, batch_size: int) -> None: 215 | return None 216 | 217 | def forward(self, inputs, prev_state=None): 218 | """ 219 | Run state update, compute features. 220 | 221 | Args: 222 | inputs: input vec tensor with shape [B, ..., ?] 223 | prev_state: previous architecture state 224 | 225 | Notes: 226 | '...' must be either one dimensional or must not exist 227 | 228 | Returns: 229 | tuple containing features with shape [B, ..., F], and new_state. 230 | """ 231 | assert len(list(inputs.shape)) in [2, 3] 232 | if len(list(inputs.shape)) == 2: 233 | inputs = inputs.unsqueeze(1) 234 | 235 | if prev_state is None: 236 | tc1_out = self._tc1(inputs=inputs, past_inputs=None) 237 | tc2_out = self._tc2(inputs=tc1_out, past_inputs=None) 238 | attn_out, new_attn_kv = self._attn(inputs=tc2_out, past_kvs=None) 239 | features = tc.cat((tc2_out, attn_out), dim=2) 240 | new_state = (tc2_out, new_attn_kv) 241 | else: 242 | tc1_out = self._tc1( 243 | inputs=inputs, past_inputs=prev_state[0][:, :, 0:self._tc1.state_dim]) 244 | tc2_out = self._tc2( 245 | inputs=tc1_out, past_inputs=prev_state[0][:, :, 0:self._tc2.state_dim]) 246 | attn_out, new_attn_kv = self._attn( 247 | inputs=tc2_out, past_kvs=prev_state[1]) 248 | features = tc.cat((tc2_out, attn_out), dim=2) 249 | new_state = ( 250 | tc.cat((prev_state[0], tc2_out), dim=1), 251 | new_attn_kv 252 | ) 253 | 254 | if features.shape[1] == 1: 255 | features = features.squeeze(1) 256 | 257 | return features, new_state 258 | -------------------------------------------------------------------------------- /rl2/agents/architectures/transformer.py: -------------------------------------------------------------------------------- 1 | """ 2 | Implements Transformer architectures for RL^2. 3 | """ 4 | 5 | import torch as tc 6 | 7 | from rl2.agents.architectures.common.normalization import LayerNorm 8 | from rl2.agents.architectures.common.attention import ( 9 | MultiheadSelfAttention, 10 | sinusoidal_embeddings 11 | ) 12 | 13 | 14 | class FF(tc.nn.Module): 15 | def __init__( 16 | self, 17 | input_dim, 18 | hidden_dim, 19 | output_dim, 20 | activation=tc.nn.ReLU 21 | ): 22 | super().__init__() 23 | self._input_dim = input_dim 24 | self._hidden_dim = hidden_dim 25 | self._output_dim = output_dim 26 | 27 | self._lin1 = tc.nn.Linear( 28 | in_features=self._input_dim, 29 | out_features=self._hidden_dim, 30 | bias=True) 31 | tc.nn.init.xavier_normal_(self._lin1.weight) 32 | tc.nn.init.zeros_(self._lin1.bias) 33 | 34 | self._act = activation() 35 | 36 | self._lin2 = tc.nn.Linear( 37 | in_features=self._hidden_dim, 38 | out_features=self._output_dim, 39 | bias=True) 40 | tc.nn.init.xavier_normal_(self._lin2.weight) 41 | tc.nn.init.zeros_(self._lin2.bias) 42 | 43 | def forward(self, inputs): 44 | """ 45 | Args: 46 | inputs: input vec tensor of shape [B, T2, I] 47 | 48 | Returns: 49 | output tensor of shape [B, T2, O] 50 | """ 51 | x = inputs 52 | x = self._lin1(x) 53 | x = self._act(x) 54 | x = self._lin2(x) 55 | return x 56 | 57 | 58 | class TransformerLayer(tc.nn.Module): 59 | def __init__( 60 | self, 61 | input_dim, 62 | feature_dim, 63 | num_heads, 64 | position_encoding_style, 65 | attention_style, 66 | connection_style, 67 | layer_ordering, 68 | row_len=None, 69 | activation=tc.nn.ReLU 70 | ): 71 | """ 72 | Args: 73 | input_dim: input dimensionality. 74 | feature_dim: feature dimensionality for each sublayer. 75 | num_heads: number of attention heads. 76 | position_encoding_style: one of 'abs', 'rel'. 77 | attention_style: one of 'full', 'row', 'column', or 'previous_row'. 78 | connection_style: one of 'plain', 'residual', 'dense'. 79 | layer_ordering: ordering of activation, function, and normalization. 80 | should be letters chosen from 'a', 'f', 'n' in any order. 81 | letter 'f' cannot be omitted. 82 | row_len: required if attention_style is not 'full' 83 | activation: activation function to use in ff and anywhere else. 84 | """ 85 | assert position_encoding_style in ['abs', 'rel'] 86 | assert attention_style in ['full', 'row', 'previous_row', 'column'] 87 | assert connection_style in ['plain', 'residual', 'dense'] 88 | assert attention_style == 'full' or row_len is not None 89 | assert len(layer_ordering) == len(set(layer_ordering)) 90 | assert set(layer_ordering) <= {'a', 'f', 'n'} 91 | assert 'f' in set(layer_ordering) 92 | 93 | super().__init__() 94 | self._input_dim = input_dim 95 | self._feature_dim = feature_dim 96 | self._num_heads = num_heads 97 | self._num_features_per_head = self._feature_dim // self._num_heads 98 | self._position_encoding_style = position_encoding_style 99 | self._attention_style = attention_style 100 | self._connection_style = connection_style 101 | self._layer_ordering = list(layer_ordering) 102 | self._row_len = row_len 103 | self._activation = activation 104 | 105 | self._attn = MultiheadSelfAttention( 106 | input_dim=self._attn_input_dim, 107 | num_heads=self._num_heads, 108 | num_head_features=self._num_features_per_head, 109 | position_encoding_style=self._position_encoding_style, 110 | attention_style=self._attention_style, 111 | row_len=self._row_len) 112 | self._proj = tc.nn.Linear( 113 | in_features=(self._num_heads * self._num_features_per_head), 114 | out_features=self._feature_dim, 115 | bias=False) 116 | tc.nn.init.xavier_normal_(self._proj.weight) 117 | 118 | self._ff = FF( 119 | input_dim=self._ff_input_dim, 120 | hidden_dim=self._feature_dim, 121 | output_dim=self._feature_dim, 122 | activation=self._activation) 123 | 124 | if 'n' in self._layer_ordering: 125 | if self._layer_ordering.index('n') < self._layer_ordering.index('f'): 126 | self._attn_layer_norm = LayerNorm(units=self._attn_input_dim) 127 | self._ff_layer_norm = LayerNorm(units=self._ff_input_dim) 128 | else: 129 | self._attn_layer_norm = LayerNorm(units=self._feature_dim) 130 | self._ff_layer_norm = LayerNorm(units=self._feature_dim) 131 | 132 | if 'a' in self._layer_ordering: 133 | self._attn_act = self._activation() 134 | self._ff_act = self._activation() 135 | 136 | @property 137 | def _attn_input_dim(self): 138 | return self._input_dim 139 | 140 | @property 141 | def _ff_input_dim(self): 142 | if self._connection_style != 'dense': 143 | return self._feature_dim 144 | return self._attn_input_dim + self._feature_dim 145 | 146 | @property 147 | def output_dim(self): 148 | if self._connection_style != 'dense': 149 | return self._feature_dim 150 | return self._ff_input_dim + self._feature_dim 151 | 152 | def forward(self, inputs, past_kvs=None): 153 | """ 154 | Args: 155 | inputs: input vec tensor of shape [B, T2, I] 156 | past_kvs: optional past kvs 157 | 158 | Returns: 159 | output tensor of shape [B, T2, I], and new kvs 160 | """ 161 | x = inputs 162 | 163 | i = inputs 164 | for letter in self._layer_ordering: 165 | if letter == 'a': 166 | x = self._attn_act(x) 167 | elif letter == 'n': 168 | x = self._attn_layer_norm(x) 169 | elif letter == 'f': 170 | x, new_kvs = self._attn(x, past_kvs=past_kvs) 171 | x = self._proj(x) 172 | if self._connection_style == 'residual': 173 | x += i 174 | else: 175 | raise NotImplementedError 176 | 177 | if self._connection_style == 'dense': 178 | x = tc.cat((i, x), dim=-1) 179 | 180 | i = x 181 | for letter in self._layer_ordering: 182 | if letter == 'a': 183 | x = self._ff_act(x) 184 | elif letter == 'n': 185 | x = self._ff_layer_norm(x) 186 | elif letter == 'f': 187 | x = self._ff(x) 188 | if self._connection_style == 'residual': 189 | x += i 190 | else: 191 | raise NotImplementedError 192 | 193 | if self._connection_style == 'dense': 194 | x = tc.cat((i, x), dim=-1) 195 | 196 | return x, new_kvs 197 | 198 | 199 | class Transformer(tc.nn.Module): 200 | def __init__( 201 | self, 202 | input_dim, 203 | feature_dim, 204 | n_layer, 205 | n_head, 206 | n_context, 207 | position_encoding_style='abs', 208 | attention_style='sparse', 209 | connection_style='dense', 210 | layer_ordering='fn', 211 | input_logic='', 212 | output_logic='', 213 | activation=tc.nn.ReLU 214 | ): 215 | """ 216 | Args: 217 | input_dim: input dimensionality. 218 | feature_dim: feature dimensionality for each sublayer. 219 | n_layer: number of transformer layers. 220 | n_head: number of attention heads. 221 | n_context: meta-episode length. 222 | position_encoding_style: one of 'abs', 'rel'. 223 | attention_style: one of 'full', 'sparse'. 224 | connection_style: one of 'plain', 'residual', 'dense'. 225 | layer_ordering: ordering of activation, function, and normalization. 226 | string should be letters chosen from 'a', 'f', 'n' in any order. 227 | letter 'f' cannot be omitted. 228 | input_logic: ordering of activation and normalization after 229 | input linear projection and prior to first transformer layer. 230 | string should be letters chosen from 'a', 'n' in any order. 231 | string defaults to empty. 232 | output_logic: ordering of activation and normalization before features 233 | are returned and after last transformer layer. 234 | string should be letters chosen from 'a', 'n' in any order. 235 | string defaults to empty. 236 | activation: activation function to use in ff and anywhere else. 237 | """ 238 | assert position_encoding_style in ['abs', 'rel'] 239 | assert attention_style in ['full', 'sparse'] 240 | assert connection_style in ['residual', 'dense'] 241 | assert len(layer_ordering) == len(set(layer_ordering)) 242 | assert set(layer_ordering) <= {'a', 'f', 'n'} 243 | assert 'f' in set(layer_ordering) 244 | assert len(input_logic) == len(set(input_logic)) 245 | assert set(input_logic) <= {'a', 'n'} 246 | assert len(output_logic) == len(set(output_logic)) 247 | assert set(output_logic) <= {'a', 'n'} 248 | 249 | super().__init__() 250 | self._input_dim = input_dim 251 | self._feature_dim = feature_dim 252 | self._n_layer = n_layer 253 | self._n_head = n_head 254 | self._n_context = n_context 255 | self._position_encoding_style = position_encoding_style 256 | self._connection_style = connection_style 257 | self._layer_ordering = list(layer_ordering) 258 | self._input_logic = list(input_logic) 259 | self._output_logic = list(output_logic) 260 | self._activation = activation 261 | 262 | # input 263 | self._input_proj = tc.nn.Linear( 264 | in_features=self._input_dim, 265 | out_features=self._feature_dim) 266 | tc.nn.init.xavier_normal_(self._input_proj.weight) 267 | if 'n' in self._input_logic: 268 | self._input_layer_norm = LayerNorm(units=self._feature_dim) 269 | if 'a' in self._input_logic: 270 | self._input_act = self._activation() 271 | 272 | if self._position_encoding_style == 'abs': 273 | self._position_embeddings = sinusoidal_embeddings( 274 | self._n_context, self._feature_dim, reverse=False) 275 | 276 | # middle 277 | self._transformer_layers = tc.nn.ModuleList([ 278 | TransformerLayer( 279 | input_dim=self._get_input_dim(l), 280 | feature_dim=self._feature_dim, 281 | num_heads=self._n_head, 282 | position_encoding_style=self._position_encoding_style, 283 | attention_style=self._get_attention_style(attention_style, l), 284 | connection_style=self._connection_style, 285 | layer_ordering=''.join(self._layer_ordering), 286 | row_len=self._get_row_len(attention_style), 287 | activation=self._activation) 288 | for l in range(self._n_layer) 289 | ]) 290 | 291 | # output 292 | if 'n' in self._output_logic: 293 | self._output_layer_norm = LayerNorm(units=self.output_dim) 294 | if 'a' in self._output_logic: 295 | self._output_act = self._activation() 296 | 297 | def _get_input_dim(self, l): 298 | if self._connection_style != 'dense': 299 | return self._feature_dim 300 | else: 301 | if self._position_encoding_style == 'abs': 302 | return (2*l+2) * self._feature_dim 303 | return (2*l+1) * self._feature_dim 304 | 305 | def _get_attention_style(self, attention_style, l): 306 | if attention_style == 'full': 307 | return 'full' 308 | sparse_attention_styles = ['row', 'column', 'previous_row'] 309 | return sparse_attention_styles[l % 3] 310 | 311 | def _get_row_len(self, attention_style): 312 | if attention_style == 'full': 313 | return None 314 | small = int(self._n_context ** 0.5) 315 | while self._n_context % small != 0: 316 | small -= 1 317 | return small 318 | 319 | def _get_past_len(self, prev_state): 320 | assert prev_state is None or isinstance(prev_state, list) 321 | if prev_state is None: 322 | return 0 323 | k, _ = prev_state[0] # layer 0, get keys 324 | if isinstance(k, list): 325 | return len(k) 326 | if isinstance(k, tc.Tensor): 327 | return k.shape[1] 328 | raise NotImplementedError 329 | 330 | def _add_position_embeddings(self, inputs, prev_state): 331 | t1 = self._get_past_len(prev_state) 332 | t2 = inputs.shape[1] 333 | assert t1 + t2 <= self._n_context 334 | pos_embs = self._position_embeddings[t1:t1+t2, :] 335 | pos_embs = pos_embs.unsqueeze(0) 336 | if self._connection_style != 'dense': 337 | inputs = inputs + pos_embs 338 | else: 339 | pos_embs = tc.tile(pos_embs, [inputs.shape[0], 1, 1]) 340 | inputs = tc.cat((inputs, pos_embs), dim=-1) 341 | return inputs 342 | 343 | def _run_input_logic(self, inputs): 344 | for letter in self._input_logic: 345 | if letter == 'n': 346 | inputs = self._input_layer_norm(inputs) 347 | elif letter == 'a': 348 | inputs = self._input_act(inputs) 349 | return inputs 350 | 351 | def _run_output_logic(self, inputs): 352 | for letter in self._output_logic: 353 | if letter == 'n': 354 | inputs = self._output_layer_norm(inputs) 355 | elif letter == 'a': 356 | inputs = self._output_act(inputs) 357 | return inputs 358 | 359 | @property 360 | def output_dim(self): 361 | return self._transformer_layers[-1].output_dim 362 | 363 | def initial_state(self, batch_size): 364 | return None 365 | 366 | def forward(self, inputs, prev_state=None): 367 | """ 368 | Args: 369 | inputs: input vec tensor of shape [B, ..., I] 370 | prev_state: optional previous state. 371 | 372 | Notes: 373 | '...' must be either one dimensional or must not exist 374 | 375 | Returns: 376 | output feature tensor and new state. 377 | """ 378 | assert len(list(inputs.shape)) in [2, 3] 379 | if len(list(inputs.shape)) == 2: 380 | inputs = inputs.unsqueeze(1) 381 | 382 | # input 383 | inputs = self._input_proj(inputs) 384 | inputs = self._run_input_logic(inputs) 385 | if self._position_encoding_style == 'abs': 386 | inputs = self._add_position_embeddings(inputs, prev_state) 387 | 388 | # middle 389 | past_kvs = [None] * self._n_layer if prev_state is None else prev_state 390 | new_kvs = [] 391 | for l in range(0, self._n_layer): 392 | inputs, new_kvs_l = self._transformer_layers[l]( 393 | inputs=inputs, past_kvs=past_kvs[l]) 394 | new_kvs.append(new_kvs_l) 395 | 396 | # output 397 | inputs = self._run_output_logic(inputs) 398 | 399 | features = inputs 400 | if features.shape[1] == 1: 401 | features = features.squeeze(1) 402 | 403 | return features, new_kvs 404 | -------------------------------------------------------------------------------- /rl2/agents/heads/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lucaslingle/pytorch_rl2/51afa4b6c6d05b469c7f74076b9490ae0809b289/rl2/agents/heads/__init__.py -------------------------------------------------------------------------------- /rl2/agents/heads/policy_heads.py: -------------------------------------------------------------------------------- 1 | """ 2 | Policy heads for RL^2 agents. 3 | """ 4 | 5 | import torch as tc 6 | 7 | 8 | class LinearPolicyHead(tc.nn.Module): 9 | """ 10 | Policy head for a reinforcement learning agent. 11 | """ 12 | def __init__(self, num_features, num_actions): 13 | super().__init__() 14 | self._num_features = num_features 15 | self._num_actions = num_actions 16 | self._linear = tc.nn.Linear( 17 | in_features=self._num_features, 18 | out_features=self._num_actions, 19 | bias=True) 20 | tc.nn.init.xavier_normal_(self._linear.weight) 21 | tc.nn.init.zeros_(self._linear.bias) 22 | 23 | def forward(self, features: tc.FloatTensor) -> tc.distributions.Categorical: 24 | """ 25 | Computes a policy distribution from features and returns it. 26 | 27 | Args: 28 | features: a tc.FloatTensor of shape [B, ..., F]. 29 | 30 | Returns: 31 | tc.distributions.Categorical over actions, with batch shape [B, ...] 32 | """ 33 | logits = self._linear(features) 34 | dists = tc.distributions.Categorical(logits=logits) 35 | return dists 36 | -------------------------------------------------------------------------------- /rl2/agents/heads/value_heads.py: -------------------------------------------------------------------------------- 1 | """ 2 | Value heads for RL^2 agents. 3 | """ 4 | 5 | import torch as tc 6 | 7 | 8 | class LinearValueHead(tc.nn.Module): 9 | """ 10 | Value head for a reinforcement learning agent. 11 | """ 12 | def __init__(self, num_features): 13 | super().__init__() 14 | self._num_features = num_features 15 | self._linear = tc.nn.Linear( 16 | in_features=self._num_features, 17 | out_features=1, 18 | bias=True) 19 | tc.nn.init.xavier_normal_(self._linear.weight) 20 | tc.nn.init.zeros_(self._linear.bias) 21 | 22 | def forward(self, features: tc.FloatTensor) -> tc.FloatTensor: 23 | """ 24 | Computes a value estimate from features and returns it. 25 | 26 | Args: 27 | features: tc.FloatTensor of features with shape [B, ..., F]. 28 | 29 | Returns: 30 | tc.FloatTensor of value estimates with shape [B, ...]. 31 | """ 32 | v_preds = self._linear(features).squeeze(-1) 33 | return v_preds 34 | -------------------------------------------------------------------------------- /rl2/agents/integration/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lucaslingle/pytorch_rl2/51afa4b6c6d05b469c7f74076b9490ae0809b289/rl2/agents/integration/__init__.py -------------------------------------------------------------------------------- /rl2/agents/integration/policy_net.py: -------------------------------------------------------------------------------- 1 | """ 2 | Implements StatefulPolicyNet class. 3 | """ 4 | 5 | from typing import Union, Tuple, Optional, TypeVar, Generic 6 | 7 | import torch as tc 8 | 9 | 10 | ArchitectureState = TypeVar('ArchitectureState') 11 | 12 | 13 | class StatefulPolicyNet(tc.nn.Module, Generic[ArchitectureState]): 14 | def __init__(self, preprocessing, architecture, policy_head): 15 | super().__init__() 16 | self._preprocessing = preprocessing 17 | self._architecture = architecture 18 | self._policy_head = policy_head 19 | 20 | def initial_state(self, batch_size: int) -> Optional[ArchitectureState]: 21 | return self._architecture.initial_state(batch_size=batch_size) 22 | 23 | def forward( 24 | self, 25 | curr_obs: Union[tc.LongTensor, tc.FloatTensor], 26 | prev_action: tc.LongTensor, 27 | prev_reward: tc.FloatTensor, 28 | prev_done: tc.FloatTensor, 29 | prev_state: Optional[ArchitectureState] 30 | ) -> Tuple[tc.distributions.Categorical, ArchitectureState]: 31 | """ 32 | Runs preprocessing and the architecture's state update; 33 | returns policy distribution(s) and new state. 34 | 35 | Args: 36 | curr_obs: current observation(s) tensor with shape [B, ..., ?]. 37 | prev_action: previous action(s) tensor with shape [B, ...] 38 | prev_reward: previous rewards(s) tensor with shape [B, ...] 39 | prev_done: previous done flag(s) tensor with shape [B, ...] 40 | prev_state: the architecture's previous state. 41 | 42 | Notes: 43 | '...' must be either one dimensional or must not exist 44 | 45 | Returns: 46 | Tuple containing policy distribution(s) with batch shape [B, ...] 47 | and the architecture's new state. 48 | """ 49 | inputs = self._preprocessing( 50 | curr_obs, prev_action, prev_reward, prev_done) 51 | 52 | features, new_state = self._architecture( 53 | inputs=inputs, prev_state=prev_state) 54 | 55 | dist = self._policy_head(features) 56 | 57 | return dist, new_state 58 | -------------------------------------------------------------------------------- /rl2/agents/integration/value_net.py: -------------------------------------------------------------------------------- 1 | """ 2 | Implements StatefulValueNet class. 3 | """ 4 | 5 | from typing import Union, Tuple, Optional, TypeVar, Generic 6 | 7 | import torch as tc 8 | 9 | 10 | ArchitectureState = TypeVar('ArchitectureState') 11 | 12 | 13 | class StatefulValueNet(tc.nn.Module, Generic[ArchitectureState]): 14 | def __init__(self, preprocessing, architecture, value_head): 15 | super().__init__() 16 | self._preprocessing = preprocessing 17 | self._architecture = architecture 18 | self._value_head = value_head 19 | 20 | def initial_state(self, batch_size: int) -> Optional[ArchitectureState]: 21 | return self._architecture.initial_state(batch_size=batch_size) 22 | 23 | def forward( 24 | self, 25 | curr_obs: Union[tc.LongTensor, tc.FloatTensor], 26 | prev_action: tc.LongTensor, 27 | prev_reward: tc.FloatTensor, 28 | prev_done: tc.FloatTensor, 29 | prev_state: Optional[ArchitectureState] 30 | ) -> Tuple[tc.FloatTensor, ArchitectureState]: 31 | """ 32 | Runs preprocessing and the architecture's state update; 33 | returns value estimate(s) and new state. 34 | 35 | Args: 36 | curr_obs: current observation(s) tensor with shape [B, ..., ?]. 37 | prev_action: previous action(s) tensor with shape [B, ...] 38 | prev_reward: previous rewards(s) tensor with shape [B, ...] 39 | prev_done: previous done flag(s) tensor with shape [B, ...] 40 | prev_state: the architecture's previous state. 41 | 42 | Notes: 43 | '...' must be either one dimensional or must not exist 44 | 45 | Returns: 46 | Tuple containing value estimate(s) with batch shape [B, ...] 47 | and the architecture's new state. 48 | """ 49 | inputs = self._preprocessing( 50 | curr_obs, prev_action, prev_reward, prev_done) 51 | 52 | features, new_state = self._architecture( 53 | inputs=inputs, prev_state=prev_state) 54 | 55 | vpred = self._value_head(features) 56 | 57 | return vpred, new_state 58 | -------------------------------------------------------------------------------- /rl2/agents/preprocessing/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lucaslingle/pytorch_rl2/51afa4b6c6d05b469c7f74076b9490ae0809b289/rl2/agents/preprocessing/__init__.py -------------------------------------------------------------------------------- /rl2/agents/preprocessing/common.py: -------------------------------------------------------------------------------- 1 | """ 2 | Implements common agent components used in Duan et al., 2016 3 | - 'RL^2 : Fast Reinforcement Learning via Slow Reinforcement Learning'. 4 | """ 5 | 6 | from typing import Union 7 | import abc 8 | 9 | import torch as tc 10 | 11 | 12 | class Preprocessing(abc.ABC, tc.nn.Module): 13 | def forward( 14 | self, 15 | curr_obs: Union[tc.LongTensor, tc.FloatTensor], 16 | prev_action: tc.LongTensor, 17 | prev_reward: tc.FloatTensor, 18 | prev_done: tc.FloatTensor 19 | ) -> tc.FloatTensor: 20 | """ 21 | Creates an input vector for a meta-learning agent. 22 | 23 | Args: 24 | curr_obs: either tc.LongTensor or tc.FloatTensor of shape [B, ...]. 25 | prev_action: tc.LongTensor of shape [B, ...] 26 | prev_reward: tc.FloatTensor of shape [B, ...] 27 | prev_done: tc.FloatTensor of shape [B, ...] 28 | 29 | Returns: 30 | tc.FloatTensor of shape [B, ..., ?] 31 | """ 32 | pass 33 | 34 | 35 | def one_hot(ys: tc.LongTensor, depth: int) -> tc.FloatTensor: 36 | """ 37 | Applies one-hot encoding to a batch of vectors. 38 | 39 | Args: 40 | ys: tc.LongTensor of shape [B]. 41 | depth: int specifying the number of possible y values. 42 | 43 | Returns: 44 | the one-hot encodings of tensor ys. 45 | """ 46 | 47 | vecs_shape = list(ys.shape) + [depth] 48 | vecs = tc.zeros(dtype=tc.float32, size=vecs_shape) 49 | vecs.scatter_(dim=-1, index=ys.unsqueeze(-1), 50 | src=tc.ones(dtype=tc.float32, size=vecs_shape)) 51 | return vecs.float() -------------------------------------------------------------------------------- /rl2/agents/preprocessing/tabular.py: -------------------------------------------------------------------------------- 1 | """ 2 | Implements preprocessing for tabular MABs and MDPs. 3 | """ 4 | 5 | import torch as tc 6 | 7 | from rl2.agents.preprocessing.common import one_hot, Preprocessing 8 | 9 | 10 | class MABPreprocessing(Preprocessing): 11 | def __init__(self, num_actions: int): 12 | super().__init__() 13 | self._num_actions = num_actions 14 | 15 | @property 16 | def output_dim(self): 17 | return self._num_actions + 2 18 | 19 | def forward( 20 | self, 21 | curr_obs: tc.LongTensor, 22 | prev_action: tc.LongTensor, 23 | prev_reward: tc.FloatTensor, 24 | prev_done: tc.FloatTensor 25 | ) -> tc.FloatTensor: 26 | """ 27 | Creates an input vector for a meta-learning agent. 28 | 29 | Args: 30 | curr_obs: tc.LongTensor of shape [B, ...]; will be ignored. 31 | prev_action: tc.LongTensor of shape [B, ...] 32 | prev_reward: tc.FloatTensor of shape [B, ...] 33 | prev_done: tc.FloatTensor of shape [B, ...] 34 | 35 | Returns: 36 | tc.FloatTensor of shape [B, ..., A+2] 37 | """ 38 | 39 | emb_a = one_hot(prev_action, depth=self._num_actions) 40 | prev_reward = prev_reward.unsqueeze(-1) 41 | prev_done = prev_done.unsqueeze(-1) 42 | vec = tc.cat((emb_a, prev_reward, prev_done), dim=-1).float() 43 | return vec 44 | 45 | 46 | class MDPPreprocessing(Preprocessing): 47 | def __init__(self, num_states: int, num_actions: int): 48 | super().__init__() 49 | self._num_states = num_states 50 | self._num_actions = num_actions 51 | 52 | @property 53 | def output_dim(self): 54 | return self._num_states + self._num_actions + 2 55 | 56 | def forward( 57 | self, 58 | curr_obs: tc.LongTensor, 59 | prev_action: tc.LongTensor, 60 | prev_reward: tc.FloatTensor, 61 | prev_done: tc.FloatTensor 62 | ) -> tc.FloatTensor: 63 | """ 64 | Creates an input vector for a meta-learning agent. 65 | 66 | Args: 67 | curr_obs: tc.FloatTensor of shape [B, ..., C, H, W] 68 | prev_action: tc.LongTensor of shape [B, ...] 69 | prev_reward: tc.FloatTensor of shape [B, ...] 70 | prev_done: tc.FloatTensor of shape [B, ...] 71 | 72 | Returns: 73 | tc.FloatTensor of shape [B, ..., S+A+2] 74 | """ 75 | 76 | emb_o = one_hot(curr_obs, depth=self._num_states) 77 | emb_a = one_hot(prev_action, depth=self._num_actions) 78 | prev_reward = prev_reward.unsqueeze(-1) 79 | prev_done = prev_done.unsqueeze(-1) 80 | vec = tc.cat( 81 | (emb_o, emb_a, prev_reward, prev_done), dim=-1).float() 82 | return vec 83 | -------------------------------------------------------------------------------- /rl2/agents/preprocessing/vision.py: -------------------------------------------------------------------------------- 1 | """ 2 | Implements preprocessing for vision-based MDPs/POMDPs. 3 | """ 4 | 5 | import abc 6 | 7 | import torch as tc 8 | 9 | from rl2.agents.preprocessing.common import one_hot, Preprocessing 10 | 11 | 12 | class VisionNet(abc.ABC, tc.nn.Module): 13 | """ 14 | Vision network abstract class. 15 | """ 16 | @property 17 | @abc.abstractmethod 18 | def output_dim(self) -> int: 19 | pass 20 | 21 | @abc.abstractmethod 22 | def forward(self, curr_obs: tc.FloatTensor) -> tc.FloatTensor: 23 | """ 24 | Embeds visual observations into feature vectors. 25 | 26 | Args: 27 | curr_obs: tc.FloatTensor of shape [B, C, H, W] 28 | 29 | Returns: 30 | a tc.FloatTensor of shape [B, F] 31 | """ 32 | pass 33 | 34 | 35 | class MDPPreprocessing(Preprocessing): 36 | def __init__(self, num_actions: int, vision_net: VisionNet): 37 | super().__init__() 38 | self._num_actions = num_actions 39 | self._vision_net = vision_net 40 | 41 | @property 42 | def output_dim(self): 43 | return self._vision_net.output_dim + self._num_actions + 2 44 | 45 | def forward( 46 | self, 47 | curr_obs: tc.FloatTensor, 48 | prev_action: tc.LongTensor, 49 | prev_reward: tc.FloatTensor, 50 | prev_done: tc.FloatTensor 51 | ) -> tc.FloatTensor: 52 | """ 53 | Creates an input vector for a meta-learning agent. 54 | 55 | Args: 56 | curr_obs: tc.FloatTensor of shape [B, ..., C, H, W] 57 | prev_action: tc.LongTensor of shape [B, ...] 58 | prev_reward: tc.FloatTensor of shape [B, ...] 59 | prev_done: tc.FloatTensor of shape [B, ...] 60 | 61 | Returns: 62 | tc.FloatTensor of shape [B, ..., F+A+2] 63 | """ 64 | 65 | curr_obs_shape = list(curr_obs.shape) 66 | curr_obs = curr_obs.view(-1, *curr_obs_shape[-3:]) 67 | emb_o = self._vision_net(curr_obs) 68 | emb_o = emb_o.view(*curr_obs_shape[:-3], emb_o.shape[-1]) 69 | 70 | emb_a = one_hot(prev_action, depth=self._num_actions) 71 | prev_reward = prev_reward.unsqueeze(-1) 72 | prev_done = prev_done.unsqueeze(-1) 73 | vec = tc.cat( 74 | (emb_o, emb_a, prev_reward, prev_done), dim=-1).float() 75 | return vec 76 | -------------------------------------------------------------------------------- /rl2/algos/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lucaslingle/pytorch_rl2/51afa4b6c6d05b469c7f74076b9490ae0809b289/rl2/algos/__init__.py -------------------------------------------------------------------------------- /rl2/algos/common.py: -------------------------------------------------------------------------------- 1 | """ 2 | Implements common algorithmic components for training 3 | stateful meta-reinforcement learning agents. 4 | """ 5 | 6 | import torch as tc 7 | import numpy as np 8 | 9 | from rl2.envs.abstract import MetaEpisodicEnv 10 | from rl2.agents.integration.policy_net import StatefulPolicyNet 11 | from rl2.agents.integration.value_net import StatefulValueNet 12 | 13 | 14 | class MetaEpisode: 15 | def __init__(self, num_timesteps, dummy_obs): 16 | self.horizon = num_timesteps 17 | self.obs = np.array([dummy_obs for _ in range(self.horizon)]) 18 | self.acs = np.zeros(self.horizon, 'int64') 19 | self.rews = np.zeros(self.horizon, 'float32') 20 | self.dones = np.zeros(self.horizon, 'float32') 21 | self.logpacs = np.zeros(self.horizon, 'float32') 22 | self.vpreds = np.zeros(self.horizon, 'float32') 23 | self.advs = np.zeros(self.horizon, 'float32') 24 | self.tdlam_rets = np.zeros(self.horizon, 'float32') 25 | 26 | 27 | @tc.no_grad() 28 | def generate_meta_episode( 29 | env: MetaEpisodicEnv, 30 | policy_net: StatefulPolicyNet, 31 | value_net: StatefulValueNet, 32 | meta_episode_len: int 33 | ) -> MetaEpisode: 34 | """ 35 | Generates a meta-episode: a sequence of episodes concatenated together, 36 | with decisions being made by a recurrent agent with state preserved 37 | across episode boundaries. 38 | 39 | Args: 40 | env: environment. 41 | policy_net: policy network. 42 | value_net: value network. 43 | meta_episode_len: timesteps per meta-episode. 44 | 45 | Returns: 46 | meta_episode: an instance of the meta-episode class. 47 | """ 48 | 49 | env.new_env() 50 | meta_episode = MetaEpisode( 51 | num_timesteps=meta_episode_len, 52 | dummy_obs=env.reset()) 53 | 54 | o_t = np.array([env.reset()]) 55 | a_tm1 = np.array([0]) 56 | r_tm1 = np.array([0.0]) 57 | d_tm1 = np.array([1.0]) 58 | h_tm1_policy_net = policy_net.initial_state(batch_size=1) 59 | h_tm1_value_net = value_net.initial_state(batch_size=1) 60 | 61 | for t in range(0, meta_episode_len): 62 | pi_dist_t, h_t_policy_net = policy_net( 63 | curr_obs=tc.LongTensor(o_t), 64 | prev_action=tc.LongTensor(a_tm1), 65 | prev_reward=tc.FloatTensor(r_tm1), 66 | prev_done=tc.FloatTensor(d_tm1), 67 | prev_state=h_tm1_policy_net) 68 | 69 | vpred_t, h_t_value_net = value_net( 70 | curr_obs=tc.LongTensor(o_t), 71 | prev_action=tc.LongTensor(a_tm1), 72 | prev_reward=tc.FloatTensor(r_tm1), 73 | prev_done=tc.FloatTensor(d_tm1), 74 | prev_state=h_tm1_value_net) 75 | 76 | a_t = pi_dist_t.sample() 77 | log_prob_a_t = pi_dist_t.log_prob(a_t) 78 | 79 | o_tp1, r_t, done_t, _ = env.step( 80 | action=a_t.squeeze(0).detach().numpy(), 81 | auto_reset=True) 82 | 83 | meta_episode.obs[t] = o_t[0] 84 | meta_episode.acs[t] = a_t.squeeze(0).detach().numpy() 85 | meta_episode.rews[t] = r_t 86 | meta_episode.dones[t] = float(done_t) 87 | meta_episode.logpacs[t] = log_prob_a_t.squeeze(0).detach().numpy() 88 | meta_episode.vpreds[t] = vpred_t.squeeze(0).detach().numpy() 89 | 90 | o_t = np.array([o_tp1]) 91 | a_tm1 = np.array([meta_episode.acs[t]]) 92 | r_tm1 = np.array([meta_episode.rews[t]]) 93 | d_tm1 = np.array([meta_episode.dones[t]]) 94 | h_tm1_policy_net = h_t_policy_net 95 | h_tm1_value_net = h_t_value_net 96 | 97 | return meta_episode 98 | 99 | 100 | @tc.no_grad() 101 | def assign_credit( 102 | meta_episode: MetaEpisode, 103 | gamma: float, 104 | lam: float 105 | ) -> MetaEpisode: 106 | """ 107 | Compute td lambda returns and generalized advantage estimates. 108 | 109 | Note that in the meta-episodic setting of RL^2, the objective is 110 | to maximize the expected discounted return of the meta-episode, 111 | so we do not utilize the usual 'done' masking in this function. 112 | 113 | Args: 114 | meta_episode: meta-episode. 115 | gamma: discount factor. 116 | lam: GAE decay parameter. 117 | 118 | Returns: 119 | meta_episode: an instance of the meta-episode class, 120 | with generalized advantage estimates and td lambda returns computed. 121 | """ 122 | T = len(meta_episode.acs) 123 | for t in reversed(range(0, T)): # T-1, ..., 0. 124 | r_t = meta_episode.rews[t] 125 | V_t = meta_episode.vpreds[t] 126 | V_tp1 = meta_episode.vpreds[t+1] if t+1 < T else 0.0 127 | A_tp1 = meta_episode.advs[t+1] if t+1 < T else 0.0 128 | delta_t = -V_t + r_t + gamma * V_tp1 129 | A_t = delta_t + gamma * lam * A_tp1 130 | meta_episode.advs[t] = A_t 131 | 132 | meta_episode.tdlam_rets = meta_episode.vpreds + meta_episode.advs 133 | return meta_episode 134 | 135 | 136 | def huber_func(y_pred, y_true, delta=1.0): 137 | a = y_pred-y_true 138 | a_abs = tc.abs(a) 139 | a2 = tc.square(a) 140 | terms = tc.where( 141 | tc.less(a_abs, delta * tc.ones_like(a2)), 142 | 0.5 * a2, 143 | delta * (a_abs - 0.5 * delta) 144 | ) 145 | return terms 146 | -------------------------------------------------------------------------------- /rl2/algos/ppo.py: -------------------------------------------------------------------------------- 1 | """ 2 | Implements ppo loss computations for training 3 | stateful meta-reinforcement learning agents. 4 | """ 5 | 6 | from typing import List, Dict, Optional, Callable 7 | from collections import deque 8 | 9 | import torch as tc 10 | import numpy as np 11 | from mpi4py import MPI 12 | 13 | from rl2.envs.abstract import MetaEpisodicEnv 14 | from rl2.agents.integration.policy_net import StatefulPolicyNet 15 | from rl2.agents.integration.value_net import StatefulValueNet 16 | from rl2.algos.common import ( 17 | MetaEpisode, 18 | generate_meta_episode, 19 | assign_credit, 20 | huber_func, 21 | ) 22 | from rl2.utils.comm_util import sync_grads 23 | from rl2.utils.constants import ROOT_RANK 24 | 25 | 26 | def compute_losses( 27 | meta_episodes: List[MetaEpisode], 28 | policy_net: StatefulPolicyNet, 29 | value_net: StatefulValueNet, 30 | clip_param: float, 31 | ent_coef: float 32 | ) -> Dict[str, tc.Tensor]: 33 | """ 34 | Computes the losses for Proximal Policy Optimization. 35 | 36 | Args: 37 | meta_episodes: list of meta-episodes. 38 | policy_net: policy network. 39 | value_net: value network. 40 | clip_param: clip parameter for PPO. 41 | ent_coef: entropy coefficient for PPO. 42 | 43 | Returns: 44 | loss_dict: a dictionary of losses. 45 | """ 46 | def get_tensor(field, dtype=None): 47 | mb_field = np.stack( 48 | list(map(lambda metaep: getattr(metaep, field), meta_episodes)), 49 | axis=0) 50 | if dtype == 'long': 51 | return tc.LongTensor(mb_field) 52 | return tc.FloatTensor(mb_field) 53 | 54 | # minibatch data tensors 55 | mb_obs = get_tensor('obs', 'long') 56 | mb_acs = get_tensor('acs', 'long') 57 | mb_rews = get_tensor('rews') 58 | mb_dones = get_tensor('dones') 59 | mb_logpacs = get_tensor('logpacs') 60 | mb_advs = get_tensor('advs') 61 | mb_tdlam_rets = get_tensor('tdlam_rets') 62 | 63 | # input for loss calculations 64 | B = len(meta_episodes) 65 | ac_dummy = tc.zeros(dtype=tc.int64, size=(B,)) 66 | rew_dummy = tc.zeros(dtype=tc.float32, size=(B,)) 67 | done_dummy = tc.ones(dtype=tc.float32, size=(B,)) 68 | 69 | curr_obs = mb_obs 70 | prev_action = tc.cat((ac_dummy.unsqueeze(1), mb_acs[:, 0:-1]), dim=1) 71 | prev_reward = tc.cat((rew_dummy.unsqueeze(1), mb_rews[:, 0:-1]), dim=1) 72 | prev_done = tc.cat((done_dummy.unsqueeze(1), mb_dones[:, 0:-1]), dim=1) 73 | prev_state_policy_net = policy_net.initial_state(batch_size=B) 74 | prev_state_value_net = value_net.initial_state(batch_size=B) 75 | 76 | # forward pass implements unroll for recurrent/attentive architectures. 77 | pi_dists, _ = policy_net( 78 | curr_obs=curr_obs, 79 | prev_action=prev_action, 80 | prev_reward=prev_reward, 81 | prev_done=prev_done, 82 | prev_state=prev_state_policy_net) 83 | 84 | vpreds, _ = value_net( 85 | curr_obs=curr_obs, 86 | prev_action=prev_action, 87 | prev_reward=prev_reward, 88 | prev_done=prev_done, 89 | prev_state=prev_state_value_net) 90 | 91 | entropies = pi_dists.entropy() 92 | logpacs_new = pi_dists.log_prob(mb_acs) 93 | vpreds_new = vpreds 94 | 95 | # entropy bonus 96 | meanent = tc.mean(entropies) 97 | policy_entropy_bonus = ent_coef * meanent 98 | 99 | # policy surrogate objective 100 | policy_ratios = tc.exp(logpacs_new - mb_logpacs) 101 | clipped_policy_ratios = tc.clip(policy_ratios, 1-clip_param, 1+clip_param) 102 | surr1 = mb_advs * policy_ratios 103 | surr2 = mb_advs * clipped_policy_ratios 104 | policy_surrogate_objective = tc.mean(tc.min(surr1, surr2)) 105 | 106 | # composite policy loss 107 | policy_loss = -(policy_surrogate_objective + policy_entropy_bonus) 108 | 109 | # value loss 110 | value_loss = tc.mean(huber_func(mb_tdlam_rets, vpreds_new)) 111 | 112 | # clipfrac 113 | clipfrac = tc.mean(tc.greater(surr1, surr2).float()) 114 | 115 | return { 116 | "policy_loss": policy_loss, 117 | "value_loss": value_loss, 118 | "meanent": meanent, 119 | "clipfrac": clipfrac 120 | } 121 | 122 | 123 | def training_loop( 124 | env: MetaEpisodicEnv, 125 | policy_net: StatefulPolicyNet, 126 | value_net: StatefulValueNet, 127 | policy_optimizer: tc.optim.Optimizer, 128 | value_optimizer: tc.optim.Optimizer, 129 | policy_scheduler: Optional[tc.optim.lr_scheduler._LRScheduler], # pylint: disable=W0212 130 | value_scheduler: Optional[tc.optim.lr_scheduler._LRScheduler], # pylint: disable=W0212 131 | meta_episodes_per_policy_update: int, 132 | meta_episodes_per_learner_batch: int, 133 | meta_episode_len: int, 134 | ppo_opt_epochs: int, 135 | ppo_clip_param: float, 136 | ppo_ent_coef: float, 137 | discount_gamma: float, 138 | gae_lambda: float, 139 | standardize_advs: bool, 140 | max_pol_iters: int, 141 | pol_iters_so_far: int, 142 | policy_checkpoint_fn: Callable[[int], None], 143 | value_checkpoint_fn: Callable[[int], None], 144 | comm: type(MPI.COMM_WORLD), 145 | ) -> None: 146 | """ 147 | Train a stateful RL^2 agent via PPO to maximize discounted cumulative reward 148 | in Tabular MDPs, sampled from the distribution used in Duan et al., 2016. 149 | 150 | Args: 151 | env: environment. 152 | policy_net: policy network. 153 | value_net: value network, 154 | policy_optimizer: policy optimizer. 155 | value_optimizer: value optimizer. 156 | policy_scheduler: policy lr scheduler. 157 | value_scheduler: value lr scheduler. 158 | meta_episodes_per_policy_update: meta-episodes per policy improvement, 159 | on each process. 160 | meta_episodes_per_learner_batch: meta-episodes per batch on each process. 161 | meta_episode_len: timesteps per meta-episode. 162 | ppo_opt_epochs: optimization epochs for proximal policy optimization. 163 | ppo_clip_param: clip parameter for proximal policy optimization. 164 | ppo_ent_coef: entropy bonus coefficient for proximal policy optimization 165 | discount_gamma: discount factor gamma. 166 | gae_lambda: decay parameter lambda for generalized advantage estimation. 167 | standardize_advs: standardize advantages to mean 0 and stddev 1? 168 | max_pol_iters: the maximum number policy improvements to make. 169 | pol_iters_so_far: the number of policy improvements made so far. 170 | policy_checkpoint_fn: a callback for saving checkpoints of policy net. 171 | value_checkpoint_fn: a callback for saving checkpoints of value net. 172 | comm: mpi comm_world communicator object. 173 | 174 | Returns: 175 | None 176 | """ 177 | meta_ep_returns = deque(maxlen=1000) 178 | 179 | for pol_iter in range(pol_iters_so_far, max_pol_iters): 180 | # collect meta-episodes... 181 | meta_episodes = list() 182 | for _ in range(0, meta_episodes_per_policy_update): 183 | # collect one meta-episode and append it to the list 184 | meta_episode = generate_meta_episode( 185 | env=env, 186 | policy_net=policy_net, 187 | value_net=value_net, 188 | meta_episode_len=meta_episode_len) 189 | meta_episode = assign_credit( 190 | meta_episode=meta_episode, 191 | gamma=discount_gamma, 192 | lam=gae_lambda) 193 | meta_episodes.append(meta_episode) 194 | 195 | # logging 196 | l_meta_ep_returns = [np.sum(meta_episode.rews)] 197 | g_meta_ep_returns = comm.allgather(l_meta_ep_returns) 198 | g_meta_ep_returns = [x for loc in g_meta_ep_returns for x in loc] 199 | meta_ep_returns.extend(g_meta_ep_returns) 200 | 201 | # maybe standardize advantages... 202 | if standardize_advs: 203 | num_procs = comm.Get_size() 204 | adv_eps = 1e-8 205 | 206 | l_advs = list(map(lambda m: m.advs, meta_episodes)) 207 | l_adv_mu = np.mean(l_advs) 208 | g_adv_mu = comm.allreduce(l_adv_mu, op=MPI.SUM) / num_procs 209 | 210 | l_advs_centered = list(map(lambda adv: adv - g_adv_mu, l_advs)) 211 | l_adv_sigma2 = np.var(l_advs_centered) 212 | g_adv_sigma2 = comm.allreduce(l_adv_sigma2, op=MPI.SUM) / num_procs 213 | g_adv_sigma = np.sqrt(g_adv_sigma2) + adv_eps 214 | 215 | l_advs_standardized = list(map(lambda adv: adv / g_adv_sigma, l_advs_centered)) 216 | for m, a in zip(meta_episodes, l_advs_standardized): 217 | setattr(m, 'advs', a) 218 | setattr(m, 'tdlam_rets', m.vpreds + a) 219 | 220 | if comm.Get_rank() == ROOT_RANK: 221 | mean_adv_r0 = np.mean( 222 | list(map(lambda m: m.advs, meta_episodes))) 223 | print(f"Mean advantage: {mean_adv_r0}") 224 | 225 | # update policy... 226 | for opt_epoch in range(ppo_opt_epochs): 227 | idxs = np.random.permutation(meta_episodes_per_policy_update) 228 | for i in range(0, meta_episodes_per_policy_update, meta_episodes_per_learner_batch): 229 | mb_idxs = idxs[i:i+meta_episodes_per_learner_batch] 230 | mb_meta_eps = [meta_episodes[idx] for idx in mb_idxs] 231 | losses = compute_losses( 232 | meta_episodes=mb_meta_eps, 233 | policy_net=policy_net, 234 | value_net=value_net, 235 | clip_param=ppo_clip_param, 236 | ent_coef=ppo_ent_coef) 237 | 238 | policy_optimizer.zero_grad() 239 | losses['policy_loss'].backward() 240 | sync_grads(model=policy_net, comm=comm) 241 | policy_optimizer.step() 242 | if policy_scheduler: 243 | policy_scheduler.step() 244 | 245 | value_optimizer.zero_grad() 246 | losses['value_loss'].backward() 247 | sync_grads(model=value_net, comm=comm) 248 | value_optimizer.step() 249 | if value_scheduler: 250 | value_scheduler.step() 251 | 252 | # logging 253 | global_losses = {} 254 | for name in losses: 255 | loss_sum = comm.allreduce(losses[name], op=MPI.SUM) 256 | loss_avg = loss_sum / comm.Get_size() 257 | global_losses[name] = loss_avg 258 | 259 | if comm.Get_rank() == ROOT_RANK: 260 | print(f"pol update {pol_iter}, opt_epoch: {opt_epoch}...") 261 | for name, value in global_losses.items(): 262 | print(f"\t{name}: {value:>0.6f}") 263 | 264 | # misc.: print metrics, save checkpoint. 265 | if comm.Get_rank() == ROOT_RANK: 266 | print("-" * 100) 267 | print(f"mean meta-episode return: {np.mean(meta_ep_returns):>0.3f}") 268 | print("-" * 100) 269 | policy_checkpoint_fn(pol_iter + 1) 270 | value_checkpoint_fn(pol_iter + 1) 271 | -------------------------------------------------------------------------------- /rl2/envs/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lucaslingle/pytorch_rl2/51afa4b6c6d05b469c7f74076b9490ae0809b289/rl2/envs/__init__.py -------------------------------------------------------------------------------- /rl2/envs/abstract.py: -------------------------------------------------------------------------------- 1 | """ 2 | Implements abstract class for meta-reinforcement learning environments. 3 | """ 4 | 5 | from typing import Generic, TypeVar, Tuple 6 | import abc 7 | 8 | 9 | ObsType = TypeVar('ObsType') 10 | 11 | 12 | class MetaEpisodicEnv(abc.ABC, Generic[ObsType]): 13 | @property 14 | @abc.abstractmethod 15 | def max_episode_len(self) -> int: 16 | """ 17 | Return the maximum episode length. 18 | """ 19 | pass 20 | 21 | @abc.abstractmethod 22 | def new_env(self) -> None: 23 | """ 24 | Reset the environment's structure by resampling 25 | the state transition probabilities and/or reward function 26 | from a prior distribution. 27 | 28 | Returns: 29 | None 30 | """ 31 | pass 32 | 33 | @abc.abstractmethod 34 | def reset(self) -> ObsType: 35 | """ 36 | Resets the environment's state to some designated initial state. 37 | This is distinct from resetting the environment's structure 38 | via self.new_env(). 39 | 40 | Returns: 41 | initial observation. 42 | """ 43 | pass 44 | 45 | @abc.abstractmethod 46 | def step( 47 | self, 48 | action: int, 49 | auto_reset: bool = True 50 | ) -> Tuple[ObsType, float, bool, dict]: 51 | """ 52 | Step the env. 53 | 54 | Args: 55 | action: integer action indicating which action to take 56 | auto_reset: whether or not to automatically reset the environment 57 | on done. if true, next observation will be given by self.reset() 58 | 59 | Returns: 60 | next observation, reward, and done flat 61 | """ 62 | pass 63 | -------------------------------------------------------------------------------- /rl2/envs/bandit_env.py: -------------------------------------------------------------------------------- 1 | """ 2 | Implements the Bernoulli bandit environment from Duan et al., 2016 3 | - 'RL^2 : Fast Reinforcement Learning via Slow Reinforcement Learning'. 4 | """ 5 | 6 | from typing import Tuple 7 | 8 | import numpy as np 9 | 10 | from rl2.envs.abstract import MetaEpisodicEnv 11 | 12 | 13 | class BanditEnv(MetaEpisodicEnv): 14 | """ 15 | An environment where each step is a pull from a multi-armed bandit. 16 | The bandit is a stationary bernoulli bandit. 17 | The environment can be reset so that a new MAB replaces the old one, 18 | thus permitting memory-augmented agents to meta-learn exploration 19 | strategies that generalize across environments in the distribution. 20 | """ 21 | def __init__(self, num_actions): 22 | self._num_actions = num_actions 23 | self._state = None 24 | self._payout_probabilities = None 25 | self.new_env() 26 | 27 | @property 28 | def max_episode_len(self): 29 | return 1 30 | 31 | @property 32 | def num_actions(self): 33 | """Get num_actions.""" 34 | return self._num_actions 35 | 36 | def _new_payout_probabilities(self): 37 | """ 38 | Samples a p_i ~ Uniform[0,1] to determine new payout probs for each arm. 39 | Returns: 40 | None 41 | """ 42 | self._payout_probabilities = np.random.uniform( 43 | low=0.0, high=1.0, size=self._num_actions) 44 | 45 | def new_env(self) -> None: 46 | """ 47 | Sample a new multi-armed bandit problem from distribution over problems. 48 | 49 | Returns: 50 | None 51 | """ 52 | self._new_payout_probabilities() 53 | 54 | def reset(self) -> int: 55 | """ 56 | Reset the environment. For MAB problems, the env is stateless, 57 | and this has no effect and is only included for compatibility 58 | with MetaEpisodicEnv abstract class. 59 | 60 | Returns: 61 | initial state. 62 | """ 63 | self._state = 0 64 | return self._state 65 | 66 | def step(self, action, auto_reset=True) -> Tuple[int, float, bool, dict]: 67 | """ 68 | Pull one arm of the multi-armed bandit, and observe one outcome. 69 | Args: 70 | action: action corresponding to an arm index. 71 | auto_reset: auto reset. if true, new_state will be from self.reset() 72 | 73 | Returns: 74 | new_state, reward, done, info. 75 | """ 76 | 77 | # bernoulli bandit 78 | reward = np.random.binomial( 79 | n=1, p=self._payout_probabilities[action], size=1)[0] 80 | 81 | new_state = self._state 82 | done = True 83 | if done and auto_reset: 84 | new_state = self.reset() 85 | info = {} 86 | return new_state, reward, done, info 87 | -------------------------------------------------------------------------------- /rl2/envs/mdp_env.py: -------------------------------------------------------------------------------- 1 | """ 2 | Implements the Tabular MDP environment(s) from Duan et al., 2016 3 | - 'RL^2 : Fast Reinforcement Learning via Slow Reinforcement Learning'. 4 | """ 5 | 6 | from typing import Tuple 7 | 8 | import numpy as np 9 | 10 | from rl2.envs.abstract import MetaEpisodicEnv 11 | 12 | 13 | class MDPEnv(MetaEpisodicEnv): 14 | """ 15 | Tabular MDP env with support for resettable MDP params (new meta-episode), 16 | in addition to the usual reset (new episode). 17 | """ 18 | def __init__(self, num_states, num_actions, max_episode_length=10): 19 | # structural 20 | self._num_states = num_states 21 | self._num_actions = num_actions 22 | self._max_ep_length = max_episode_length 23 | 24 | # per-environment-sample quantities. 25 | self._reward_means = None 26 | self._state_transition_probabilities = None 27 | self.new_env() 28 | 29 | # mdp state. 30 | self._ep_steps_so_far = 0 31 | self._state = 0 32 | 33 | @property 34 | def max_episode_len(self): 35 | return self._max_ep_length 36 | 37 | @property 38 | def num_actions(self): 39 | """Get self._num_actions.""" 40 | return self._num_actions 41 | 42 | @property 43 | def num_states(self): 44 | """Get self._num_states.""" 45 | return self._num_states 46 | 47 | def _new_reward_means(self): 48 | self._reward_means = np.random.normal( 49 | loc=1.0, scale=1.0, size=(self._num_states, self._num_actions)) 50 | 51 | def _new_state_transition_dynamics(self): 52 | p_aijs = [] 53 | for a in range(self._num_actions): 54 | dirichlet_samples_ij = np.random.dirichlet( 55 | alpha=np.ones(dtype=np.float32, shape=(self._num_states,)), 56 | size=(self._num_states,)) 57 | p_aijs.append(dirichlet_samples_ij) 58 | self._state_transition_probabilities = np.stack(p_aijs, axis=0) 59 | 60 | def new_env(self) -> None: 61 | """ 62 | Sample a new MDP from the distribution over MDPs. 63 | 64 | Returns: 65 | None 66 | """ 67 | self._new_reward_means() 68 | self._new_state_transition_dynamics() 69 | self._state = 0 70 | 71 | def reset(self) -> int: 72 | """ 73 | Reset the environment. 74 | 75 | Returns: 76 | initial state. 77 | """ 78 | self._ep_steps_so_far = 0 79 | self._state = 0 80 | return self._state 81 | 82 | def step(self, action, auto_reset=True) -> Tuple[int, float, bool, dict]: 83 | """ 84 | Take action in the MDP, and observe next state, reward, done, etc. 85 | 86 | Args: 87 | action: action corresponding to an arm index. 88 | auto_reset: auto reset. if true, new_state will be from self.reset() 89 | 90 | Returns: 91 | new_state, reward, done, info. 92 | """ 93 | self._ep_steps_so_far += 1 94 | t = self._ep_steps_so_far 95 | 96 | s_t = self._state 97 | a_t = action 98 | 99 | s_tp1 = np.random.choice( 100 | a=self._num_states, 101 | p=self._state_transition_probabilities[a_t, s_t]) 102 | self._state = s_tp1 103 | 104 | r_t = np.random.normal( 105 | loc=self._reward_means[s_t, a_t], 106 | scale=1.0) 107 | 108 | done_t = False if t < self._max_ep_length else True 109 | if done_t and auto_reset: 110 | s_tp1 = self.reset() 111 | 112 | return s_tp1, r_t, done_t, {} 113 | -------------------------------------------------------------------------------- /rl2/utils/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lucaslingle/pytorch_rl2/51afa4b6c6d05b469c7f74076b9490ae0809b289/rl2/utils/__init__.py -------------------------------------------------------------------------------- /rl2/utils/checkpoint_util.py: -------------------------------------------------------------------------------- 1 | """ 2 | Utility module for saving and loading checkpoints. 3 | """ 4 | 5 | import os 6 | 7 | import torch as tc 8 | 9 | 10 | def _format_name(kind, steps): 11 | filename = f"{kind}_{steps}.pth" 12 | return filename 13 | 14 | 15 | def _parse_name(filename): 16 | kind, steps = filename.split(".")[0].split("_") 17 | steps = int(steps) 18 | return { 19 | "kind": kind, 20 | "steps": steps 21 | } 22 | 23 | 24 | def _latest_n_checkpoint_steps(base_path, n=5): 25 | steps = set(map(lambda x: _parse_name(x)['steps'], os.listdir(base_path))) 26 | latest_steps = sorted(steps) 27 | latest_n = latest_steps[-n:] 28 | return latest_n 29 | 30 | 31 | def _latest_step(base_path): 32 | return _latest_n_checkpoint_steps(base_path, n=1)[-1] 33 | 34 | 35 | def save_checkpoint( 36 | steps, 37 | checkpoint_dir, 38 | model_name, 39 | model, 40 | optimizer, 41 | scheduler 42 | ): 43 | """ 44 | Saves a checkpoint of the latest model, optimizer, scheduler state. 45 | Also tidies up checkpoint_dir/model_name/ by keeping only last 5 ckpts. 46 | 47 | Args: 48 | steps: num steps for the checkpoint to save. 49 | checkpoint_dir: checkpoint dir for checkpointing. 50 | model_name: model name for checkpointing. 51 | model: model to be updated from checkpoint. 52 | optimizer: optimizer to be updated from checkpoint. 53 | scheduler: scheduler to be updated from checkpoint. 54 | 55 | Returns: 56 | None 57 | """ 58 | base_path = os.path.join(checkpoint_dir, model_name) 59 | os.makedirs(base_path, exist_ok=True) 60 | 61 | model_path = os.path.join(base_path, _format_name('model', steps)) 62 | optim_path = os.path.join(base_path, _format_name('optimizer', steps)) 63 | sched_path = os.path.join(base_path, _format_name('scheduler', steps)) 64 | 65 | # save everything 66 | tc.save(model.state_dict(), model_path) 67 | tc.save(optimizer.state_dict(), optim_path) 68 | if scheduler is not None: 69 | tc.save(scheduler.state_dict(), sched_path) 70 | 71 | # keep only last n checkpoints 72 | latest_n_steps = _latest_n_checkpoint_steps(base_path, n=5) 73 | for file in os.listdir(base_path): 74 | if _parse_name(file)['steps'] not in latest_n_steps: 75 | os.remove(os.path.join(base_path, file)) 76 | 77 | 78 | def maybe_load_checkpoint( 79 | checkpoint_dir, 80 | model_name, 81 | model, 82 | optimizer, 83 | scheduler, 84 | steps 85 | ): 86 | """ 87 | Tries to load a checkpoint from checkpoint_dir/model_name/. 88 | If there isn't one, it fails gracefully, allowing the script to proceed 89 | from a newly initialized model. 90 | 91 | Args: 92 | checkpoint_dir: checkpoint dir for checkpointing. 93 | model_name: model name for checkpointing. 94 | model: model to be updated from checkpoint. 95 | optimizer: optimizer to be updated from checkpoint. 96 | scheduler: scheduler to be updated from checkpoint. 97 | steps: num steps for the checkpoint to locate. if none, use latest. 98 | 99 | Returns: 100 | number of env steps experienced by loaded checkpoint. 101 | """ 102 | base_path = os.path.join(checkpoint_dir, model_name) 103 | try: 104 | if steps is None: 105 | steps = _latest_step(base_path) 106 | 107 | model_path = os.path.join(base_path, _format_name('model', steps)) 108 | optim_path = os.path.join(base_path, _format_name('optimizer', steps)) 109 | sched_path = os.path.join(base_path, _format_name('scheduler', steps)) 110 | 111 | model.load_state_dict(tc.load(model_path)) 112 | optimizer.load_state_dict(tc.load(optim_path)) 113 | if scheduler is not None: 114 | scheduler.load_state_dict(tc.load(sched_path)) 115 | 116 | print(f"Loaded checkpoint from {base_path}, with step {steps}.") 117 | print("Continuing from checkpoint.") 118 | except FileNotFoundError: 119 | print(f"Bad checkpoint or none at {base_path} with step {steps}.") 120 | print("Running from scratch.") 121 | steps = 0 122 | 123 | return steps 124 | -------------------------------------------------------------------------------- /rl2/utils/comm_util.py: -------------------------------------------------------------------------------- 1 | """ 2 | Utility module for synchronizing state across processes. 3 | """ 4 | 5 | from mpi4py import MPI 6 | import torch as tc 7 | import numpy as np 8 | 9 | 10 | def get_comm(): 11 | comm = MPI.COMM_WORLD 12 | return comm 13 | 14 | 15 | @tc.no_grad() 16 | def sync_state(model, optimizer, scheduler, comm, root): 17 | """ 18 | Synchronize state of a model, its optimizer, and possibly its scheduler, 19 | using MPI. 20 | 21 | Args: 22 | model: model. 23 | optimizer: optimizer for the model. 24 | scheduler: optional lr scheduler for the model. 25 | comm: mpi4py comm_world object. 26 | root: root mpi process rank to broadcast from. 27 | 28 | Returns: 29 | None 30 | """ 31 | model_state_dict = comm.bcast(model.state_dict(), root=root) 32 | optimizer_state_dict = comm.bcast(optimizer.state_dict(), root=root) 33 | if scheduler is not None: 34 | scheduler_state_dict = comm.bcast(scheduler.state_dict(), root=root) 35 | 36 | model.load_state_dict(model_state_dict) 37 | optimizer.load_state_dict(optimizer_state_dict) 38 | if scheduler is not None: 39 | scheduler.load_state_dict(scheduler_state_dict) 40 | 41 | 42 | @tc.no_grad() 43 | def sync_grads(model, comm): 44 | """ 45 | Sync gradients for a model across processes using MPI. 46 | The resulting synchronized gradient is stored in the p.grad field within 47 | each model parameter p. The stored value is the average. 48 | 49 | This allows us to do data-parallel training using multiple processes, 50 | as the p.grad fields are passed to the optimizer on each process, 51 | keeping the optimizer states synchronized as well. 52 | 53 | Args: 54 | model: model. 55 | comm: mpi4py comm_world object. 56 | 57 | Returns: 58 | None 59 | """ 60 | for p in model.parameters(): 61 | p_grad_local = p.grad.numpy() 62 | p_grad_global = np.zeros_like(p_grad_local) 63 | comm.Allreduce(sendbuf=p_grad_local, recvbuf=p_grad_global, op=MPI.SUM) 64 | p.grad.copy_(tc.FloatTensor(p_grad_global) / comm.Get_size()) 65 | -------------------------------------------------------------------------------- /rl2/utils/constants.py: -------------------------------------------------------------------------------- 1 | """ 2 | Global constants for the rl2 python package. 3 | """ 4 | 5 | ROOT_RANK = 0 -------------------------------------------------------------------------------- /rl2/utils/optim_util.py: -------------------------------------------------------------------------------- 1 | """ 2 | Utility module for optimization of RL^2 agents. 3 | """ 4 | 5 | import torch as tc 6 | 7 | 8 | def get_weight_decay_param_groups(model, weight_decay): 9 | decay, no_decay = [], [] 10 | for name, param in model.named_parameters(): 11 | if not param.requires_grad: 12 | continue 13 | elif len(param.shape) == 1: 14 | no_decay.append(param) 15 | else: 16 | decay.append(param) 17 | return [ 18 | {'params': decay, 'weight_decay': weight_decay}, 19 | {'params': no_decay, 'weight_decay': 0.0} 20 | ] 21 | -------------------------------------------------------------------------------- /rl2/utils/stat_util.py: -------------------------------------------------------------------------------- 1 | """ 2 | Utility module containing basic statistical operations. 3 | """ 4 | 5 | import numpy as np 6 | 7 | 8 | def standardize(arr: np.ndarray) -> np.ndarray: 9 | """ 10 | Computes the empirical z-scores of an array. 11 | 12 | Args: 13 | arr: a numpy array 14 | 15 | Returns: 16 | a numpy array of z-scores. 17 | """ 18 | eps = 1e-8 19 | mu = arr.mean() 20 | sigma = arr.std() 21 | standardized = (arr - mu) / (eps + sigma) 22 | return standardized 23 | 24 | 25 | def explained_variance(ypred: np.ndarray, y: np.ndarray) -> np.float32: 26 | """ 27 | Computes the explained variance. 28 | See https://en.wikipedia.org/wiki/Explained_variation 29 | 30 | Args: 31 | ypred: predicted values. 32 | y: actual values. 33 | 34 | Returns: 35 | The explained variance, a number between 0 and 1. 36 | """ 37 | vary = y.var() 38 | return 1 - (y-ypred).var()/(1e-8 + vary) -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | from setuptools import setup 2 | 3 | 4 | setup( 5 | name="pytorch_rl2_mdp_lstm", 6 | py_modules=["rl2"], 7 | version="2.0.0", 8 | description="A Pytorch implementation of RL^2.", 9 | author="Lucas D. Lingle", 10 | install_requires=[ 11 | 'mpi4py==3.0.3', 12 | 'torch==1.8.1' 13 | ] 14 | ) -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | """ 2 | Script for training stateful meta-reinforcement learning agents 3 | """ 4 | 5 | import argparse 6 | from functools import partial 7 | 8 | import torch as tc 9 | 10 | from rl2.envs.bandit_env import BanditEnv 11 | from rl2.envs.mdp_env import MDPEnv 12 | 13 | from rl2.agents.preprocessing.tabular import MABPreprocessing, MDPPreprocessing 14 | from rl2.agents.architectures.gru import GRU 15 | from rl2.agents.architectures.lstm import LSTM 16 | from rl2.agents.architectures.snail import SNAIL 17 | from rl2.agents.architectures.transformer import Transformer 18 | from rl2.agents.heads.policy_heads import LinearPolicyHead 19 | from rl2.agents.heads.value_heads import LinearValueHead 20 | from rl2.agents.integration.policy_net import StatefulPolicyNet 21 | from rl2.agents.integration.value_net import StatefulValueNet 22 | from rl2.algos.ppo import training_loop 23 | 24 | from rl2.utils.checkpoint_util import maybe_load_checkpoint, save_checkpoint 25 | from rl2.utils.comm_util import get_comm, sync_state 26 | from rl2.utils.constants import ROOT_RANK 27 | from rl2.utils.optim_util import get_weight_decay_param_groups 28 | 29 | 30 | def create_argparser(): 31 | parser = argparse.ArgumentParser( 32 | description="""Training script for RL^2.""") 33 | 34 | ### Environment 35 | parser.add_argument("--environment", choices=['bandit', 'tabular_mdp'], 36 | default='bandit') 37 | parser.add_argument("--num_states", type=int, default=10, 38 | help="Ignored if environment is bandit.") 39 | parser.add_argument("--num_actions", type=int, default=5) 40 | parser.add_argument("--max_episode_len", type=int, default=10, 41 | help="Timesteps before automatic episode reset. " + 42 | "Ignored if environment is bandit.") 43 | parser.add_argument("--meta_episode_len", type=int, default=100, 44 | help="Timesteps per meta-episode.") 45 | 46 | ### Architecture 47 | parser.add_argument( 48 | "--architecture", choices=['gru', 'lstm', 'snail', 'transformer'], 49 | default='gru') 50 | parser.add_argument("--num_features", type=int, default=256) 51 | 52 | ### Checkpointing 53 | parser.add_argument("--model_name", type=str, default='defaults') 54 | parser.add_argument("--checkpoint_dir", type=str, default='checkpoints') 55 | 56 | ### Training 57 | parser.add_argument("--max_pol_iters", type=int, default=12000) 58 | parser.add_argument("--meta_episodes_per_policy_update", type=int, default=-1, 59 | help="If -1, quantity is determined using a formula") 60 | parser.add_argument("--meta_episodes_per_learner_batch", type=int, default=60) 61 | parser.add_argument("--ppo_opt_epochs", type=int, default=8) 62 | parser.add_argument("--ppo_clip_param", type=float, default=0.10) 63 | parser.add_argument("--ppo_ent_coef", type=float, default=0.01) 64 | parser.add_argument("--discount_gamma", type=float, default=0.99) 65 | parser.add_argument("--gae_lambda", type=float, default=0.3) 66 | parser.add_argument("--standardize_advs", type=int, choices=[0,1], default=0) 67 | parser.add_argument("--adam_lr", type=float, default=2e-4) 68 | parser.add_argument("--adam_eps", type=float, default=1e-5) 69 | parser.add_argument("--adam_wd", type=float, default=0.01) 70 | return parser 71 | 72 | 73 | def create_env(environment, num_states, num_actions, max_episode_len): 74 | if environment == 'bandit': 75 | return BanditEnv( 76 | num_actions=num_actions) 77 | if environment == 'tabular_mdp': 78 | return MDPEnv( 79 | num_states=num_states, 80 | num_actions=num_actions, 81 | max_episode_length=max_episode_len) 82 | raise NotImplementedError 83 | 84 | 85 | def create_preprocessing(environment, num_states, num_actions): 86 | if environment == 'bandit': 87 | return MABPreprocessing( 88 | num_actions=num_actions) 89 | if environment == 'tabular_mdp': 90 | return MDPPreprocessing( 91 | num_states=num_states, 92 | num_actions=num_actions) 93 | raise NotImplementedError 94 | 95 | 96 | def create_architecture(architecture, input_dim, num_features, context_size): 97 | if architecture == 'gru': 98 | return GRU( 99 | input_dim=input_dim, 100 | hidden_dim=num_features, 101 | forget_bias=1.0, 102 | use_ln=True, 103 | reset_after=True) 104 | if architecture == 'lstm': 105 | return LSTM( 106 | input_dim=input_dim, 107 | hidden_dim=num_features, 108 | forget_bias=1.0, 109 | use_ln=True) 110 | if architecture == 'snail': 111 | return SNAIL( 112 | input_dim=input_dim, 113 | feature_dim=num_features, 114 | context_size=context_size, 115 | use_ln=True) 116 | if architecture == 'transformer': 117 | return Transformer( 118 | input_dim=input_dim, 119 | feature_dim=num_features, 120 | n_layer=9, 121 | n_head=2, 122 | n_context=context_size) 123 | raise NotImplementedError 124 | 125 | 126 | def create_head(head_type, num_features, num_actions): 127 | if head_type == 'policy': 128 | return LinearPolicyHead( 129 | num_features=num_features, 130 | num_actions=num_actions) 131 | if head_type == 'value': 132 | return LinearValueHead( 133 | num_features=num_features) 134 | raise NotImplementedError 135 | 136 | 137 | def create_net( 138 | net_type, environment, architecture, num_states, num_actions, 139 | num_features, context_size 140 | ): 141 | preprocessing = create_preprocessing( 142 | environment=environment, 143 | num_states=num_states, 144 | num_actions=num_actions) 145 | architecture = create_architecture( 146 | architecture=architecture, 147 | input_dim=preprocessing.output_dim, 148 | num_features=num_features, 149 | context_size=context_size) 150 | head = create_head( 151 | head_type=net_type, 152 | num_features=architecture.output_dim, 153 | num_actions=num_actions) 154 | 155 | if net_type == 'policy': 156 | return StatefulPolicyNet( 157 | preprocessing=preprocessing, 158 | architecture=architecture, 159 | policy_head=head) 160 | if net_type == 'value': 161 | return StatefulValueNet( 162 | preprocessing=preprocessing, 163 | architecture=architecture, 164 | value_head=head) 165 | raise NotImplementedError 166 | 167 | 168 | def main(): 169 | args = create_argparser().parse_args() 170 | comm = get_comm() 171 | 172 | # create env. 173 | env = create_env( 174 | environment=args.environment, 175 | num_states=args.num_states, 176 | num_actions=args.num_actions, 177 | max_episode_len=args.max_episode_len) 178 | 179 | # create learning system. 180 | policy_net = create_net( 181 | net_type='policy', 182 | environment=args.environment, 183 | architecture=args.architecture, 184 | num_states=args.num_states, 185 | num_actions=args.num_actions, 186 | num_features=args.num_features, 187 | context_size=args.meta_episode_len) 188 | 189 | value_net = create_net( 190 | net_type='value', 191 | environment=args.environment, 192 | architecture=args.architecture, 193 | num_states=args.num_states, 194 | num_actions=args.num_actions, 195 | num_features=args.num_features, 196 | context_size=args.meta_episode_len) 197 | 198 | policy_optimizer = tc.optim.AdamW( 199 | get_weight_decay_param_groups(policy_net, args.adam_wd), 200 | lr=args.adam_lr, 201 | eps=args.adam_eps) 202 | value_optimizer = tc.optim.AdamW( 203 | get_weight_decay_param_groups(value_net, args.adam_wd), 204 | lr=args.adam_lr, 205 | eps=args.adam_eps) 206 | 207 | policy_scheduler = None 208 | value_scheduler = None 209 | 210 | # load checkpoint, if applicable. 211 | pol_iters_so_far = 0 212 | if comm.Get_rank() == ROOT_RANK: 213 | a = maybe_load_checkpoint( 214 | checkpoint_dir=args.checkpoint_dir, 215 | model_name=f"{args.model_name}/policy_net", 216 | model=policy_net, 217 | optimizer=policy_optimizer, 218 | scheduler=policy_scheduler, 219 | steps=None) 220 | 221 | b = maybe_load_checkpoint( 222 | checkpoint_dir=args.checkpoint_dir, 223 | model_name=f"{args.model_name}/value_net", 224 | model=value_net, 225 | optimizer=value_optimizer, 226 | scheduler=value_scheduler, 227 | steps=None) 228 | 229 | if a != b: 230 | raise RuntimeError( 231 | "Policy and value iterates not aligned in latest checkpoint!") 232 | pol_iters_so_far = a 233 | 234 | # sync state. 235 | pol_iters_so_far = comm.bcast(pol_iters_so_far, root=ROOT_RANK) 236 | sync_state( 237 | model=policy_net, 238 | optimizer=policy_optimizer, 239 | scheduler=policy_scheduler, 240 | comm=comm, 241 | root=ROOT_RANK) 242 | sync_state( 243 | model=value_net, 244 | optimizer=value_optimizer, 245 | scheduler=value_scheduler, 246 | comm=comm, 247 | root=ROOT_RANK) 248 | 249 | # make callback functions for checkpointing. 250 | policy_checkpoint_fn = partial( 251 | save_checkpoint, 252 | checkpoint_dir=args.checkpoint_dir, 253 | model_name=f"{args.model_name}/policy_net", 254 | model=policy_net, 255 | optimizer=policy_optimizer, 256 | scheduler=policy_scheduler) 257 | 258 | value_checkpoint_fn = partial( 259 | save_checkpoint, 260 | checkpoint_dir=args.checkpoint_dir, 261 | model_name=f"{args.model_name}/value_net", 262 | model=value_net, 263 | optimizer=value_optimizer, 264 | scheduler=value_scheduler) 265 | 266 | # run it! 267 | if args.meta_episodes_per_policy_update == -1: 268 | numer = 240000 269 | denom = comm.Get_size() * args.meta_episode_len 270 | meta_episodes_per_policy_update = numer // denom 271 | else: 272 | meta_episodes_per_policy_update = args.meta_episodes_per_policy_update 273 | 274 | training_loop( 275 | env=env, 276 | policy_net=policy_net, 277 | value_net=value_net, 278 | policy_optimizer=policy_optimizer, 279 | value_optimizer=value_optimizer, 280 | policy_scheduler=policy_scheduler, 281 | value_scheduler=value_scheduler, 282 | meta_episodes_per_policy_update=meta_episodes_per_policy_update, 283 | meta_episodes_per_learner_batch=args.meta_episodes_per_learner_batch, 284 | meta_episode_len=args.meta_episode_len, 285 | ppo_opt_epochs=args.ppo_opt_epochs, 286 | ppo_clip_param=args.ppo_clip_param, 287 | ppo_ent_coef=args.ppo_ent_coef, 288 | discount_gamma=args.discount_gamma, 289 | gae_lambda=args.gae_lambda, 290 | standardize_advs=bool(args.standardize_advs), 291 | max_pol_iters=args.max_pol_iters, 292 | pol_iters_so_far=pol_iters_so_far, 293 | policy_checkpoint_fn=policy_checkpoint_fn, 294 | value_checkpoint_fn=value_checkpoint_fn, 295 | comm=comm) 296 | 297 | 298 | if __name__ == '__main__': 299 | main() 300 | --------------------------------------------------------------------------------