├── README.md ├── diagnostic ├── LICENSE ├── README.md ├── copy_task_data.py ├── copy_task_main.py ├── data.py ├── eval_utils.py ├── layers.py ├── model.py ├── rtrl_copy_task_main.py ├── rtrl_layers.py └── utils │ └── copy_task_generator.py └── reinforcement_learning ├── LICENSE ├── README.md ├── atari_data.py ├── list_atari_games.txt ├── nest ├── README.md ├── nest │ ├── nest.h │ ├── nest_pybind.cc │ └── nest_pybind.h ├── nest_test.py └── setup.py ├── scripts └── install_grpc.sh ├── setup.py ├── tests ├── batching_queue_test.py ├── contiguous_arrays_env.py ├── contiguous_arrays_test.py ├── core_agent_state_env.py ├── core_agent_state_test.py ├── dynamic_batcher_test.py ├── inference_speed_profiling.py ├── lint_changed.sh ├── polybeast_inference_test.py ├── polybeast_learn_function_test.py ├── polybeast_loss_functions_test.py ├── polybeast_net_test.py └── vtrace_test.py ├── torchbeast ├── atari_wrappers.py ├── core │ ├── environment.py │ ├── file_writer.py │ ├── prof.py │ └── vtrace.py ├── layer.py ├── model.py └── polybeast.py ├── torchbeast_atari ├── atari_wrappers.py ├── model.py ├── polybeast.py ├── polybeast_env.py └── polybeast_learner.py ├── torchbeast_dmlab ├── atari_wrappers.py ├── core │ ├── .history.kazuki │ ├── _environment.py │ ├── environment.py │ ├── file_writer.py │ ├── prof.py │ └── vtrace.py ├── dmlab30.py ├── dmlab_wrappers.py ├── model.py ├── polybeast.py ├── polybeast_env.py └── polybeast_learner.py └── torchbeast_procgen ├── model.py ├── polybeast.py ├── polybeast_env.py ├── polybeast_learner.py └── procgen_wrappers.py /README.md: -------------------------------------------------------------------------------- 1 | # Real-Time Recurrent Learning with eLSTM 2 | 3 | This is the official repository containing code for the paper: 4 | 5 | [Exploring the Promise and Limits of Real-Time Recurrent Learning (ICLR 2024)](https://arxiv.org/abs/2305.19044) 6 | 7 | ## Contents 8 | * `diagnostic` directory contains code for the copy task (Sec 4.1) 9 | * `reinforcement_learning` directory contains code for the RL experiments (Sec. 4.2 and 4.3) 10 | 11 | Please refer to the readme file in each directory for further instructions. 12 | Separate license files can be found in each directory. 13 | 14 | ## BibTex 15 | ``` 16 | @inproceedings{irie2023exploring, 17 | title={Exploring the Promise and Limits of Real-Time Recurrent Learning}, 18 | author={Irie, Kazuki and Gopalakrishnan, Anand and Schmidhuber, J{\"u}rgen}, 19 | booktitle={International Conference on Learning Representations (ICLR)}, 20 | address={Vienna, Austria}, 21 | month=may, 22 | year=2024 23 | } 24 | ``` 25 | -------------------------------------------------------------------------------- /diagnostic/LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2021, 2023 Kazuki Irie 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /diagnostic/README.md: -------------------------------------------------------------------------------- 1 | # Diagnostic Task/Copy Task 2 | 3 | This repository was originally forked from [IDSIA/recurrent-fwp/algorithmic](https://github.com/IDSIA/recurrent-fwp/tree/master/algorithmic). 4 | 5 | ## Requirements 6 | * PyTorch (`>= 1.6.0` recommended) 7 | 8 | ## Data Generation 9 | 10 | ``` 11 | cd utils 12 | mkdir data_copy_task 13 | # `code_length` is the number bits in the pattern to be memorised 14 | # set it to 50 or 500 to get the dataset used in the paper 15 | python copy_task_generator.py --dump_dir data_copy_task --code_length 500 16 | ``` 17 | 18 | ## Training 19 | The training script is as follows: 20 | 21 | `model_type` specifies the learning algo/model type: 22 | * `0`: BPTT LSTM 23 | * `10`: BPTT eLSTM 24 | * `11`: RTRL eLSTM; use this with `rtrl_copy_task_main.py` 25 | 26 | **NB: in the code, "eLSTM" (the name of our RNN architecture in the paper) is called "QuasiLSTM".** 27 | 28 | `--level` specifies the max sequence length of the patterns to be memorised (either 50 or 500 in the paper) 29 | 30 | ``` 31 | DATA_DIR='utils/data_copy_task' 32 | 33 | python rtrl_copy_task_main.py \ 34 | --data_dir ${DATA_DIR} \ 35 | --level 500 \ 36 | --model_type 11 \ 37 | --no_embedding \ 38 | --num_layer 1 \ 39 | --hidden_size 2048 \ 40 | --dropout 0.0 \ 41 | --batch_size 128 \ 42 | --learning_rate 3e-5 \ 43 | --clip 1.0 \ 44 | --grad_cummulate 1 \ 45 | --num_epoch 500 \ 46 | --seed 1 \ 47 | ``` 48 | 49 | `rtrl_copy_task_main.py` should be replaced by `copy_task_main.py` for all non-RTRL settings (there are many code duplications; we leave them as is). 50 | 51 | Note that unlike prior work, we do not use any curriculum learning. 52 | 53 | ## Evaluation 54 | Evalution is automatically run at the end of training using the best performing checkpoint based on the validation accuracy (which should be 100% for this task). 55 | 56 | ## Gradient Test 57 | Basic implementation of RTRL forward recursion equations and the corresponding gradient test can be found in `rtrl_layers.py` 58 | -------------------------------------------------------------------------------- /diagnostic/copy_task_data.py: -------------------------------------------------------------------------------- 1 | # Dataset 2 | import os 3 | 4 | import numpy 5 | import random 6 | 7 | import torch 8 | from torch.utils.data import Dataset 9 | 10 | 11 | # From https://pytorch.org/docs/stable/notes/randomness.html 12 | def seed_worker(worker_id): 13 | worker_seed = torch.initial_seed() % 2**32 14 | numpy.random.seed(worker_seed) 15 | random.seed(worker_seed) 16 | 17 | 18 | class Vocabulary(object): 19 | def __init__(self, vocab_dict=None, vocab_file=None, 20 | include_unk=False, unk_str='', 21 | include_eos=False, eos_str=''): 22 | # If provided, contruction from dict is prioritized. 23 | self.str2idx = {} 24 | self.idx2str = [] 25 | 26 | if include_eos: 27 | self.add_str(eos_str) 28 | self.eos_str = eos_str 29 | 30 | if include_unk: 31 | self.add_str(unk_str) 32 | self.unk_str = unk_str 33 | 34 | if vocab_dict is not None: 35 | self.contruct_from_dict(vocab_dict) 36 | elif vocab_file is not None: 37 | self.contruct_from_file(vocab_file) 38 | 39 | def contruct_from_file(self, vocab_file): 40 | # Expect each line to contain "token_str idx", space separated. 41 | print(f"Creating vocab from: {vocab_file}") 42 | tmp_idx2str_dict = {} 43 | with open(vocab_file, 'r') as text: 44 | for line in text: 45 | vocab_pair = line.split() 46 | assert vocab_pair == 2, "Unexpected vocab format." 47 | token_str, token_idx = vocab_pair 48 | # TODO 49 | assert False 50 | 51 | def contruct_from_dict(self, vocab_dict): 52 | self.str2idx = vocab_dict 53 | vocab_size = len(vocab_dict.keys()) 54 | # TODO 55 | assert False 56 | 57 | def get_idx(self, stg): 58 | return self.str2idx[stg] 59 | 60 | def get_str(self, idx): 61 | return self.idx2str(idx) 62 | 63 | # Increment the vocab size, give the new index to the new token. 64 | def add_str(self, stg): 65 | if stg not in self.str2idx.keys(): 66 | self.idx2str.append(stg) 67 | self.str2idx[stg] = len(self.idx2str) - 1 68 | 69 | # Return vocab size. 70 | def size(self): 71 | return len(self.idx2str) 72 | 73 | def get_unk_str(self): 74 | return self.unk_str 75 | 76 | 77 | class CopyTaskDataset(Dataset): 78 | 79 | def __init__(self, src_file, tgt_file, src_pad_idx, tgt_pad_idx, 80 | src_vocab=None, tgt_vocab=None, device='cuda'): 81 | 82 | self.src_max_seq_length = None # set by text_to_data 83 | self.tgt_max_seq_length = None 84 | 85 | build_src_vocab = False 86 | if src_vocab is None: 87 | build_src_vocab = True 88 | self.src_vocab = Vocabulary() 89 | else: 90 | self.src_vocab = src_vocab 91 | 92 | build_tgt_vocab = False 93 | if tgt_vocab is None: 94 | build_tgt_vocab = True 95 | self.tgt_vocab = Vocabulary() 96 | else: 97 | self.tgt_vocab = tgt_vocab 98 | 99 | self.data = self.text_to_data( 100 | src_file, tgt_file, src_pad_idx, tgt_pad_idx, 101 | build_src_vocab, build_tgt_vocab, device) 102 | 103 | self.data_size = len(self.data) 104 | 105 | def __len__(self): # To be used by PyTorch Dataloader. 106 | return self.data_size 107 | 108 | def __getitem__(self, index): # To be used by PyTorch Dataloader. 109 | return self.data[index] 110 | 111 | def text_to_data(self, src_file, tgt_file, src_pad_idx, tgt_pad_idx, 112 | build_src_vocab=None, build_tgt_vocab=None, 113 | device='cuda'): 114 | # Convert paired src/tgt texts into torch.tensor data. 115 | # All sequences are padded to the length of the longest sequence 116 | # of the respective file. 117 | 118 | assert os.path.exists(src_file) 119 | assert os.path.exists(tgt_file) 120 | 121 | data_list = [] 122 | # Check the max length, if needed construct vocab file. 123 | src_max = 0 124 | with open(src_file, 'r') as text: 125 | for line in text: 126 | tokens = line.split() 127 | length = len(tokens) 128 | if src_max < length: 129 | src_max = length 130 | if build_src_vocab: 131 | for token in tokens: 132 | self.src_vocab.add_str(token) 133 | self.src_max_seq_length = src_max 134 | 135 | tgt_max = 0 136 | with open(tgt_file, 'r') as text: 137 | for line in text: 138 | tokens = line.split() 139 | length = len(tokens) 140 | if tgt_max < length: 141 | tgt_max = length 142 | if build_tgt_vocab: 143 | for token in tokens: 144 | self.tgt_vocab.add_str(token) 145 | self.tgt_max_seq_length = tgt_max 146 | 147 | # Construct data 148 | src_list = [] 149 | print(f"Loading source file from: {src_file}") 150 | with open(src_file, 'r') as text: 151 | for line in text: 152 | seq = [] 153 | tokens = line.split() 154 | for token in tokens: 155 | seq.append(self.src_vocab.get_idx(token)) 156 | var_len = len(seq) 157 | var_seq = torch.tensor(seq, device=device, dtype=torch.int64) 158 | # padding 159 | new_seq = var_seq.data.new(src_max).fill_(src_pad_idx) 160 | new_seq[:var_len] = var_seq 161 | src_list.append(new_seq) 162 | 163 | tgt_list = [] 164 | print(f"Loading target file from: {tgt_file}") 165 | with open(tgt_file, 'r') as text: 166 | for line in text: 167 | seq = [] 168 | tokens = line.split() 169 | for token in tokens: 170 | seq.append(self.tgt_vocab.get_idx(token)) 171 | 172 | var_len = len(seq) 173 | var_seq = torch.tensor(seq, device=device, dtype=torch.int64) 174 | # padding 175 | new_seq = var_seq.data.new(tgt_max).fill_(tgt_pad_idx) 176 | new_seq[:var_len] = var_seq 177 | tgt_list.append(new_seq) 178 | 179 | # src_file and tgt_file are assumed to be aligned. 180 | assert len(src_list) == len(tgt_list) 181 | for i in range(len(src_list)): 182 | data_list.append((src_list[i], tgt_list[i])) 183 | 184 | return data_list 185 | 186 | -------------------------------------------------------------------------------- /diagnostic/data.py: -------------------------------------------------------------------------------- 1 | # Dataset 2 | import os 3 | 4 | import numpy 5 | import random 6 | 7 | import torch 8 | from torch.utils.data import Dataset 9 | 10 | 11 | # From https://pytorch.org/docs/stable/notes/randomness.html 12 | def seed_worker(worker_id): 13 | worker_seed = torch.initial_seed() % 2**32 14 | numpy.random.seed(worker_seed) 15 | random.seed(worker_seed) 16 | 17 | 18 | class Vocabulary(object): 19 | def __init__(self, vocab_dict=None, vocab_file=None, 20 | include_unk=False, unk_str='', 21 | include_eos=False, eos_str='', 22 | no_out_str='_', 23 | pad_id=None, pad_str=None): 24 | # If provided, contruction from dict is prioritized. 25 | self.str2idx = {} 26 | self.idx2str = [] 27 | 28 | self.no_out_str = no_out_str 29 | 30 | if include_eos: 31 | self.add_str(eos_str) 32 | self.eos_str = eos_str 33 | 34 | if include_unk: 35 | self.add_str(unk_str) 36 | self.unk_str = unk_str 37 | 38 | if vocab_dict is not None: 39 | self.contruct_from_dict(vocab_dict) 40 | elif vocab_file is not None: 41 | self.contruct_from_file(vocab_file) 42 | 43 | def contruct_from_file(self, vocab_file): 44 | # Expect each line to contain "token_str idx", space separated. 45 | print(f"Creating vocab from: {vocab_file}") 46 | tmp_idx2str_dict = {} 47 | with open(vocab_file, 'r') as text: 48 | for line in text: 49 | vocab_pair = line.split() 50 | assert vocab_pair == 2, "Unexpected vocab format." 51 | token_str, token_idx = vocab_pair 52 | # TODO 53 | assert False 54 | 55 | def contruct_from_dict(self, vocab_dict): 56 | self.str2idx = vocab_dict 57 | vocab_size = len(vocab_dict.keys()) 58 | # TODO 59 | assert False 60 | 61 | def get_idx(self, stg): 62 | return self.str2idx[stg] 63 | 64 | def get_str(self, idx): 65 | return self.idx2str(idx) 66 | 67 | # Increment the vocab size, give the new index to the new token. 68 | def add_str(self, stg): 69 | if stg not in self.str2idx.keys(): 70 | self.idx2str.append(stg) 71 | self.str2idx[stg] = len(self.idx2str) - 1 72 | 73 | # Return vocab size. 74 | def size(self): 75 | return len(self.idx2str) 76 | 77 | def get_no_op_id(self): 78 | return self.str2idx[self.no_out_str] 79 | 80 | def get_unk_str(self): 81 | return self.unk_str 82 | 83 | 84 | class LTEDataset(Dataset): 85 | 86 | def __init__(self, src_file, tgt_file, src_pad_idx, tgt_pad_idx, 87 | src_vocab=None, tgt_vocab=None, device='cuda'): 88 | 89 | self.src_max_seq_length = None # set by text_to_data 90 | self.tgt_max_seq_length = None 91 | 92 | build_src_vocab = False 93 | if src_vocab is None: 94 | build_src_vocab = True 95 | self.src_vocab = Vocabulary() 96 | else: 97 | self.src_vocab = src_vocab 98 | 99 | build_tgt_vocab = False 100 | if tgt_vocab is None: 101 | build_tgt_vocab = True 102 | self.tgt_vocab = Vocabulary() 103 | else: 104 | self.tgt_vocab = tgt_vocab 105 | 106 | self.data = self.text_to_data( 107 | src_file, tgt_file, src_pad_idx, tgt_pad_idx, 108 | build_src_vocab, build_tgt_vocab, device) 109 | 110 | self.data_size = len(self.data) 111 | 112 | def __len__(self): # To be used by PyTorch Dataloader. 113 | return self.data_size 114 | 115 | def __getitem__(self, index): # To be used by PyTorch Dataloader. 116 | return self.data[index] 117 | 118 | def text_to_data(self, src_file, tgt_file, src_pad_idx, tgt_pad_idx, 119 | build_src_vocab=None, build_tgt_vocab=None, 120 | device='cuda'): 121 | # Convert paired src/tgt texts into torch.tensor data. 122 | # All sequences are padded to the length of the longest sequence 123 | # of the respective file. 124 | 125 | assert os.path.exists(src_file) 126 | assert os.path.exists(tgt_file) 127 | 128 | data_list = [] 129 | # Check the max length, if needed construct vocab file. 130 | src_max = 0 131 | with open(src_file, 'r') as text: 132 | for line in text: 133 | tokens = line.split() 134 | length = len(tokens) 135 | if src_max < length: 136 | src_max = length 137 | if build_src_vocab: 138 | for token in tokens: 139 | self.src_vocab.add_str(token) 140 | self.src_max_seq_length = src_max 141 | 142 | tgt_max = 0 143 | with open(tgt_file, 'r') as text: 144 | for line in text: 145 | tokens = line.split() 146 | length = len(tokens) 147 | if tgt_max < length: 148 | tgt_max = length 149 | if build_tgt_vocab: 150 | for token in tokens: 151 | self.tgt_vocab.add_str(token) 152 | self.tgt_max_seq_length = tgt_max 153 | 154 | # Construct data 155 | src_list = [] 156 | print(f"Loading source file from: {src_file}") 157 | with open(src_file, 'r') as text: 158 | for line in text: 159 | seq = [] 160 | tokens = line.split() 161 | for token in tokens: 162 | seq.append(self.src_vocab.get_idx(token)) 163 | var_len = len(seq) 164 | var_seq = torch.tensor(seq, device=device, dtype=torch.int64) 165 | # padding 166 | new_seq = var_seq.data.new(src_max).fill_(src_pad_idx) 167 | new_seq[:var_len] = var_seq 168 | src_list.append(new_seq) 169 | 170 | tgt_list = [] 171 | print(f"Loading target file from: {tgt_file}") 172 | with open(tgt_file, 'r') as text: 173 | for line in text: 174 | seq = [] 175 | tokens = line.split() 176 | for token in tokens: 177 | seq.append(self.tgt_vocab.get_idx(token)) 178 | 179 | var_len = len(seq) 180 | var_seq = torch.tensor(seq, device=device, dtype=torch.int64) 181 | # padding 182 | new_seq = var_seq.data.new(tgt_max).fill_(tgt_pad_idx) 183 | new_seq[:var_len] = var_seq 184 | tgt_list.append(new_seq) 185 | 186 | # src_file and tgt_file are assumed to be aligned. 187 | assert len(src_list) == len(tgt_list) 188 | for i in range(len(src_list)): 189 | data_list.append((src_list[i], tgt_list[i])) 190 | 191 | return data_list 192 | 193 | 194 | if __name__ == '__main__': 195 | from datetime import datetime 196 | import random 197 | import argparse 198 | 199 | from torch.utils.data import DataLoader 200 | 201 | torch.manual_seed(123) 202 | random.seed(123) 203 | 204 | if torch.cuda.is_available(): 205 | torch.cuda.manual_seed_all(123) 206 | parser = argparse.ArgumentParser(description='Learning to execute') 207 | parser.add_argument( 208 | '--data_dir', type=str, 209 | default='./data/', 210 | help='location of the data corpus') 211 | 212 | args = parser.parse_args() 213 | data_path = args.data_dir 214 | 215 | file_src = f"{data_path}/valid_3.src" 216 | file_tgt = f"{data_path}/valid_3.tgt" 217 | 218 | bsz = 3 219 | 220 | dummy_data = LTEDataset(src_file=file_src, tgt_file=file_tgt, 221 | src_pad_idx=0, tgt_pad_idx=-1, 222 | src_vocab=None, tgt_vocab=None) 223 | 224 | data_loader = DataLoader(dataset=dummy_data, batch_size=bsz, shuffle=True) 225 | 226 | stop_ = 2 227 | 228 | for idx, batch in enumerate(data_loader): 229 | src, tgt = batch 230 | if idx < stop_: 231 | print(src[:, 0:20]) 232 | -------------------------------------------------------------------------------- /diagnostic/eval_utils.py: -------------------------------------------------------------------------------- 1 | # Utils 2 | import torch 3 | 4 | 5 | def compute_accuracy(model, data_iterator, loss_fn, no_print_idx, pad_value=-1, 6 | show_example=False, only_nbatch=-1): 7 | """Compute accuracies and loss. 8 | 9 | :param str, split_name: for printing the accuracy with the split name. 10 | :param bool, show_example: if True, print some decoding output examples. 11 | :param int, only_nbatch: Only use given number of batches. If -1, use all 12 | data (default). 13 | returns loss, accucary char-level accuracy, print accuracy 14 | """ 15 | model.eval() 16 | 17 | total_loss = 0.0 18 | corr = 0 19 | corr_char = 0 20 | corr_print = 0 21 | 22 | step = 0 23 | total_num_seqs = 0 24 | total_char = 0 25 | total_print = 0 26 | 27 | for idx, batch in enumerate(data_iterator): 28 | step += 1 29 | src, tgt = batch 30 | logits = model(src) 31 | target = tgt # (B, len) 32 | 33 | # to compute accuracy 34 | output = torch.argmax(logits, dim=-1).squeeze() 35 | 36 | # compute loss 37 | logits = logits.contiguous().view(-1, logits.shape[-1]) 38 | labels = tgt.view(-1) 39 | loss = loss_fn(logits, labels) 40 | total_loss += loss 41 | 42 | # sequence level accuracy 43 | seq_match = (torch.eq(target, output) | (target == pad_value) 44 | ).all(1).sum().item() 45 | corr += seq_match 46 | total_num_seqs += src.size()[0] 47 | 48 | # padded part should not be counted as correct 49 | char_match = torch.logical_and( 50 | torch.logical_and(torch.eq(target, output), target != pad_value), 51 | target == no_print_idx).sum().item() 52 | corr_char += char_match 53 | total_char += torch.logical_and( 54 | target != pad_value, target == no_print_idx).sum().item() 55 | 56 | # Ignore non-print outputs 57 | print_match = torch.logical_and( 58 | torch.logical_and(torch.eq(target, output), target != pad_value), 59 | target != no_print_idx).sum().item() 60 | corr_print += print_match 61 | 62 | total_print += torch.logical_and( 63 | target != pad_value, target != no_print_idx).sum().item() 64 | 65 | if only_nbatch > 0: 66 | if idx > only_nbatch: 67 | break 68 | 69 | res_loss = total_loss.item() / float(step) 70 | acc = corr / float(total_num_seqs) * 100 71 | if total_char > 0: 72 | no_op_acc = corr_char / float(total_char) * 100 73 | else: 74 | no_op_acc = -0 75 | print_acc = corr_print / float(total_print) * 100 76 | 77 | return res_loss, acc, no_op_acc, print_acc 78 | -------------------------------------------------------------------------------- /diagnostic/layers.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | 6 | @torch.jit.script 7 | def elu_p1(x): 8 | return F.elu(x, 1., False) + 1. 9 | 10 | 11 | @torch.jit.script 12 | def sum_norm(x): 13 | return x / x.sum(-1, keepdim=True) 14 | 15 | 16 | # Quasi RNN-like https://arxiv.org/abs/1611.01576 17 | # But the output gate is conditioned by c(t) instead of c(t-1) 18 | class QuasiLSTMlayer(nn.Module): 19 | def __init__(self, input_dim, hidden_dim, forget_bias=0.): 20 | super().__init__() 21 | 22 | self.input_dim = input_dim 23 | self.hidden_dim = hidden_dim 24 | 25 | # weight matrices 26 | self.wm_z = nn.Parameter(torch.rand(hidden_dim, input_dim)) 27 | self.wm_f = nn.Parameter(torch.rand(hidden_dim, input_dim)) 28 | 29 | # weight vectors 30 | self.wv_z = nn.Parameter(torch.rand(1, hidden_dim)) # append B dim 31 | self.wv_f = nn.Parameter(torch.rand(1, hidden_dim)) 32 | 33 | # biases 34 | self.bias_z = nn.Parameter(torch.rand(1, hidden_dim)) 35 | self.bias_f = nn.Parameter(torch.rand(1, hidden_dim)) 36 | self.forget_bias = forget_bias 37 | 38 | self.init_weights() 39 | 40 | def init_weights(self): 41 | torch.nn.init.normal_(self.wm_z, mean=0.0, std=0.1) 42 | torch.nn.init.normal_(self.wm_f, mean=0.0, std=0.1) 43 | 44 | torch.nn.init.normal_(self.wv_z, mean=0.0, std=0.1) 45 | torch.nn.init.normal_(self.wv_f, mean=0.0, std=0.1) 46 | 47 | torch.nn.init.normal_(self.bias_z, mean=0.0, std=0.1) 48 | torch.nn.init.normal_(self.bias_f, mean=0.0, std=0.1) 49 | with torch.no_grad(): 50 | self.bias_f.copy_(self.bias_f + self.forget_bias) 51 | 52 | def forward(self, x, state=None): 53 | # x shape: (len, B, n_head * d_head) 54 | # state is a tuple 55 | slen, bsz, x_dim = x.size() 56 | 57 | if state is None: 58 | hidden_prev = torch.zeros( 59 | [bsz, self.hidden_dim], device=x.device) 60 | else: 61 | hidden_prev = state.squeeze(0) # layer dim compat. 62 | 63 | weight_matrix = torch.cat([self.wm_z, self.wm_f], dim=0) 64 | out = x.reshape(slen * bsz, x_dim) 65 | out = F.linear(out, weight_matrix) 66 | out = out.view(slen, bsz, self.hidden_dim * 2) 67 | 68 | out_z, out_f = torch.split(out, (self.hidden_dim,) * 2, -1) 69 | 70 | output_list = [] 71 | 72 | new_cell = hidden_prev 73 | for z_, f_ in zip(out_z, out_f): 74 | z_part = torch.tanh( 75 | z_ + self.wv_z * new_cell + self.bias_z) 76 | f_part = torch.sigmoid( 77 | f_ + self.wv_f * new_cell + self.bias_f) 78 | new_cell = new_cell * f_part + (1. - f_part) * z_part 79 | output_list.append(new_cell.clone()) 80 | 81 | new_cells = torch.stack(output_list) # (len, B, dim) 82 | 83 | return new_cells 84 | -------------------------------------------------------------------------------- /diagnostic/model.py: -------------------------------------------------------------------------------- 1 | # Contains model implementations. 2 | 3 | import torch 4 | import torch.nn as nn 5 | 6 | 7 | from layers import QuasiLSTMlayer 8 | from rtrl_layers import RTRLQuasiLSTMlayer 9 | 10 | 11 | class BaseModel(nn.Module): 12 | def __init__(self): 13 | super().__init__() 14 | 15 | # return number of parameters 16 | def num_params(self): 17 | return sum(p.numel() for p in self.parameters() if p.requires_grad) 18 | 19 | def reset_grad(self): 20 | # More efficient than optimizer.zero_grad() according to: 21 | # Szymon Migacz "PYTORCH PERFORMANCE TUNING GUIDE" at GTC-21. 22 | # - doesn't execute memset for every parameter 23 | # - memory is zeroed-out by the allocator in a more efficient way 24 | # - backward pass updates gradients with "=" operator (write) (unlike 25 | # zero_grad() which would result in "+="). 26 | # In PyT >= 1.7, one can do `model.zero_grad(set_to_none=True)` 27 | for p in self.parameters(): 28 | p.grad = None 29 | 30 | def print_params(self): 31 | for p in self.named_parameters(): 32 | print(p) 33 | 34 | 35 | # Pure PyTorch LSTM model 36 | class LSTMModel(BaseModel): 37 | def __init__(self, emb_dim, hidden_size, in_vocab_size, out_vocab_size, 38 | dropout=0.0, num_layers=1, no_embedding=False): 39 | super().__init__() 40 | self.in_vocab_size = in_vocab_size 41 | self.out_vocab_size = out_vocab_size 42 | self.hidden_size = hidden_size 43 | 44 | self.no_embedding = no_embedding 45 | rnn_input_size = in_vocab_size 46 | 47 | if not no_embedding: 48 | self.embedding = nn.Embedding( 49 | num_embeddings=in_vocab_size, embedding_dim=emb_dim) 50 | rnn_input_size = emb_dim 51 | else: 52 | self.num_classes = in_vocab_size 53 | 54 | self.rnn_func = nn.LSTM( 55 | input_size=rnn_input_size, hidden_size=hidden_size, 56 | num_layers=num_layers) 57 | 58 | self.dropout = dropout 59 | if dropout > 0.: 60 | self.dropout = nn.Dropout(dropout) 61 | self.out_layer = nn.Linear(hidden_size, out_vocab_size) 62 | 63 | def forward(self, x): 64 | if self.no_embedding: 65 | out = torch.nn.functional.one_hot(x, self.num_classes).permute(1, 0, 2).float() 66 | else: 67 | out = self.embedding(x).permute(1, 0, 2) # seq dim first 68 | 69 | # if self.dropout: 70 | # out = self.dropout(out) 71 | out, _ = self.rnn_func(out) 72 | 73 | if self.dropout: 74 | out = self.dropout(out) 75 | logits = self.out_layer(out).permute(1, 0, 2) 76 | 77 | return logits 78 | 79 | 80 | class QuasiLSTMModel(BaseModel): 81 | def __init__(self, emb_dim, hidden_size, in_vocab_size, out_vocab_size, 82 | dropout=0.0, num_layers=1, no_embedding=False): 83 | super().__init__() 84 | self.in_vocab_size = in_vocab_size 85 | self.out_vocab_size = out_vocab_size 86 | self.hidden_size = hidden_size 87 | 88 | self.no_embedding = no_embedding 89 | rnn_input_size = in_vocab_size 90 | if not no_embedding: 91 | self.embedding = nn.Embedding( 92 | num_embeddings=in_vocab_size, embedding_dim=emb_dim) 93 | rnn_input_size = emb_dim 94 | else: 95 | self.num_classes = in_vocab_size 96 | 97 | self.rnn_func = QuasiLSTMlayer( 98 | input_dim=rnn_input_size, hidden_dim=hidden_size) 99 | 100 | self.output_gate = nn.Linear( 101 | rnn_input_size + hidden_size, hidden_size) 102 | 103 | self.dropout = dropout 104 | if dropout > 0.: 105 | self.dropout = nn.Dropout(dropout) 106 | self.out_layer = nn.Linear(hidden_size, out_vocab_size) 107 | 108 | def forward(self, x): 109 | if self.no_embedding: 110 | out = torch.nn.functional.one_hot(x, self.num_classes).permute(1, 0, 2).float() 111 | else: 112 | out = self.embedding(x).permute(1, 0, 2) # seq dim first 113 | 114 | # if self.dropout: 115 | # out = self.dropout(out) 116 | cell_out = self.rnn_func(out) 117 | 118 | gate_out = self.output_gate(torch.cat([out, cell_out], dim=-1)) 119 | gate_out = torch.sigmoid(gate_out) 120 | gate_out = cell_out * gate_out 121 | 122 | if self.dropout: 123 | gate_out = self.dropout(gate_out) 124 | logits = self.out_layer(gate_out).permute(1, 0, 2) 125 | 126 | return logits 127 | 128 | 129 | class RTRLQuasiLSTMModel(BaseModel): 130 | def __init__(self, emb_dim, hidden_size, in_vocab_size, out_vocab_size, 131 | dropout=0.0, num_layers=1, no_embedding=False): 132 | super().__init__() 133 | self.in_vocab_size = in_vocab_size 134 | self.out_vocab_size = out_vocab_size 135 | self.hidden_size = hidden_size 136 | 137 | self.no_embedding = no_embedding 138 | rnn_input_size = in_vocab_size 139 | self.num_classes = in_vocab_size 140 | if not no_embedding: 141 | self.embedding = nn.Embedding( 142 | num_embeddings=in_vocab_size, embedding_dim=emb_dim) 143 | rnn_input_size = emb_dim 144 | 145 | self.rnn_func = RTRLQuasiLSTMlayer( 146 | input_dim=rnn_input_size, hidden_dim=hidden_size) 147 | 148 | self.output_gate = nn.Linear( 149 | rnn_input_size + hidden_size, hidden_size) 150 | 151 | self.dropout = dropout 152 | if dropout > 0.: 153 | self.dropout = nn.Dropout(dropout) 154 | self.out_layer = nn.Linear(hidden_size, out_vocab_size) 155 | 156 | def forward(self, x, state): 157 | if self.no_embedding: 158 | out = torch.nn.functional.one_hot(x, self.num_classes).float() 159 | else: 160 | out = self.embedding(x) # seq dim first 161 | 162 | # RTRLQuasiLSTMlayer can take inputs of shape (B, dim) 163 | cell_out, state = self.rnn_func(out, state) 164 | cell_out.requires_grad_() 165 | cell_out.retain_grad() 166 | 167 | gate_out = self.output_gate(torch.cat([out, cell_out], dim=-1)) 168 | gate_out = torch.sigmoid(gate_out) 169 | gate_out = cell_out * gate_out 170 | 171 | logits = self.out_layer(gate_out) 172 | 173 | return logits, cell_out, state 174 | 175 | def compute_gradient_rtrl(self, top_grad_, rtrl_state): 176 | Z_state, F_state, wz_state, wf_state, bz_state, bf_state = rtrl_state 177 | 178 | self.rnn_func.wm_z.grad += (top_grad_.unsqueeze(-1) * Z_state).sum(dim=0) 179 | self.rnn_func.wm_f.grad += (top_grad_.unsqueeze(-1) * F_state).sum(dim=0) 180 | 181 | self.rnn_func.wv_z.grad += (top_grad_ * wz_state).sum(dim=0) 182 | self.rnn_func.wv_f.grad += (top_grad_ * wf_state).sum(dim=0) 183 | 184 | self.rnn_func.bias_z.grad += (top_grad_ * bz_state).sum(dim=0) 185 | self.rnn_func.bias_f.grad += (top_grad_ * bf_state).sum(dim=0) 186 | 187 | def get_init_states(self, batch_size, device): 188 | return self.rnn_func.get_init_states(batch_size, device) 189 | 190 | def rtrl_reset_grad(self): 191 | self.rnn_func.wm_z.grad = torch.zeros_like(self.rnn_func.wm_z) 192 | self.rnn_func.wm_f.grad = torch.zeros_like(self.rnn_func.wm_f) 193 | 194 | self.rnn_func.wv_z.grad = torch.zeros_like(self.rnn_func.wv_z) 195 | self.rnn_func.wv_f.grad = torch.zeros_like(self.rnn_func.wv_f) 196 | 197 | self.rnn_func.bias_z.grad = torch.zeros_like(self.rnn_func.bias_z) 198 | self.rnn_func.bias_f.grad = torch.zeros_like(self.rnn_func.bias_f) 199 | -------------------------------------------------------------------------------- /diagnostic/utils/copy_task_generator.py: -------------------------------------------------------------------------------- 1 | # Code to generate the data set for copy task. 2 | # 3 | # Sketch: 4 | # For a given max code length T, 5 | # sample length t between 1 and T, 6 | # randomly sample 0 or 1 for the given sequence length t to get `seq` 7 | # create a sequence of the same length filled with `2` (i.e, special token) `2*` 8 | # input = `seq` + `2*` 9 | # output = `2` + `seq` 10 | 11 | import sys 12 | from random import randrange as drw 13 | import random 14 | import numpy as np 15 | 16 | # Task hyper-parameters 17 | rnd_seed = 42 18 | random.seed(rnd_seed) 19 | np.random.seed(rnd_seed) 20 | 21 | 22 | # Get number of characters in the string w/o spaces 23 | # NB: line break '\n' counts as one character. 24 | def num_token(string): 25 | return len(string.split()) 26 | 27 | 28 | # max_seq_length is the max seq length of the pattern/code to be memorized 29 | # for length-padding, use the same token as memory token and skip from the loss 30 | def get_data_pair(max_seq_length, pad_id=2): 31 | '''Get one example of input/output pair.''' 32 | slen = drw(1, max_seq_length + 1) 33 | pattern = np.random.randint(2, size=slen) 34 | spaces = np.ones_like(pattern) * pad_id 35 | padding = np.ones([max_seq_length - slen]).astype(int) * pad_id 36 | input_str = np.concatenate((pattern, spaces, padding)) 37 | tgt_str = np.concatenate((spaces, pattern, padding)) 38 | 39 | input_str = ' '.join(map(str, input_str)) 40 | tgt_str = ' '.join(map(str, tgt_str)) 41 | 42 | return input_str, tgt_str 43 | 44 | # Visualize alignment 45 | def visualize(code_str, tgt_str): 46 | 47 | print("=== Code string ============ ") 48 | print(code_str) 49 | 50 | print("\n=== Target string ========== ") 51 | print(tgt_str) 52 | 53 | print("=== END ") 54 | 55 | 56 | if __name__ == '__main__': 57 | 58 | import argparse 59 | from tqdm import tqdm 60 | 61 | parser = argparse.ArgumentParser(description='Generate data.') 62 | parser.add_argument('--dump_dir', 63 | required=True, help='where to store the data') 64 | parser.add_argument('--train_size', required=False, default=10000, 65 | type=int, help='Number of examples in the train set.') 66 | parser.add_argument('--valid_size', required=False, default=1000, 67 | type=int, help='Number of examples in the valid set.') 68 | parser.add_argument('--test_size', required=False, default=1000, 69 | type=int, help='Number of examples in the test set.') 70 | 71 | parser.add_argument('--code_length', required=False, default=50, 72 | type=int, help='Number of statements in each example.') 73 | parser.add_argument('--show_example', required=False, action='store_true', 74 | help='Only show one example.') 75 | 76 | args = parser.parse_args() 77 | 78 | in_sfx = ".src" 79 | out_sfx = ".tgt" 80 | 81 | train_file_name = f"train_{args.code_length}" 82 | valid_file_name = f"valid_{args.code_length}" 83 | test_file_name = f"test_{args.code_length}" 84 | 85 | tr_src = f"{args.dump_dir}/{train_file_name}{in_sfx}" 86 | tr_tgt = f"{args.dump_dir}/{train_file_name}{out_sfx}" 87 | 88 | valid_src = f"{args.dump_dir}/{valid_file_name}{in_sfx}" 89 | valid_tgt = f"{args.dump_dir}/{valid_file_name}{out_sfx}" 90 | 91 | test_src = f"{args.dump_dir}/{test_file_name}{in_sfx}" 92 | test_tgt = f"{args.dump_dir}/{test_file_name}{out_sfx}" 93 | 94 | if args.show_example: 95 | code_str, tgt_str = get_data_pair(args.code_length) 96 | visualize(code_str, tgt_str) 97 | sys.exit(0) 98 | 99 | # train 100 | print("Generating train data...") 101 | with open(tr_src, 'a') as txt_in, open(tr_tgt, 'a') as txt_out: 102 | for i in tqdm(range(args.train_size)): 103 | code_str, tgt_str = get_data_pair(args.code_length) 104 | # input_seq = ' '.join(code_str.split()) 105 | input_seq = code_str 106 | output_seq = tgt_str 107 | # visualize(code_str, tgt_str) 108 | # print(input_seq) 109 | # print(tgt_str) 110 | if i != args.train_size - 1: 111 | txt_in.write(input_seq + '\n') 112 | txt_out.write(output_seq + '\n') 113 | 114 | # valid 115 | print("done.") 116 | print("Generating valid data...") 117 | with open(valid_src, 'a') as txt_in, open(valid_tgt, 'a') as txt_out: 118 | for i in tqdm(range(args.valid_size)): 119 | code_str, tgt_str = get_data_pair(args.code_length) 120 | # input_seq = ' '.join(code_str.split()) 121 | input_seq = code_str 122 | output_seq = tgt_str 123 | # visualize(code_str, tgt_str) 124 | # print(input_seq) 125 | # print(tgt_str) 126 | 127 | if i != args.valid_size - 1: 128 | txt_in.write(input_seq + '\n') 129 | txt_out.write(output_seq + '\n') 130 | 131 | # test 132 | print("done.") 133 | print("Generating test data...") 134 | with open(test_src, 'a') as txt_in, open(test_tgt, 'a') as txt_out: 135 | for i in tqdm(range(args.test_size)): 136 | code_str, tgt_str = get_data_pair(args.code_length) 137 | input_seq = code_str 138 | # input_seq = ' '.join(code_str.split()) 139 | output_seq = tgt_str 140 | # visualize(code_str, tgt_str) 141 | # print(input_seq) 142 | # print(tgt_str) 143 | if i != args.test_size - 1: 144 | txt_in.write(input_seq + '\n') 145 | txt_out.write(output_seq + '\n') 146 | -------------------------------------------------------------------------------- /reinforcement_learning/README.md: -------------------------------------------------------------------------------- 1 | # Reinforcement Learning Experiments using Real-Time Recurrent Actor-Critic Method (R2AC) 2 | 3 | This repository is originally forked from [IDSIA/modern-srwm/reinforcement_learning](https://github.com/IDSIA/modern-srwm/tree/main/reinforcement_learning) which itself is a fork of the public PyTorch implementation of IMPALA, [Torchbeast](https://github.com/facebookresearch/torchbeast). 4 | 5 | ## Requirements 6 | * We use the `Polybeast` version of [Torchbeast](https://github.com/facebookresearch/torchbeast). 7 | * [DMLab](https://github.com/deepmind/lab) and [ProcGen](https://github.com/openai/procgen). 8 | 9 | We refer to instructions in the original repositories to install these packages. Please check the corresponding requirements. 10 | Note that intalling Polybeast or DMLab might not be straightforward depending on your system. 11 | 12 | * Optionally: install `wandb` for monitoring jobs (by using the `--use_wandb` flag) 13 | 14 | * We used PyTorch version `>= 1.4.0` for our experiments 15 | 16 | ## Training 17 | 18 | **NB: in the code, "eLSTM" (the name of our RNN architecture in the paper) is called "QuasiLSTM".** 19 | 20 | We have a separate main training code file for each environment: DMLab, ProcGen, and Atari. 21 | See example scripts below. 22 | * `--use_rtrl` flag should be removed to train a feedforward agent, 23 | * or it should be replaced by `--use_quasi_lstm` for the TBPTT-trained eLSTM 24 | * or use `--use_quasi_full_lstm` for TBPTT-trained "feLSTM" or `--use_snap` for SnAp-trained "feLSTM" 25 | 26 | Logs of our experiments/figures (~3 GB uncompressed) can be downloaded from [here/google-drive](https://drive.google.com/file/d/1d4EhyGzVMEILZdeIMXnE_7-OfeW8yrrR/view?usp=sharing). 27 | 28 | ### DMLab 29 | 30 | * Training from scratch: 31 | ``` 32 | SAVE_DIR=saved_models_dmlab 33 | 34 | GAME=rooms_keys_doors_puzzle 35 | MODEL=rtrl_elstm 36 | SEED= 37 | LEN=100 38 | SIZE=512 39 | 40 | python -m torchbeast_dmlab.polybeast \ 41 | --single_gpu \ 42 | --use_wandb \ 43 | --use_rtrl \ 44 | --seed ${SEED} \ 45 | --env ${GAME} \ 46 | --pipes_basename "unix:/tmp/pb_rgbdmlab_${MODEL}_${GAME}_${SIZE}_len${LEN}_seed${SEED}" \ 47 | --validate_every 240000 \ 48 | --disable_validation \ 49 | --validate_step_every 1_000_000 \ 50 | --num_actors 48 \ 51 | --num_servers 48 \ 52 | --total_steps 100_000_000 \ 53 | --learning_rate 0.0006 \ 54 | --grad_norm_clipping 40 \ 55 | --epsilon 0.01 \ 56 | --entropy_cost 0.01 \ 57 | --batch_size 32 \ 58 | --unroll_length ${LEN} \ 59 | --hidden_size ${SIZE} \ 60 | --num_learner_threads 1 \ 61 | --num_inference_threads 1 \ 62 | --project_name "2023_${GAME}" \ 63 | --xpid "rgb_${MODEL}_${GAME}_${MODE}_${SIZE}_len${LEN}_seed${SEED}" \ 64 | --savedir ${SAVE_DIR} 65 | ``` 66 | 67 | * Training using a pre-trained vision stem: 68 | ``` 69 | SAVE_DIR=saved_models_dmlab 70 | 71 | GAME=rooms_watermaze 72 | MODEL=frozen_rtrl_elstm 73 | SEED= 74 | LEN=100 75 | SIZE=512 76 | 77 | PRETRAIN='pretrained_models/me/model.tar' 78 | 79 | python -m torchbeast_dmlab.polybeast \ 80 | --single_gpu \ 81 | --use_rtrl \ 82 | --use_wandb \ 83 | --load_conv_net_from ${PRETRAIN} \ 84 | --freeze_conv \ 85 | --freeze_fc \ 86 | --env ${GAME} \ 87 | --pipes_basename "unix:/tmp/pb_rgbdmlab_${MODEL}_${GAME}_${SIZE}_len${LEN}_seed${SEED}" \ 88 | --validate_every 240000 \ 89 | --disable_validation \ 90 | --validate_step_every 1_000_000 \ 91 | --num_actors 48 \ 92 | --num_servers 48 \ 93 | --total_steps 100_000_000 \ 94 | --learning_rate 0.0006 \ 95 | --grad_norm_clipping 40 \ 96 | --epsilon 0.01 \ 97 | --entropy_cost 0.01 \ 98 | --batch_size 32 \ 99 | --unroll_length ${LEN} \ 100 | --hidden_size ${SIZE} \ 101 | --num_learner_threads 1 \ 102 | --num_inference_threads 1 \ 103 | --project_name "2023_${GAME}_frozen" \ 104 | --xpid "rgb_${MODEL}_${GAME}_${MODE}_${SIZE}_len${LEN}_seed${SEED}" \ 105 | --savedir ${SAVE_DIR} 106 | ``` 107 | 108 | ### ProcGen 109 | ``` 110 | SAVE_DIR=saved_models 111 | 112 | GAME=chaser 113 | MODE=hard 114 | MODEL=rtrl_quasi_lstm 115 | SEED= 116 | LEN=50 117 | SIZE=256 118 | 119 | python -m torchbeast_procgen.polybeast \ 120 | --single_gpu \ 121 | --use_wandb \ 122 | --use_rtrl \ 123 | --env procgen:procgen-${GAME}-v0 \ 124 | --pipes_basename "unix:/tmp/pb_${MODEL}_${GAME}_${SIZE}_len${LEN}_seed${SEED}" \ 125 | --validate_every 60 \ 126 | --num_actors 48 \ 127 | --num_servers 48 \ 128 | --total_steps 200_000_000 \ 129 | --save_extra_checkpoint 50_000_000 \ 130 | --learning_rate 0.0006 \ 131 | --grad_norm_clipping 40 \ 132 | --epsilon 0.01 \ 133 | --entropy_cost 0.01 \ 134 | --batch_size 32 \ 135 | --unroll_length ${LEN} \ 136 | --num_actions 15 \ 137 | --hidden_size ${SIZE} \ 138 | --num_learner_threads 1 \ 139 | --num_inference_threads 1 \ 140 | --project_name "2023_${GAME}" \ 141 | --xpid "${MODEL}_${GAME}_${MODE}_${SIZE}_len${LEN}_seed${SEED}" \ 142 | --num_levels 500 \ 143 | --start_level 0 \ 144 | --distribution_mode ${MODE} \ 145 | --valid_distribution_mode ${MODE} \ 146 | --valid_num_levels 500 \ 147 | --valid_start_level 500 \ 148 | --valid_num_episodes 10 \ 149 | --savedir ${SAVE_DIR} 150 | ``` 151 | 152 | ### Atari 153 | For Atari, the number of actions should be changed for each game; see `list_atari_games.txt` 154 | ``` 155 | SAVE_DIR=saved_models_atari 156 | 157 | GAME=QbertNoFrameskip-v4 158 | MODEL=rtrl_quasi_lstm 159 | SEED= 160 | LEN=50 161 | SIZE=256 162 | 163 | python -m torchbeast_atari.polybeast \ 164 | --single_gpu \ 165 | --env ${GAME} \ 166 | --use_rtrl \ 167 | --use_wandb \ 168 | --pipes_basename "unix:/tmp/pb_${MODEL}_${GAME}_${SIZE}_len${LEN}_seed${SEED}" \ 169 | --disable_validation \ 170 | --validate_every 60 \ 171 | --num_actors 48 \ 172 | --num_servers 48 \ 173 | --disable_validation \ 174 | --validate_every 6000 \ 175 | --validate_step_every 10_000_000 \ 176 | --total_steps 200_000_000 \ 177 | --save_extra_checkpoint 50_000_000 \ 178 | --learning_rate 0.0006 \ 179 | --grad_norm_clipping 40 \ 180 | --epsilon 0.01 \ 181 | --entropy_cost 0.01 \ 182 | --batch_size 32 \ 183 | --unroll_length ${LEN} \ 184 | --num_actions 6 \ 185 | --hidden_size ${SIZE} \ 186 | --num_learner_threads 1 \ 187 | --num_inference_threads 1 \ 188 | --project_name "2023_${GAME}" \ 189 | --xpid "${MODEL}_${GAME}_${MODE}_${SIZE}_len${LEN}_seed${SEED}" \ 190 | --savedir ${SAVE_DIR} 191 | ``` 192 | 193 | ## Evaluation 194 | **The final evalution is automatically carried out at the end of training.** 195 | NB: `polybeast_learner.py` files typically also contain code for the eval-only mode; please ignore them; they are just copy-pasted from some random environments and not adapted to each environment. 196 | -------------------------------------------------------------------------------- /reinforcement_learning/atari_data.py: -------------------------------------------------------------------------------- 1 | # Taken from https://github.com/deepmind/dqn_zoo/blob/master/dqn_zoo/atari_data.py 2 | # 3 | # Copyright 2019 DeepMind Technologies Limited. All Rights Reserved. 4 | # 5 | # Licensed under the Apache License, Version 2.0 (the "License"); 6 | # you may not use this file except in compliance with the License. 7 | # You may obtain a copy of the License at 8 | # 9 | # http://www.apache.org/licenses/LICENSE-2.0 10 | # 11 | # Unless required by applicable law or agreed to in writing, software 12 | # distributed under the License is distributed on an "AS IS" BASIS, 13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | # See the License for the specific language governing permissions and 15 | # limitations under the License. 16 | # ============================================================================== 17 | """Utilities to compute human-normalized Atari scores. 18 | The data used in this module is human and random performance data on Atari-57. 19 | It comprises of evaluation scores (undiscounted returns), each averaged 20 | over at least 3 episode runs, on each of the 57 Atari games. Each episode begins 21 | with the environment already stepped with a uniform random number (between 1 and 22 | 30 inclusive) of noop actions. 23 | The two agents are: 24 | * 'random' (agent choosing its actions uniformly randomly on each step) 25 | * 'human' (professional human game tester) 26 | Scores are obtained by averaging returns over the episodes played by each agent, 27 | with episode length capped to 108,000 frames (i.e. timeout after 30 minutes). 28 | The term 'human-normalized' here means a linear per-game transformation of 29 | a game score in such a way that 0 corresponds to random performance and 1 30 | corresponds to human performance. 31 | """ 32 | 33 | # pylint: disable=g-bad-import-order 34 | 35 | import math 36 | 37 | # Game: score-tuple dictionary. Each score tuple contains 38 | # 0: score random (float) and 1: score human (float). 39 | _ATARI_DATA = { 40 | 'alien': (227.8, 7127.7), 41 | 'amidar': (5.8, 1719.5), 42 | 'assault': (222.4, 742.0), 43 | 'asterix': (210.0, 8503.3), 44 | 'asteroids': (719.1, 47388.7), 45 | 'atlantis': (12850.0, 29028.1), 46 | 'bank_heist': (14.2, 753.1), 47 | 'battle_zone': (2360.0, 37187.5), 48 | 'beam_rider': (363.9, 16926.5), 49 | 'berzerk': (123.7, 2630.4), 50 | 'bowling': (23.1, 160.7), 51 | 'boxing': (0.1, 12.1), 52 | 'breakout': (1.7, 30.5), 53 | 'centipede': (2090.9, 12017.0), 54 | 'chopper_command': (811.0, 7387.8), 55 | 'crazy_climber': (10780.5, 35829.4), 56 | 'defender': (2874.5, 18688.9), 57 | 'demon_attack': (152.1, 1971.0), 58 | 'double_dunk': (-18.6, -16.4), 59 | 'enduro': (0.0, 860.5), 60 | 'fishing_derby': (-91.7, -38.7), 61 | 'freeway': (0.0, 29.6), 62 | 'frostbite': (65.2, 4334.7), 63 | 'gopher': (257.6, 2412.5), 64 | 'gravitar': (173.0, 3351.4), 65 | 'hero': (1027.0, 30826.4), 66 | 'ice_hockey': (-11.2, 0.9), 67 | 'jamesbond': (29.0, 302.8), 68 | 'kangaroo': (52.0, 3035.0), 69 | 'krull': (1598.0, 2665.5), 70 | 'kung_fu_master': (258.5, 22736.3), 71 | 'montezuma_revenge': (0.0, 4753.3), 72 | 'ms_pacman': (307.3, 6951.6), 73 | 'name_this_game': (2292.3, 8049.0), 74 | 'phoenix': (761.4, 7242.6), 75 | 'pitfall': (-229.4, 6463.7), 76 | 'pong': (-20.7, 14.6), 77 | 'private_eye': (24.9, 69571.3), 78 | 'qbert': (163.9, 13455.0), 79 | 'riverraid': (1338.5, 17118.0), 80 | 'road_runner': (11.5, 7845.0), 81 | 'robotank': (2.2, 11.9), 82 | 'seaquest': (68.4, 42054.7), 83 | 'skiing': (-17098.1, -4336.9), 84 | 'solaris': (1236.3, 12326.7), 85 | 'space_invaders': (148.0, 1668.7), 86 | 'star_gunner': (664.0, 10250.0), 87 | 'surround': (-10.0, 6.5), 88 | 'tennis': (-23.8, -8.3), 89 | 'time_pilot': (3568.0, 5229.2), 90 | 'tutankham': (11.4, 167.6), 91 | 'up_n_down': (533.4, 11693.2), 92 | 'venture': (0.0, 1187.5), 93 | # Note the random agent score on Video Pinball is sometimes greater than the 94 | # human score under other evaluation methods. 95 | 'video_pinball': (16256.9, 17667.9), 96 | 'wizard_of_wor': (563.5, 4756.5), 97 | 'yars_revenge': (3092.9, 54576.9), 98 | 'zaxxon': (32.5, 9173.3), 99 | } 100 | 101 | _RANDOM_COL = 0 102 | _HUMAN_COL = 1 103 | 104 | ATARI_GAMES = tuple(sorted(_ATARI_DATA.keys())) 105 | 106 | 107 | def get_human_normalized_score(game: str, raw_score: float) -> float: 108 | """Converts game score to human-normalized score.""" 109 | game_scores = _ATARI_DATA.get(game, (math.nan, math.nan)) 110 | random, human = game_scores[_RANDOM_COL], game_scores[_HUMAN_COL] 111 | return (raw_score - random) / (human - random) 112 | 113 | -------------------------------------------------------------------------------- /reinforcement_learning/list_atari_games.txt: -------------------------------------------------------------------------------- 1 | AdventureNoFrameskip-v4 18 2 | AirRaidNoFrameskip-v4 6 3 | AlienNoFrameskip-v4 18 4 | AmidarNoFrameskip-v4 10 5 | AssaultNoFrameskip-v4 7 6 | AsterixNoFrameskip-v4 9 7 | AsteroidsNoFrameskip-v4 14 8 | AtlantisNoFrameskip-v4 4 9 | BankHeistNoFrameskip-v4 18 10 | BattleZoneNoFrameskip-v4 18 11 | BeamRiderNoFrameskip-v4 9 12 | BerzerkNoFrameskip-v4 18 13 | BowlingNoFrameskip-v4 6 14 | BoxingNoFrameskip-v4 18 15 | BreakoutNoFrameskip-v4 4 16 | CarnivalNoFrameskip-v4 6 17 | CentipedeNoFrameskip-v4 18 18 | ChopperCommandNoFrameskip-v4 18 19 | CrazyClimberNoFrameskip-v4 9 20 | DemonAttackNoFrameskip-v4 6 21 | DoubleDunkNoFrameskip-v4 18 22 | ElevatorActionNoFrameskip-v4 18 23 | EnduroNoFrameskip-v4 9 24 | FishingDerbyNoFrameskip-v4 18 25 | FreewayNoFrameskip-v4 3 26 | FrostbiteNoFrameskip-v4 18 27 | GopherNoFrameskip-v4 8 28 | GravitarNoFrameskip-v4 18 29 | HeroNoFrameskip-v4 18 30 | IceHockeyNoFrameskip-v4 18 31 | JamesbondNoFrameskip-v4 18 32 | JourneyEscapeNoFrameskip-v4 16 33 | KangarooNoFrameskip-v4 18 34 | KrullNoFrameskip-v4 18 35 | KungFuMasterNoFrameskip-v4 14 36 | MontezumaRevengeNoFrameskip-v4 18 37 | MsPacmanNoFrameskip-v4 9 38 | NameThisGameNoFrameskip-v4 6 39 | PhoenixNoFrameskip-v4 8 40 | PitfallNoFrameskip-v4 18 41 | PongNoFrameskip-v4 6 42 | PooyanNoFrameskip-v4 6 43 | PrivateEyeNoFrameskip-v4 18 44 | QbertNoFrameskip-v4 6 45 | RiverraidNoFrameskip-v4 18 46 | RoadRunnerNoFrameskip-v4 18 47 | RobotankNoFrameskip-v4 18 48 | SeaquestNoFrameskip-v4 18 49 | SkiingNoFrameskip-v4 3 50 | SolarisNoFrameskip-v4 18 51 | SpaceInvadersNoFrameskip-v4 6 52 | StarGunnerNoFrameskip-v4 18 53 | TennisNoFrameskip-v4 18 54 | TimePilotNoFrameskip-v4 10 55 | TutankhamNoFrameskip-v4 8 56 | UpNDownNoFrameskip-v4 6 57 | VentureNoFrameskip-v4 18 58 | VideoPinballNoFrameskip-v4 9 59 | WizardOfWorNoFrameskip-v4 10 60 | YarsRevengeNoFrameskip-v4 18 61 | ZaxxonNoFrameskip-v4 18 62 | -------------------------------------------------------------------------------- /reinforcement_learning/nest/README.md: -------------------------------------------------------------------------------- 1 | 2 | # Nest library 3 | 4 | ```shell 5 | CXX=c++ pip install . -vv 6 | ``` 7 | 8 | Usage in Python: 9 | 10 | ```python 11 | import torch 12 | import nest 13 | 14 | t1 = torch.tensor(0) 15 | t2 = torch.tensor(1) 16 | d = {'hey': torch.tensor(2)} 17 | 18 | print(nest.map(lambda t: t + 42, (t1, t2, d))) 19 | # --> (tensor(42), tensor(43), {'hey': tensor(44)}) 20 | ``` 21 | -------------------------------------------------------------------------------- /reinforcement_learning/nest/nest/nest_pybind.cc: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright (c) Facebook, Inc. and its affiliates. 3 | * 4 | * Licensed under the Apache License, Version 2.0 (the "License"); 5 | * you may not use this file except in compliance with the License. 6 | * You may obtain a copy of the License at 7 | * 8 | * http://www.apache.org/licenses/LICENSE-2.0 9 | * 10 | * Unless required by applicable law or agreed to in writing, software 11 | * distributed under the License is distributed on an "AS IS" BASIS, 12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | * See the License for the specific language governing permissions and 14 | * limitations under the License. 15 | */ 16 | 17 | #include 18 | #include 19 | #include 20 | 21 | #include "nest.h" 22 | #include "nest_pybind.h" 23 | 24 | namespace py = pybind11; 25 | 26 | typedef nest::Nest PyNest; 27 | 28 | class py_list_back_inserter { 29 | public: 30 | py_list_back_inserter(py::list &l) : list_(&l) {} 31 | py_list_back_inserter &operator=(const py::object &value) { 32 | list_->append(value); 33 | return *this; 34 | }; 35 | constexpr py_list_back_inserter &operator*() { return *this; }; 36 | constexpr py_list_back_inserter &operator++() { return *this; } 37 | constexpr py_list_back_inserter &operator++(int) { return *this; } 38 | 39 | private: 40 | py::list *list_; 41 | }; 42 | 43 | PYBIND11_MODULE(nest, m) { 44 | m.def("map", [](py::function f, const PyNest &n) { 45 | // This says const py::object, but f can actually modify it! 46 | std::function cppf = 47 | [&f](const py::object &arg) { return f(arg); }; 48 | return n.map(cppf); 49 | }); 50 | m.def("map_many", 51 | [](const std::function &)> &f, 52 | py::args args) { 53 | std::vector nests = args.cast>(); 54 | return PyNest::zip(nests).map(f); 55 | }); 56 | m.def("map_many2", [](const std::function &f, 58 | const PyNest &n1, const PyNest &n2) { 59 | try { 60 | return PyNest::map2(f, n1, n2); 61 | } catch (const std::invalid_argument &e) { 62 | // IDK why I have to do this manually. 63 | throw py::value_error(e.what()); 64 | } 65 | }); 66 | m.def("flatten", [](const PyNest &n) { 67 | py::list result; 68 | n.flatten(py_list_back_inserter(result)); 69 | return result; 70 | }); 71 | m.def("pack_as", [](const PyNest &n, const py::sequence &sequence) { 72 | try { 73 | return n.pack_as(sequence.begin(), sequence.end()); 74 | } catch (const std::exception &e) { 75 | // PyTorch pybind11 doesn't seem to translate exceptions? 76 | throw py::value_error(e.what()); 77 | } 78 | }); 79 | m.def("front", [](const PyNest &n) { return n.front(); }); 80 | } 81 | -------------------------------------------------------------------------------- /reinforcement_learning/nest/nest/nest_pybind.h: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright (c) Facebook, Inc. and its affiliates. 3 | * 4 | * Licensed under the Apache License, Version 2.0 (the "License"); 5 | * you may not use this file except in compliance with the License. 6 | * You may obtain a copy of the License at 7 | * 8 | * http://www.apache.org/licenses/LICENSE-2.0 9 | * 10 | * Unless required by applicable law or agreed to in writing, software 11 | * distributed under the License is distributed on an "AS IS" BASIS, 12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | * See the License for the specific language governing permissions and 14 | * limitations under the License. 15 | */ 16 | 17 | #pragma once 18 | 19 | #include 20 | #include 21 | 22 | #include "nest.h" 23 | 24 | namespace pybind11 { 25 | namespace detail { 26 | template 27 | struct type_caster> { 28 | using ValueNest = nest::Nest; 29 | using value_conv = make_caster; 30 | 31 | public: 32 | PYBIND11_TYPE_CASTER(ValueNest, _("Nest[") + value_conv::name + _("]")); 33 | 34 | bool load(handle src, bool convert) { 35 | if (!src.ptr()) { 36 | return false; 37 | } 38 | if (isinstance(src) || isinstance(src)) { 39 | value.value = std::move(src).cast>(); 40 | return true; 41 | } 42 | if (isinstance(src)) { 43 | value.value = std::move(src).cast>(); 44 | return true; 45 | } 46 | 47 | value_conv conv; 48 | if (!conv.load(src, convert)) return false; 49 | 50 | value.value = cast_op(std::move(conv)); 51 | return true; 52 | } 53 | 54 | static handle cast(ValueNest&& src, return_value_policy policy, 55 | handle parent) { 56 | return std::visit( 57 | nest::overloaded{ 58 | [&policy, &parent](Value&& t) { 59 | return value_conv::cast(std::move(t), policy, parent); 60 | }, 61 | [&policy, &parent](std::vector&& v) { 62 | object py_list = reinterpret_steal( 63 | list_caster, ValueNest>::cast( 64 | std::move(v), policy, parent)); 65 | 66 | return handle(PyList_AsTuple(py_list.ptr())); 67 | }, 68 | [&policy, &parent](std::map&& m) { 69 | return map_caster::cast( 70 | std::move(m), policy, parent); 71 | }}, 72 | std::move(src.value)); 73 | } 74 | 75 | static handle cast(const ValueNest& src, return_value_policy policy, 76 | handle parent) { 77 | return std::visit( 78 | nest::overloaded{ 79 | [&policy, &parent](const Value& t) { 80 | return value_conv::cast(t, policy, parent); 81 | }, 82 | [&policy, &parent](const std::vector& v) { 83 | object py_list = reinterpret_steal( 84 | list_caster, ValueNest>::cast( 85 | v, policy, parent)); 86 | 87 | return handle(PyList_AsTuple(py_list.ptr())); 88 | }, 89 | [&policy, &parent](const std::map& m) { 90 | return map_caster::cast( 91 | m, policy, parent); 92 | }}, 93 | src.value); 94 | } 95 | }; 96 | } // namespace detail 97 | } // namespace pybind11 98 | -------------------------------------------------------------------------------- /reinforcement_learning/nest/nest_test.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import sys 16 | import unittest 17 | 18 | import nest 19 | import torch 20 | 21 | 22 | class NestTest(unittest.TestCase): 23 | def setUp(self): 24 | self.n1 = ("Test", ["More", 32], {"h": 4}) 25 | self.n2 = ("Test", ("More", 32, (None, 43, ())), {"h": 4}) 26 | 27 | def test_nest_flatten_no_asserts(self): 28 | t = torch.tensor(1) 29 | t2 = torch.tensor(2) 30 | n = (t, t2) 31 | d = {"hey": t} 32 | 33 | nest.flatten((t, t2)) 34 | nest.flatten(d) 35 | nest.flatten((d, t)) 36 | nest.flatten((d, n, t)) 37 | 38 | nest.flatten(((t, t2), (t, t2))) 39 | 40 | nest.flatten(self.n1) 41 | nest.flatten(self.n2) 42 | 43 | d2 = {"hey": t2, "there": d, "more": t2} 44 | nest.flatten(d2) # Careful here, order not necessarily as above. 45 | 46 | def test_nest_map(self): 47 | t1 = torch.tensor(0) 48 | t2 = torch.tensor(1) 49 | d = {"hey": t2} 50 | 51 | n = nest.map(lambda t: t + 42, (t1, t2)) 52 | 53 | self.assertSequenceEqual(n, [t1 + 42, t2 + 42]) 54 | self.assertSequenceEqual(n, nest.flatten(n)) 55 | 56 | n1 = (d, n, t1) 57 | n2 = nest.map(lambda t: t * 2, n1) 58 | 59 | self.assertEqual(n2[0], {"hey": torch.tensor(2)}) 60 | self.assertEqual(n2[1], (torch.tensor(84), torch.tensor(86))) 61 | self.assertEqual(n2[2], torch.tensor(0)) 62 | 63 | t = torch.tensor(42) 64 | 65 | # Doesn't work with pybind11/functional.h, but does with py::function. 66 | self.assertEqual(nest.map(t.add, t2), torch.tensor(43)) 67 | 68 | def test_nest_flatten(self): 69 | self.assertEqual(nest.flatten(None), [None]) 70 | self.assertEqual(nest.flatten(self.n1), ["Test", "More", 32, 4]) 71 | 72 | def test_nest_pack_as(self): 73 | self.assertEqual(self.n2, nest.pack_as(self.n2, nest.flatten(self.n2))) 74 | 75 | with self.assertRaisesRegex(ValueError, "didn't exhaust sequence"): 76 | nest.pack_as(self.n2, nest.flatten(self.n2) + [None]) 77 | with self.assertRaisesRegex(ValueError, "Too few elements"): 78 | nest.pack_as(self.n2, nest.flatten(self.n2)[1:]) 79 | 80 | def test_nest_map_many2(self): 81 | def f(a, b): 82 | return (b, a) 83 | 84 | self.assertEqual(nest.map_many2(f, (1, 2), (3, 4)), ((3, 1), (4, 2))) 85 | 86 | with self.assertRaisesRegex(ValueError, "got 2 vs 1"): 87 | nest.map_many2(f, (1, 2), (3,)) 88 | 89 | self.assertEqual(nest.map_many2(f, {"a": 1}, {"a": 2}), {"a": (2, 1)}) 90 | 91 | with self.assertRaisesRegex(ValueError, "same keys"): 92 | nest.map_many2(f, {"a": 1}, {"b": 2}) 93 | 94 | with self.assertRaisesRegex(ValueError, "1 vs 0"): 95 | nest.map_many2(f, {"a": 1}, {}) 96 | 97 | with self.assertRaisesRegex(ValueError, "nests don't match"): 98 | nest.map_many2(f, {"a": 1}, ()) 99 | 100 | def test_nest_map_many(self): 101 | def f(a): 102 | return (a[1], a[0]) 103 | 104 | self.assertEqual(nest.map_many(f, (1, 2), (3, 4)), ((3, 1), (4, 2))) 105 | 106 | return 107 | with self.assertRaisesRegex(ValueError, "got 2 vs 1"): 108 | nest.map_many(f, (1, 2), (3,)) 109 | 110 | self.assertEqual(nest.map_many(f, {"a": 1}, {"a": 2}), {"a": (2, 1)}) 111 | 112 | with self.assertRaisesRegex(ValueError, "same keys"): 113 | nest.map_many(f, {"a": 1}, {"b": 2}) 114 | 115 | with self.assertRaisesRegex(ValueError, "1 vs 0"): 116 | nest.map_many(f, {"a": 1}, {}) 117 | 118 | with self.assertRaisesRegex(ValueError, "nests don't match"): 119 | nest.map_many(f, {"a": 1}, ()) 120 | 121 | def test_front(self): 122 | self.assertEqual(nest.front((1, 2, 3)), 1) 123 | self.assertEqual(nest.front((2, 3)), 2) 124 | self.assertEqual(nest.front((3,)), 3) 125 | 126 | def test_refcount(self): 127 | obj = "my very large and random string with numbers 1234" 128 | 129 | rc = sys.getrefcount(obj) 130 | 131 | # Test nest.front. This doesn't involve returning nests 132 | # from C++ to Python. 133 | nest.front((None, obj)) 134 | self.assertEqual(rc, sys.getrefcount(obj)) 135 | 136 | nest.front(obj) 137 | self.assertEqual(rc, sys.getrefcount(obj)) 138 | 139 | nest.front((obj,)) 140 | self.assertEqual(rc, sys.getrefcount(obj)) 141 | 142 | nest.front((obj, obj, [obj, {"obj": obj}, obj])) 143 | self.assertEqual(rc, sys.getrefcount(obj)) 144 | 145 | # Test returning nests of Nones. 146 | nest.map(lambda x: None, (obj, obj, [obj, {"obj": obj}, obj])) 147 | self.assertEqual(rc, sys.getrefcount(obj)) 148 | 149 | # Test returning actual nests. 150 | nest.map(lambda s: s, obj) 151 | self.assertEqual(rc, sys.getrefcount(obj)) 152 | 153 | nest.map(lambda x: x, {"obj": obj}) 154 | self.assertEqual(rc, sys.getrefcount(obj)) 155 | 156 | nest.map(lambda x: x, (obj,)) 157 | self.assertEqual(rc, sys.getrefcount(obj)) 158 | 159 | nest.map(lambda s: s, (obj, obj)) 160 | nest.map(lambda s: s, (obj, obj)) 161 | self.assertEqual(rc, sys.getrefcount(obj)) 162 | 163 | n = nest.map(lambda s: s, (obj,)) 164 | self.assertEqual(rc + 1, sys.getrefcount(obj)) 165 | del n 166 | self.assertEqual(rc, sys.getrefcount(obj)) 167 | 168 | 169 | if __name__ == "__main__": 170 | unittest.main() 171 | -------------------------------------------------------------------------------- /reinforcement_learning/nest/setup.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | # 16 | # CXX=c++ python3 setup.py build develop 17 | # or 18 | # CXX=c++ pip install . -vv 19 | # 20 | 21 | import sys 22 | 23 | import setuptools 24 | import setuptools.command.build_ext 25 | 26 | 27 | class get_pybind_include(object): 28 | """Helper class to determine the pybind11 include path 29 | 30 | The purpose of this class is to postpone importing pybind11 31 | until it is actually installed, so that the ``get_include()`` 32 | method can be invoked. """ 33 | 34 | def __init__(self, user=False): 35 | self.user = user 36 | 37 | def __str__(self): 38 | import pybind11 39 | 40 | return pybind11.get_include(self.user) 41 | 42 | 43 | ext_modules = [ 44 | setuptools.Extension( 45 | "nest", 46 | ["nest/nest_pybind.cc"], 47 | include_dirs=[ 48 | # Path to pybind11 headers 49 | get_pybind_include(), 50 | get_pybind_include(user=True), 51 | ], 52 | depends=["nest/nest.h", "nest/nest_pybind.h"], 53 | language="c++", 54 | extra_compile_args=["-std=c++17"], 55 | ) 56 | ] 57 | 58 | 59 | class BuildExt(setuptools.command.build_ext.build_ext): 60 | """A custom build extension for adding compiler-specific options.""" 61 | 62 | c_opts = {"msvc": ["/EHsc"], "unix": []} 63 | 64 | if sys.platform == "darwin": 65 | c_opts["unix"] += ["-stdlib=libc++", "-mmacosx-version-min=10.14"] 66 | 67 | def build_extensions(self): 68 | ct = self.compiler.compiler_type 69 | opts = self.c_opts.get(ct, []) 70 | if ct == "unix": 71 | opts.append('-DVERSION_INFO="%s"' % self.distribution.get_version()) 72 | opts.append("-std=c++17") 73 | opts.append("-fvisibility=hidden") 74 | elif ct == "msvc": 75 | opts.append('/DVERSION_INFO=\\"%s\\"' % self.distribution.get_version()) 76 | for ext in self.extensions: 77 | ext.extra_compile_args += opts 78 | if sys.platform == "darwin": 79 | ext.extra_link_args = ["-stdlib=libc++"] 80 | 81 | super().build_extensions() 82 | 83 | 84 | setuptools.setup( 85 | name="nest", 86 | version="0.0.3", 87 | author="TorchBeast team", 88 | ext_modules=ext_modules, 89 | headers=["nest/nest.h", "nest/nest_pybind.h"], 90 | cmdclass={"build_ext": BuildExt}, 91 | install_requires=["pybind11>=2.3"], 92 | setup_requires=["pybind11>=2.3"], 93 | ) 94 | -------------------------------------------------------------------------------- /reinforcement_learning/scripts/install_grpc.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | # Copyright (c) Facebook, Inc. and its affiliates. 4 | # 5 | # Licensed under the Apache License, Version 2.0 (the "License"); 6 | # you may not use this file except in compliance with the License. 7 | # You may obtain a copy of the License at 8 | # 9 | # http://www.apache.org/licenses/LICENSE-2.0 10 | # 11 | # Unless required by applicable law or agreed to in writing, software 12 | # distributed under the License is distributed on an "AS IS" BASIS, 13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | # See the License for the specific language governing permissions and 15 | # limitations under the License. 16 | 17 | set -e 18 | set -x 19 | 20 | if [ -z ${GRPC_DIR+x} ]; then 21 | GRPC_DIR=$(pwd)/third_party/grpc; 22 | fi 23 | 24 | PREFIX=${CONDA_PREFIX:-"$(dirname $(which conda))/../"} 25 | 26 | NPROCS=$(getconf _NPROCESSORS_ONLN) 27 | 28 | pushd ${GRPC_DIR} 29 | 30 | # Ask PyTorch if it has been compiled with -D_GLIBCXX_USE_CXX11_ABI=0 (old ABI). 31 | # See https://github.com/pytorch/pytorch/issues/17492. 32 | GLIBCXX_USE_CXX11_ABI=$(python3 -c "import torch; print(int(torch._C._GLIBCXX_USE_CXX11_ABI))") 33 | export EXTRA_CXXFLAGS="-D_GLIBCXX_USE_CXX11_ABI=$GLIBCXX_USE_CXX11_ABI" 34 | 35 | # Install protobuf. We don't use the conda package as PyTorch insists 36 | # on using a different ABI. 37 | pushd ${GRPC_DIR}/third_party/protobuf 38 | ./autogen.sh && ./configure --prefix=${PREFIX} \ 39 | CFLAGS="-fPIC" CXXFLAGS="-fPIC ${EXTRA_CXXFLAGS}" 40 | make -j ${NPROCS} && make install 41 | ldconfig || true 42 | popd 43 | 44 | # Make make find libprotobuf 45 | export PATH=${PREFIX}/bin:${PATH} 46 | export CPATH=${PREFIX}/include:${CPATH} 47 | export LIBRARY_PATH=${PREFIX}/lib:${LIBRARY_PATH} 48 | export LD_LIBRARY_PATH=${PREFIX}/lib:${LD_LIBRARY_PATH} 49 | 50 | make -j ${NPROCS} prefix=${PREFIX} EXTRA_CXXFLAGS=${EXTRA_CXXFLAGS} \ 51 | HAS_SYSTEM_PROTOBUF=true HAS_SYSTEM_CARES=false 52 | make prefix=${PREFIX} \ 53 | HAS_SYSTEM_PROTOBUF=true HAS_SYSTEM_CARES=false install 54 | 55 | popd 56 | -------------------------------------------------------------------------------- /reinforcement_learning/setup.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | # 16 | # CXX=c++ python3 setup.py build develop 17 | # or 18 | # CXX=c++ pip install . -vv 19 | # 20 | # Potentially also set TORCHBEAST_LIBS_PREFIX. 21 | 22 | import os 23 | import subprocess 24 | import sys 25 | import unittest 26 | 27 | import numpy as np 28 | import setuptools 29 | from torch.utils import cpp_extension 30 | 31 | 32 | PREFIX = os.getenv("CONDA_PREFIX") 33 | 34 | if os.getenv("TORCHBEAST_LIBS_PREFIX"): 35 | PREFIX = os.getenv("TORCHBEAST_LIBS_PREFIX") 36 | if not PREFIX: 37 | PREFIX = "/usr/local" 38 | 39 | 40 | def build_pb(): 41 | protoc = f"{PREFIX}/bin/protoc" 42 | 43 | # Hard-code client.proto for now. 44 | source = os.path.join(os.path.dirname(__file__), "libtorchbeast", "rpcenv.proto") 45 | output = source.replace(".proto", ".pb.cc") 46 | 47 | if os.path.exists(output) and ( 48 | os.path.exists(source) and os.path.getmtime(source) < os.path.getmtime(output) 49 | ): 50 | return 51 | 52 | print("calling protoc") 53 | if ( 54 | subprocess.call( 55 | [protoc, "--cpp_out=libtorchbeast", "-Ilibtorchbeast", "rpcenv.proto"] 56 | ) 57 | != 0 58 | ): 59 | sys.exit(-1) 60 | if ( 61 | subprocess.call( 62 | protoc + " --grpc_out=libtorchbeast -Ilibtorchbeast" 63 | " --plugin=protoc-gen-grpc=`which grpc_cpp_plugin`" 64 | " rpcenv.proto", 65 | shell=True, 66 | ) 67 | != 0 68 | ): 69 | sys.exit(-1) 70 | 71 | 72 | def test_suite(): 73 | test_loader = unittest.TestLoader() 74 | test_suite = test_loader.discover("tests", pattern="*_test.py") 75 | return test_suite 76 | 77 | 78 | class build_ext(cpp_extension.BuildExtension): 79 | def run(self): 80 | build_pb() 81 | cpp_extension.BuildExtension.run(self) 82 | 83 | 84 | def main(): 85 | extra_compile_args = [] 86 | extra_link_args = [] 87 | 88 | grpc_objects = [ 89 | f"{PREFIX}/lib/libgrpc++.a", 90 | f"{PREFIX}/lib/libgrpc.a", 91 | f"{PREFIX}/lib/libgpr.a", 92 | f"{PREFIX}/lib/libaddress_sorting.a", 93 | ] 94 | 95 | include_dirs = cpp_extension.include_paths() + [ 96 | np.get_include(), 97 | f"{PREFIX}/include", 98 | ] 99 | libraries = [] 100 | 101 | if sys.platform == "darwin": 102 | extra_compile_args += ["-stdlib=libc++", "-mmacosx-version-min=10.14"] 103 | extra_link_args += ["-stdlib=libc++", "-mmacosx-version-min=10.14"] 104 | 105 | # Relevant only when c-cares is not embedded in grpc, e.g. when 106 | # installing grpc via homebrew. 107 | libraries.append("cares") 108 | elif sys.platform == "linux": 109 | libraries.append("z") 110 | 111 | grpc_objects.append(f"{PREFIX}/lib/libprotobuf.a") 112 | 113 | libtorchbeast = cpp_extension.CppExtension( 114 | name="libtorchbeast._C", 115 | sources=[ 116 | "libtorchbeast/libtorchbeast.cc", 117 | "libtorchbeast/actorpool.cc", 118 | "libtorchbeast/rpcenv.cc", 119 | "libtorchbeast/rpcenv.pb.cc", 120 | "libtorchbeast/rpcenv.grpc.pb.cc", 121 | ], 122 | include_dirs=include_dirs, 123 | libraries=libraries, 124 | language="c++", 125 | extra_compile_args=["-std=c++17"] + extra_compile_args, 126 | extra_link_args=extra_link_args, 127 | extra_objects=grpc_objects, 128 | ) 129 | 130 | setuptools.setup( 131 | name="libtorchbeast", 132 | packages=["libtorchbeast"], 133 | version="0.0.14", 134 | ext_modules=[libtorchbeast], 135 | cmdclass={"build_ext": build_ext}, 136 | test_suite="setup.test_suite", 137 | ) 138 | 139 | 140 | if __name__ == "__main__": 141 | main() 142 | -------------------------------------------------------------------------------- /reinforcement_learning/tests/batching_queue_test.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | """Tests for actorpool.BatchingQueue. 15 | Basic functionalities actorpool.BatchingQueue are tested 16 | in libtorchbeast/actorpool_test.cc. 17 | """ 18 | 19 | import threading 20 | import time 21 | import unittest 22 | 23 | import numpy as np 24 | import torch 25 | import libtorchbeast 26 | 27 | 28 | class BatchingQueueTest(unittest.TestCase): 29 | def test_bad_construct(self): 30 | with self.assertRaisesRegex(ValueError, "Min batch size must be >= 1"): 31 | libtorchbeast.BatchingQueue( 32 | batch_dim=3, minimum_batch_size=0, maximum_batch_size=1 33 | ) 34 | 35 | with self.assertRaisesRegex( 36 | ValueError, "Max batch size must be >= min batch size" 37 | ): 38 | libtorchbeast.BatchingQueue( 39 | batch_dim=3, minimum_batch_size=1, maximum_batch_size=0 40 | ) 41 | 42 | def test_multiple_close_calls(self): 43 | queue = libtorchbeast.BatchingQueue() 44 | queue.close() 45 | with self.assertRaisesRegex(RuntimeError, "Queue was closed already"): 46 | queue.close() 47 | 48 | def test_check_inputs(self): 49 | queue = libtorchbeast.BatchingQueue(batch_dim=2) 50 | with self.assertRaisesRegex( 51 | ValueError, "Enqueued tensors must have more than batch_dim ==" 52 | ): 53 | queue.enqueue(torch.ones(5)) 54 | with self.assertRaisesRegex( 55 | ValueError, "Cannot enqueue empty vector of tensors" 56 | ): 57 | queue.enqueue([]) 58 | with self.assertRaisesRegex( 59 | libtorchbeast.ClosedBatchingQueue, "Enqueue to closed queue" 60 | ): 61 | queue.close() 62 | queue.enqueue(torch.ones(1, 1, 1)) 63 | 64 | def test_simple_run(self): 65 | queue = libtorchbeast.BatchingQueue( 66 | batch_dim=0, minimum_batch_size=1, maximum_batch_size=1 67 | ) 68 | 69 | inputs = torch.zeros(1, 2, 3) 70 | queue.enqueue(inputs) 71 | batch = next(queue) 72 | np.testing.assert_array_equal(batch, inputs) 73 | 74 | def test_batched_run(self, batch_size=2): 75 | queue = libtorchbeast.BatchingQueue( 76 | batch_dim=0, minimum_batch_size=batch_size, maximum_batch_size=batch_size 77 | ) 78 | 79 | inputs = [torch.full((1, 2, 3), i) for i in range(batch_size)] 80 | 81 | def enqueue_target(i): 82 | while queue.size() < i: 83 | # Make sure thread i calls enqueue before thread i + 1. 84 | time.sleep(0.05) 85 | queue.enqueue(inputs[i]) 86 | 87 | enqueue_threads = [] 88 | for i in range(batch_size): 89 | enqueue_threads.append( 90 | threading.Thread( 91 | target=enqueue_target, name=f"enqueue-thread-{i}", args=(i,) 92 | ) 93 | ) 94 | 95 | for t in enqueue_threads: 96 | t.start() 97 | 98 | batch = next(queue) 99 | np.testing.assert_array_equal(batch, torch.cat(inputs)) 100 | 101 | for t in enqueue_threads: 102 | t.join() 103 | 104 | 105 | class BatchingQueueProducerConsumerTest(unittest.TestCase): 106 | def test_many_consumers( 107 | self, enqueue_threads_number=16, repeats=100, dequeue_threads_number=64 108 | ): 109 | queue = libtorchbeast.BatchingQueue(batch_dim=0) 110 | 111 | lock = threading.Lock() 112 | total_batches_consumed = 0 113 | 114 | def enqueue_target(i): 115 | for _ in range(repeats): 116 | queue.enqueue(torch.full((1, 2, 3), i)) 117 | 118 | def dequeue_target(): 119 | nonlocal total_batches_consumed 120 | for batch in queue: 121 | batch_size, *_ = batch.shape 122 | with lock: 123 | total_batches_consumed += batch_size 124 | 125 | enqueue_threads = [] 126 | for i in range(enqueue_threads_number): 127 | enqueue_threads.append( 128 | threading.Thread( 129 | target=enqueue_target, name=f"enqueue-thread-{i}", args=(i,) 130 | ) 131 | ) 132 | 133 | dequeue_threads = [] 134 | for i in range(dequeue_threads_number): 135 | dequeue_threads.append( 136 | threading.Thread(target=dequeue_target, name=f"dequeue-thread-{i}") 137 | ) 138 | 139 | for t in enqueue_threads + dequeue_threads: 140 | t.start() 141 | 142 | for t in enqueue_threads: 143 | t.join() 144 | 145 | queue.close() 146 | 147 | for t in dequeue_threads: 148 | t.join() 149 | 150 | self.assertEqual(total_batches_consumed, repeats * enqueue_threads_number) 151 | 152 | 153 | if __name__ == "__main__": 154 | unittest.main() 155 | -------------------------------------------------------------------------------- /reinforcement_learning/tests/contiguous_arrays_env.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | """Mock environment for the test contiguous_arrays_test.py.""" 15 | 16 | import numpy as np 17 | import libtorchbeast 18 | 19 | 20 | class Env: 21 | def __init__(self): 22 | self.frame = np.arange(3 * 4 * 5) 23 | self.frame = self.frame.reshape(3, 4, 5) 24 | self.frame = self.frame.transpose(2, 1, 0) 25 | assert not self.frame.flags.c_contiguous 26 | 27 | def reset(self): 28 | return self.frame 29 | 30 | def step(self, action): 31 | return self.frame, 0.0, False, {} 32 | 33 | 34 | if __name__ == "__main__": 35 | server_address = "unix:/tmp/contiguous_arrays_test" 36 | server = libtorchbeast.Server(Env, server_address=server_address) 37 | server.run() 38 | -------------------------------------------------------------------------------- /reinforcement_learning/tests/contiguous_arrays_test.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | """Test that non-contiguous arrays are handled properly.""" 15 | 16 | import subprocess 17 | import threading 18 | import unittest 19 | 20 | import numpy as np 21 | 22 | import torch 23 | 24 | import libtorchbeast 25 | 26 | 27 | class ContiguousArraysTest(unittest.TestCase): 28 | def setUp(self): 29 | self.server_proc = subprocess.Popen( 30 | ["python", "tests/contiguous_arrays_env.py"] 31 | ) 32 | 33 | server_address = ["unix:/tmp/contiguous_arrays_test"] 34 | self.learner_queue = libtorchbeast.BatchingQueue( 35 | batch_dim=1, minimum_batch_size=1, maximum_batch_size=10, check_inputs=True 36 | ) 37 | self.inference_batcher = libtorchbeast.DynamicBatcher( 38 | batch_dim=1, 39 | minimum_batch_size=1, 40 | maximum_batch_size=10, 41 | timeout_ms=100, 42 | check_outputs=True, 43 | ) 44 | actor = libtorchbeast.ActorPool( 45 | unroll_length=1, 46 | learner_queue=self.learner_queue, 47 | inference_batcher=self.inference_batcher, 48 | env_server_addresses=server_address, 49 | initial_agent_state=(), 50 | ) 51 | 52 | def run(): 53 | actor.run() 54 | 55 | self.actor_thread = threading.Thread(target=run) 56 | self.actor_thread.start() 57 | 58 | self.target = np.arange(3 * 4 * 5) 59 | self.target = self.target.reshape(3, 4, 5) 60 | self.target = self.target.transpose(2, 1, 0) 61 | 62 | def check_inference_inputs(self): 63 | batch = next(self.inference_batcher) 64 | batched_env_outputs, _ = batch.get_inputs() 65 | frame, *_ = batched_env_outputs 66 | self.assertTrue(np.array_equal(frame.shape, (1, 1, 5, 4, 3))) 67 | frame = frame.reshape(5, 4, 3) 68 | self.assertTrue(np.array_equal(frame, self.target)) 69 | # Set an arbitrary output. 70 | batch.set_outputs(((torch.ones(1, 1),), ())) 71 | 72 | def test_contiguous_arrays(self): 73 | self.check_inference_inputs() 74 | # Stop actor thread. 75 | self.inference_batcher.close() 76 | self.learner_queue.close() 77 | self.actor_thread.join() 78 | 79 | def tearDown(self): 80 | self.server_proc.terminate() 81 | -------------------------------------------------------------------------------- /reinforcement_learning/tests/core_agent_state_env.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | """Mock environment for the test core_agent_state_test.py.""" 15 | 16 | import numpy as np 17 | 18 | import libtorchbeast 19 | 20 | 21 | class Env: 22 | def __init__(self): 23 | self.frame = np.zeros((1, 1)) 24 | self.count = 0 25 | self.done_after = 5 26 | 27 | def reset(self): 28 | self.frame = np.zeros((1, 1)) 29 | return self.frame 30 | 31 | def step(self, action): 32 | self.frame += 1 33 | done = self.frame.item() == self.done_after 34 | return self.frame, 0.0, done, {} 35 | 36 | 37 | if __name__ == "__main__": 38 | server_address = "unix:/tmp/core_agent_state_test" 39 | server = libtorchbeast.Server(Env, server_address=server_address) 40 | server.run() 41 | -------------------------------------------------------------------------------- /reinforcement_learning/tests/core_agent_state_test.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | """Test that the core state is handled correctly by the batching mechanism.""" 15 | 16 | import unittest 17 | import threading 18 | import subprocess 19 | 20 | import torch 21 | from torch import nn 22 | 23 | import libtorchbeast 24 | 25 | 26 | class Net(nn.Module): 27 | def __init__(self): 28 | super(Net, self).__init__() 29 | 30 | def initial_state(self): 31 | return torch.zeros(1, 1) 32 | 33 | def forward(self, inputs, core_state): 34 | x = inputs["frame"] 35 | notdone = (~inputs["done"]).float() 36 | T, B, *_ = x.shape 37 | 38 | for nd in notdone.unbind(): 39 | nd.view(1, -1) 40 | core_state = nd * core_state 41 | core_state = core_state + 1 42 | # Arbitrarily return action 1. 43 | action = torch.ones((T, B), dtype=torch.int32) 44 | return (action,), core_state 45 | 46 | 47 | class CoreAgentStateTest(unittest.TestCase): 48 | def setUp(self): 49 | self.server_proc = subprocess.Popen(["python", "tests/core_agent_state_env.py"]) 50 | 51 | self.B = 2 52 | self.T = 3 53 | self.model = Net() 54 | server_address = ["unix:/tmp/core_agent_state_test"] 55 | self.learner_queue = libtorchbeast.BatchingQueue( 56 | batch_dim=1, 57 | minimum_batch_size=self.B, 58 | maximum_batch_size=self.B, 59 | check_inputs=True, 60 | ) 61 | self.inference_batcher = libtorchbeast.DynamicBatcher( 62 | batch_dim=1, 63 | minimum_batch_size=1, 64 | maximum_batch_size=1, 65 | timeout_ms=100, 66 | check_outputs=True, 67 | ) 68 | self.actor = libtorchbeast.ActorPool( 69 | unroll_length=self.T, 70 | learner_queue=self.learner_queue, 71 | inference_batcher=self.inference_batcher, 72 | env_server_addresses=server_address, 73 | initial_agent_state=self.model.initial_state(), 74 | ) 75 | 76 | def inference(self): 77 | for batch in self.inference_batcher: 78 | batched_env_outputs, agent_state = batch.get_inputs() 79 | frame, _, done, *_ = batched_env_outputs 80 | # Check that when done is set we reset the environment. 81 | # Since we only have one actor producing experience we will always 82 | # have batch_size == 1, hence we can safely use item(). 83 | if done.item(): 84 | self.assertEqual(frame.item(), 0.0) 85 | outputs = self.model(dict(frame=frame, done=done), agent_state) 86 | batch.set_outputs(outputs) 87 | 88 | def learn(self): 89 | for i, tensors in enumerate(self.learner_queue): 90 | batch, initial_agent_state = tensors 91 | env_outputs, actor_outputs = batch 92 | frame, _, done, *_ = env_outputs 93 | # Make sure the last env_outputs of a rollout equals the first of the 94 | # following one. 95 | # This is guaranteed to be true if there is only one actor filling up 96 | # the learner queue. 97 | self.assertEqual(frame[self.T][0].item(), frame[0][1].item()) 98 | self.assertEqual(done[self.T][0].item(), done[0][1].item()) 99 | 100 | # Make sure the initial state equals the value of the frame at the beginning 101 | # of the rollout. This has to be the case in our test since: 102 | # - every call to forward increments the core state by one. 103 | # - every call to step increments the value in the frame by one (modulo 5). 104 | env_done_after = 5 # Matches self.done_after in core_agent_state_env.py. 105 | self.assertEqual( 106 | frame[0][0].item(), initial_agent_state[0][0].item() % env_done_after 107 | ) 108 | self.assertEqual( 109 | frame[0][1].item(), initial_agent_state[0][1].item() % env_done_after 110 | ) 111 | 112 | if i >= 10: 113 | # Stop execution. 114 | self.learner_queue.close() 115 | self.inference_batcher.close() 116 | 117 | def test_core_agent_state(self): 118 | def run(): 119 | self.actor.run() 120 | 121 | threads = [ 122 | threading.Thread(target=self.inference), 123 | threading.Thread(target=run), 124 | ] 125 | 126 | # Start actor and inference thread. 127 | for thread in threads: 128 | thread.start() 129 | 130 | self.learn() 131 | 132 | for thread in threads: 133 | thread.join() 134 | 135 | def tearDown(self): 136 | self.server_proc.terminate() 137 | self.server_proc.wait() 138 | 139 | 140 | if __name__ == "__main__": 141 | unittest.main() 142 | -------------------------------------------------------------------------------- /reinforcement_learning/tests/dynamic_batcher_test.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | """Tests for actorpool.DynamicBatcher.""" 15 | 16 | import threading 17 | import time 18 | import unittest 19 | 20 | import numpy as np 21 | import torch 22 | import libtorchbeast 23 | 24 | 25 | _BROKEN_PROMISE_MESSAGE = ( 26 | "The associated promise has been destructed prior" 27 | " to the associated state becoming ready." 28 | ) 29 | 30 | 31 | class DynamicBatcherTest(unittest.TestCase): 32 | def test_simple_run(self): 33 | batcher = libtorchbeast.DynamicBatcher( 34 | batch_dim=0, minimum_batch_size=1, maximum_batch_size=1 35 | ) 36 | 37 | inputs = torch.zeros(1, 2, 3) 38 | outputs = torch.ones(1, 42, 3) 39 | 40 | def target(): 41 | np.testing.assert_array_equal(batcher.compute(inputs), outputs) 42 | 43 | t = threading.Thread(target=target, name="compute-thread") 44 | t.start() 45 | 46 | batch = next(batcher) 47 | np.testing.assert_array_equal(batch.get_inputs(), inputs) 48 | batch.set_outputs(outputs) 49 | 50 | t.join() 51 | 52 | def test_timeout(self): 53 | timeout_ms = 300 54 | batcher = libtorchbeast.DynamicBatcher( 55 | batch_dim=0, 56 | minimum_batch_size=5, 57 | maximum_batch_size=5, 58 | timeout_ms=timeout_ms, 59 | ) 60 | 61 | inputs = torch.zeros(1, 2, 3) 62 | outputs = torch.ones(1, 42, 3) 63 | 64 | def compute_target(): 65 | batcher.compute(inputs) 66 | 67 | compute_thread = threading.Thread(target=compute_target, name="compute-thread") 68 | compute_thread.start() 69 | 70 | start_waiting_time = time.time() 71 | # Wait until approximately timeout_ms. 72 | batch = next(batcher) 73 | waiting_time_ms = (time.time() - start_waiting_time) * 1000 74 | # Timeout has expired and the batch of size 1 (< minimum_batch_size) 75 | # has been consumed. 76 | batch.set_outputs(outputs) 77 | 78 | compute_thread.join() 79 | 80 | self.assertTrue(timeout_ms <= waiting_time_ms <= timeout_ms + timeout_ms / 10) 81 | 82 | def test_batched_run(self, batch_size=10): 83 | batcher = libtorchbeast.DynamicBatcher( 84 | batch_dim=0, minimum_batch_size=batch_size, maximum_batch_size=batch_size 85 | ) 86 | 87 | inputs = [torch.full((1, 2, 3), i) for i in range(batch_size)] 88 | outputs = torch.ones(batch_size, 42, 3) 89 | 90 | def target(i): 91 | while batcher.size() < i: 92 | # Make sure thread i calls compute before thread i + 1. 93 | time.sleep(0.05) 94 | 95 | np.testing.assert_array_equal( 96 | batcher.compute(inputs[i]), outputs[i : i + 1] 97 | ) 98 | 99 | threads = [] 100 | for i in range(batch_size): 101 | threads.append( 102 | threading.Thread(target=target, name=f"compute-thread-{i}", args=(i,)) 103 | ) 104 | 105 | for t in threads: 106 | t.start() 107 | 108 | batch = next(batcher) 109 | 110 | batched_inputs = batch.get_inputs() 111 | np.testing.assert_array_equal(batched_inputs, torch.cat(inputs)) 112 | batch.set_outputs(outputs) 113 | 114 | for t in threads: 115 | t.join() 116 | 117 | def test_dropped_batch(self): 118 | batcher = libtorchbeast.DynamicBatcher( 119 | batch_dim=0, minimum_batch_size=1, maximum_batch_size=1 120 | ) 121 | 122 | inputs = torch.zeros(1, 2, 3) 123 | 124 | def target(): 125 | with self.assertRaisesRegex( 126 | libtorchbeast.AsyncError, _BROKEN_PROMISE_MESSAGE 127 | ): 128 | batcher.compute(inputs) 129 | 130 | t = threading.Thread(target=target, name="compute-thread") 131 | t.start() 132 | 133 | next(batcher) # Retrieves but doesn't keep the batch object. 134 | t.join() 135 | 136 | def test_check_outputs1(self): 137 | batcher = libtorchbeast.DynamicBatcher( 138 | batch_dim=2, minimum_batch_size=1, maximum_batch_size=1 139 | ) 140 | 141 | inputs = torch.zeros(1, 2, 3) 142 | 143 | def target(): 144 | batcher.compute(inputs) 145 | 146 | t = threading.Thread(target=target, name="compute-thread") 147 | t.start() 148 | 149 | batch = next(batcher) 150 | 151 | with self.assertRaisesRegex(ValueError, "output shape must have at least"): 152 | outputs = torch.ones(1) 153 | batch.set_outputs(outputs) 154 | 155 | # Set correct outputs so the thread can join. 156 | batch.set_outputs(torch.ones(1, 1, 1)) 157 | t.join() 158 | 159 | def test_check_outputs2(self): 160 | batcher = libtorchbeast.DynamicBatcher( 161 | batch_dim=2, minimum_batch_size=1, maximum_batch_size=1 162 | ) 163 | 164 | inputs = torch.zeros(1, 2, 3) 165 | 166 | def target(): 167 | batcher.compute(inputs) 168 | 169 | t = threading.Thread(target=target, name="compute-thread") 170 | t.start() 171 | 172 | batch = next(batcher) 173 | 174 | with self.assertRaisesRegex( 175 | ValueError, 176 | "Output shape must have the same batch dimension as the input batch size.", 177 | ): 178 | # Dimenstion two of the outputs is != from the size of the batch (3 != 1). 179 | batch.set_outputs(torch.ones(1, 42, 3)) 180 | 181 | # Set correct outputs so the thread can join. 182 | batch.set_outputs(torch.ones(1, 1, 1)) 183 | t.join() 184 | 185 | def test_multiple_set_outputs_calls(self): 186 | batcher = libtorchbeast.DynamicBatcher( 187 | batch_dim=0, minimum_batch_size=1, maximum_batch_size=1 188 | ) 189 | 190 | inputs = torch.zeros(1, 2, 3) 191 | outputs = torch.ones(1, 42, 3) 192 | 193 | def target(): 194 | batcher.compute(inputs) 195 | 196 | t = threading.Thread(target=target, name="compute-thread") 197 | t.start() 198 | 199 | batch = next(batcher) 200 | batch.set_outputs(outputs) 201 | with self.assertRaisesRegex(RuntimeError, "set_outputs called twice"): 202 | batch.set_outputs(outputs) 203 | 204 | t.join() 205 | 206 | 207 | class DynamicBatcherProducerConsumerTest(unittest.TestCase): 208 | def test_many_consumers( 209 | self, 210 | minimum_batch_size=1, 211 | compute_thread_number=64, 212 | repeats=100, 213 | consume_thread_number=16, 214 | ): 215 | batcher = libtorchbeast.DynamicBatcher( 216 | batch_dim=0, minimum_batch_size=minimum_batch_size 217 | ) 218 | 219 | lock = threading.Lock() 220 | total_batches_consumed = 0 221 | 222 | def compute_thread_target(i): 223 | for _ in range(repeats): 224 | inputs = torch.full((1, 2, 3), i) 225 | batcher.compute(inputs) 226 | 227 | def consume_thread_target(): 228 | nonlocal total_batches_consumed 229 | for batch in batcher: 230 | inputs = batch.get_inputs() 231 | batch_size, *_ = inputs.shape 232 | batch.set_outputs(torch.ones_like(inputs)) 233 | with lock: 234 | total_batches_consumed += batch_size 235 | 236 | compute_threads = [] 237 | for i in range(compute_thread_number): 238 | compute_threads.append( 239 | threading.Thread( 240 | target=compute_thread_target, name=f"compute-thread-{i}", args=(i,) 241 | ) 242 | ) 243 | 244 | consume_threads = [] 245 | for i in range(consume_thread_number): 246 | consume_threads.append( 247 | threading.Thread( 248 | target=consume_thread_target, name=f"consume-thread-{i}" 249 | ) 250 | ) 251 | 252 | for t in compute_threads + consume_threads: 253 | t.start() 254 | 255 | for t in compute_threads: 256 | t.join() 257 | 258 | # Stop iteration in all consume_threads. 259 | batcher.close() 260 | 261 | for t in consume_threads: 262 | t.join() 263 | 264 | self.assertEqual(total_batches_consumed, compute_thread_number * repeats) 265 | 266 | 267 | if __name__ == "__main__": 268 | unittest.main() 269 | -------------------------------------------------------------------------------- /reinforcement_learning/tests/inference_speed_profiling.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import logging 16 | import os 17 | import sys 18 | import threading 19 | import time 20 | import timeit 21 | 22 | import torch 23 | 24 | sys.path.append("..") 25 | import experiment # noqa: E402 26 | 27 | logging.basicConfig( 28 | format=( 29 | "[%(levelname)s:%(process)d %(module)s:%(lineno)d %(asctime)s] " "%(message)s" 30 | ), 31 | level=0, 32 | ) 33 | 34 | batch_size = int(sys.argv[1]) if len(sys.argv) > 1 else 4 35 | num_inference_threads = int(sys.argv[2]) if len(sys.argv) > 2 else 2 36 | 37 | 38 | def main(): 39 | filename = "inference_speed_test.json" 40 | with torch.autograd.profiler.profile() as prof: 41 | run() 42 | logging.info("Collecting trace and writing to '%s.gz'", filename) 43 | prof.export_chrome_trace(filename) 44 | os.system("gzip %s" % filename) 45 | 46 | 47 | def run(): 48 | size = (4, 84, 84) 49 | num_actions = 6 50 | 51 | if torch.cuda.is_available(): 52 | device = torch.device("cuda:0") 53 | else: 54 | device = torch.device("cpu") 55 | 56 | model = experiment.Net(observation_size=size, num_actions=num_actions) 57 | model = model.to(device=device) 58 | 59 | should_stop = threading.Event() 60 | 61 | step = 0 62 | 63 | def stream_inference(frame): 64 | nonlocal step 65 | 66 | T, B, *_ = frame.shape 67 | stream = torch.cuda.Stream() 68 | 69 | with torch.no_grad(): 70 | with torch.cuda.stream(stream): 71 | while not should_stop.is_set(): 72 | input = frame.pin_memory() 73 | input = frame.to(device, non_blocking=True) 74 | outputs = model(input) 75 | outputs = [t.cpu() for t in outputs] 76 | stream.synchronize() 77 | step += B 78 | 79 | def inference(frame, lock=threading.Lock()): # noqa: B008 80 | nonlocal step 81 | 82 | T, B, *_ = frame.shape 83 | with torch.no_grad(): 84 | while not should_stop.is_set(): 85 | input = frame.to(device) 86 | with lock: 87 | outputs = model(input) 88 | step += B 89 | outputs = [t.cpu() for t in outputs] 90 | 91 | def direct_inference(frame): 92 | nonlocal step 93 | frame = frame.to(device) 94 | 95 | T, B, *_ = frame.shape 96 | with torch.no_grad(): 97 | while not should_stop.is_set(): 98 | model(frame) 99 | step += B 100 | 101 | frame = 255 * torch.rand((1, batch_size) + size) 102 | 103 | work_threads = [ 104 | threading.Thread(target=stream_inference, args=(frame,)) 105 | for _ in range(num_inference_threads) 106 | ] 107 | for thread in work_threads: 108 | thread.start() 109 | 110 | try: 111 | while step < 10000: 112 | start_time = timeit.default_timer() 113 | start_step = step 114 | time.sleep(3) 115 | end_step = step 116 | 117 | logging.info( 118 | "Step %i @ %.1f SPS.", 119 | end_step, 120 | (end_step - start_step) / (timeit.default_timer() - start_time), 121 | ) 122 | except KeyboardInterrupt: 123 | pass 124 | 125 | should_stop.set() 126 | for thread in work_threads: 127 | thread.join() 128 | 129 | 130 | if __name__ == "__main__": 131 | main() 132 | -------------------------------------------------------------------------------- /reinforcement_learning/tests/lint_changed.sh: -------------------------------------------------------------------------------- 1 | #!/bin/sh 2 | 3 | # Copyright (c) Facebook, Inc. and its affiliates. 4 | # This source code is licensed under the MIT license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | # This shell script lints only the things that changed in the most recent change. 8 | # It also ignores deleted files, so that black and flake8 don't explode. 9 | 10 | set -e 11 | 12 | CMD="flake8" 13 | CHANGED_FILES="$(git diff --diff-filter=d --name-only master... | grep '\.py$' | grep -v "torchbeast/atari_wrappers.py" | tr '\n' ' ')" 14 | while getopts bi opt; do 15 | case $opt in 16 | b) 17 | CMD="black" 18 | esac 19 | 20 | done 21 | 22 | if [ "$CHANGED_FILES" != "" ] 23 | then 24 | if [[ "$CMD" == "black" ]] 25 | then 26 | command -v black >/dev/null || \ 27 | ( echo "Please install black." && false ) 28 | # Only output if something needs to change. 29 | black --check $CHANGED_FILES 30 | else 31 | flake8 --version | grep '^3\.[6-9]\.' >/dev/null || \ 32 | ( echo "Please install flake8 >=3.6.0." && false ) 33 | 34 | # Soft complaint on too-long-lines. 35 | flake8 --select=E501 --show-source $CHANGED_FILES 36 | # Hard complaint on really long lines. 37 | exec flake8 --max-line-length=127 --show-source $CHANGED_FILES 38 | fi 39 | fi 40 | -------------------------------------------------------------------------------- /reinforcement_learning/tests/polybeast_inference_test.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | """Tests for polybeast inference implementation.""" 15 | 16 | import unittest 17 | import warnings 18 | from unittest import mock 19 | 20 | import torch 21 | from torchbeast import polybeast_learner as polybeast 22 | 23 | 24 | class InferenceTest(unittest.TestCase): 25 | def setUp(self): 26 | self.unroll_length = 1 # Inference called for every step. 27 | self.batch_size = 4 # Arbitrary. 28 | self.frame_dimension = 84 # Has to match what expected by the model. 29 | self.num_actions = 6 # Specific to each environment. 30 | self.num_channels = 4 # Has to match with the first conv layer of the net. 31 | self.core_output_size = 256 # Has to match what expected by the model. 32 | self.num_lstm_layers = 1 # As in the model. 33 | 34 | self.frame = torch.ones( 35 | self.unroll_length, 36 | self.batch_size, 37 | self.num_channels, 38 | self.frame_dimension, 39 | self.frame_dimension, 40 | ) 41 | self.rewards = torch.ones(self.unroll_length, self.batch_size) 42 | self.done = torch.zeros(self.unroll_length, self.batch_size, dtype=torch.uint8) 43 | self.episode_return = torch.ones( 44 | self.unroll_length, self.batch_size 45 | ) # Not used in the current implemenation of inference. 46 | self.episode_step = torch.ones( 47 | self.unroll_length, self.batch_size 48 | ) # Not used in the current implemenation of inference. 49 | 50 | self.mock_batch = mock.Mock() 51 | # Set the mock inference batcher to be iterable and return a mock_batch. 52 | self.mock_inference_batcher = mock.MagicMock() 53 | self.mock_inference_batcher.__iter__.return_value = iter([self.mock_batch]) 54 | 55 | def _test_inference(self, use_lstm, device): 56 | model = polybeast.Net(num_actions=self.num_actions, use_lstm=use_lstm) 57 | model.to(device) 58 | agent_state = model.initial_state() 59 | 60 | inputs = ( 61 | ( 62 | self.frame, 63 | self.rewards, 64 | self.done, 65 | self.episode_return, 66 | self.episode_return, 67 | ), 68 | agent_state, 69 | ) 70 | # Set the behaviour of the methods of the mock batch. 71 | self.mock_batch.get_inputs = mock.Mock(return_value=inputs) 72 | self.mock_batch.set_outputs = mock.Mock() 73 | 74 | # Preparing the mock flags. Could do with just a dict but using 75 | # a Mock object for consistency. 76 | mock_flags = mock.Mock() 77 | mock_flags.actor_device = device 78 | mock_flags.use_lstm = use_lstm 79 | 80 | polybeast.inference(mock_flags, self.mock_inference_batcher, model) 81 | 82 | # Assert the batch is used only once. 83 | self.mock_batch.get_inputs.assert_called_once() 84 | self.mock_batch.set_outputs.assert_called_once() 85 | # Check that set_outputs has been called with paramaters with the expected shape. 86 | batch_args, batch_kwargs = self.mock_batch.set_outputs.call_args 87 | self.assertEqual(batch_kwargs, {}) 88 | model_outputs, *other_args = batch_args 89 | self.assertEqual(other_args, []) 90 | 91 | (action, policy_logits, baseline), core_state = model_outputs 92 | self.assertSequenceEqual(action.shape, (self.unroll_length, self.batch_size)) 93 | self.assertSequenceEqual( 94 | policy_logits.shape, (self.unroll_length, self.batch_size, self.num_actions) 95 | ) 96 | self.assertSequenceEqual(baseline.shape, (self.unroll_length, self.batch_size)) 97 | 98 | for tensor in (action, policy_logits, baseline) + core_state: 99 | self.assertEqual(tensor.device, torch.device("cpu")) 100 | 101 | self.assertEqual(len(core_state), 2 if use_lstm else 0) 102 | for core_state_element in core_state: 103 | self.assertSequenceEqual( 104 | core_state_element.shape, 105 | (self.num_lstm_layers, self.batch_size, self.core_output_size), 106 | ) 107 | 108 | def test_inference_cpu_no_lstm(self): 109 | self._test_inference(use_lstm=False, device=torch.device("cpu")) 110 | 111 | def test_inference_cuda_no_lstm(self): 112 | if not torch.cuda.is_available(): 113 | warnings.warn("Not testing cuda as it's not available") 114 | return 115 | self._test_inference(use_lstm=False, device=torch.device("cuda")) 116 | 117 | def test_inference_cpu_with_lstm(self): 118 | self._test_inference(use_lstm=True, device=torch.device("cpu")) 119 | 120 | def test_inference_cuda_with_lstm(self): 121 | if not torch.cuda.is_available(): 122 | warnings.warn("Not testing cuda as it's not available") 123 | return 124 | self._test_inference(use_lstm=True, device=torch.device("cuda")) 125 | 126 | 127 | if __name__ == "__main__": 128 | unittest.main() 129 | -------------------------------------------------------------------------------- /reinforcement_learning/tests/polybeast_learn_function_test.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | """Tests for polybeast learn function implementation.""" 15 | 16 | import copy 17 | import unittest 18 | from unittest import mock 19 | 20 | import numpy as np 21 | import torch 22 | from torchbeast import polybeast_learner as polybeast 23 | 24 | 25 | def _state_dict_to_numpy(state_dict): 26 | return {key: value.numpy() for key, value in state_dict.items()} 27 | 28 | 29 | class LearnTest(unittest.TestCase): 30 | def setUp(self): 31 | unroll_length = 2 # Arbitrary. 32 | batch_size = 4 # Arbitrary. 33 | frame_dimension = 84 # Has to match what expected by the model. 34 | num_actions = 6 # Specific to each environment. 35 | num_channels = 4 # Has to match with the first conv layer of the net. 36 | 37 | # The following hyperparamaters are arbitrary. 38 | self.lr = 0.1 39 | total_steps = 100000 40 | 41 | # Set the random seed manually to get reproducible results. 42 | torch.manual_seed(0) 43 | 44 | self.model = polybeast.Net(num_actions=num_actions, use_lstm=False) 45 | self.actor_model = polybeast.Net(num_actions=num_actions, use_lstm=False) 46 | self.initial_model_dict = copy.deepcopy(self.model.state_dict()) 47 | self.initial_actor_model_dict = copy.deepcopy(self.actor_model.state_dict()) 48 | 49 | optimizer = torch.optim.SGD(self.model.parameters(), lr=self.lr) 50 | 51 | scheduler = torch.optim.lr_scheduler.StepLR( 52 | optimizer, step_size=total_steps // 10 53 | ) 54 | 55 | self.stats = {} 56 | 57 | # The call to plogger.log will not perform any action. 58 | plogger = mock.Mock() 59 | plogger.log = mock.Mock() 60 | 61 | # Mock flags. 62 | mock_flags = mock.Mock() 63 | mock_flags.learner_device = torch.device("cpu") 64 | mock_flags.reward_clipping = "abs_one" # Default value from cmd. 65 | mock_flags.discounting = 0.99 # Default value from cmd. 66 | mock_flags.baseline_cost = 0.5 # Default value from cmd. 67 | mock_flags.entropy_cost = 0.0006 # Default value from cmd. 68 | mock_flags.unroll_length = unroll_length 69 | mock_flags.batch_size = batch_size 70 | mock_flags.grad_norm_clipping = 40 71 | 72 | # Prepare content for mock_learner_queue. 73 | frame = torch.ones( 74 | unroll_length, batch_size, num_channels, frame_dimension, frame_dimension 75 | ) 76 | rewards = torch.ones(unroll_length, batch_size) 77 | done = torch.zeros(unroll_length, batch_size, dtype=torch.uint8) 78 | episode_step = torch.ones(unroll_length, batch_size) 79 | episode_return = torch.ones(unroll_length, batch_size) 80 | 81 | env_outputs = (frame, rewards, done, episode_step, episode_return) 82 | actor_outputs = ( 83 | # Actions taken. 84 | torch.randint(low=0, high=num_actions, size=(unroll_length, batch_size)), 85 | # Logits. 86 | torch.randn(unroll_length, batch_size, num_actions), 87 | # Baseline. 88 | torch.rand(unroll_length, batch_size), 89 | ) 90 | initial_agent_state = () # No lstm. 91 | tensors = ((env_outputs, actor_outputs), initial_agent_state) 92 | 93 | # Mock learner_queue. 94 | mock_learner_queue = mock.MagicMock() 95 | mock_learner_queue.__iter__.return_value = iter([tensors]) 96 | 97 | self.learn_args = ( 98 | mock_flags, 99 | mock_learner_queue, 100 | self.model, 101 | self.actor_model, 102 | optimizer, 103 | scheduler, 104 | self.stats, 105 | plogger, 106 | ) 107 | 108 | def test_parameters_copied_to_actor_model(self): 109 | """Check that the learner model copies the parameters to the actor model.""" 110 | # Reset models. 111 | self.model.load_state_dict(self.initial_model_dict) 112 | self.actor_model.load_state_dict(self.initial_actor_model_dict) 113 | 114 | polybeast.learn(*self.learn_args) 115 | 116 | np.testing.assert_equal( 117 | _state_dict_to_numpy(self.actor_model.state_dict()), 118 | _state_dict_to_numpy(self.model.state_dict()), 119 | ) 120 | 121 | def test_weights_update(self): 122 | """Check that trainable parameters get updated after one iteration.""" 123 | # Reset models. 124 | self.model.load_state_dict(self.initial_model_dict) 125 | self.actor_model.load_state_dict(self.initial_actor_model_dict) 126 | 127 | polybeast.learn(*self.learn_args) 128 | 129 | model_state_dict = self.model.state_dict(keep_vars=True) 130 | actor_model_state_dict = self.actor_model.state_dict(keep_vars=True) 131 | for key, initial_tensor in self.initial_model_dict.items(): 132 | model_tensor = model_state_dict[key] 133 | actor_model_tensor = actor_model_state_dict[key] 134 | # Assert that the gradient is not zero for the learner. 135 | self.assertGreater(torch.norm(model_tensor.grad), 0.0) 136 | # Assert actor has no gradient. 137 | # Note that even though actor model tensors have no gradient, 138 | # they have requires_grad == True. No gradients are ever calculated 139 | # for these tensors because the inference function in polybeast.py 140 | # (that performs forward passes with the actor_model) uses torch.no_grad 141 | # context manager. 142 | self.assertIsNone(actor_model_tensor.grad) 143 | # Assert that the weights are updated in the expected way. 144 | # We manually perform a gradient descent step, 145 | # and check that they are the same as the calculated ones 146 | # (ignoring floating point errors). 147 | expected_tensor = ( 148 | initial_tensor.detach().numpy() - self.lr * model_tensor.grad.numpy() 149 | ) 150 | np.testing.assert_almost_equal( 151 | model_tensor.detach().numpy(), expected_tensor 152 | ) 153 | np.testing.assert_almost_equal( 154 | actor_model_tensor.detach().numpy(), expected_tensor 155 | ) 156 | 157 | def test_gradients_update(self): 158 | """Check that gradients get updated after one iteration.""" 159 | # Reset models. 160 | self.model.load_state_dict(self.initial_model_dict) 161 | self.actor_model.load_state_dict(self.initial_actor_model_dict) 162 | 163 | # There should be no calculated gradient yet. 164 | for p in self.model.parameters(): 165 | self.assertIsNone(p.grad) 166 | for p in self.actor_model.parameters(): 167 | self.assertIsNone(p.grad) 168 | 169 | polybeast.learn(*self.learn_args) 170 | 171 | # Check that every parameter for the learner model has a gradient, and that 172 | # there is at least some non-zero gradient for each set of paramaters. 173 | for p in self.model.parameters(): 174 | self.assertIsNotNone(p.grad) 175 | self.assertFalse(torch.equal(p.grad, torch.zeros_like(p.grad))) 176 | 177 | # Check that the actor model has no gradients associated with it. 178 | for p in self.actor_model.parameters(): 179 | self.assertIsNone(p.grad) 180 | 181 | def test_non_zero_loss(self): 182 | """Check that the loss is not zero after one iteration.""" 183 | # Reset models. 184 | self.model.load_state_dict(self.initial_model_dict) 185 | self.actor_model.load_state_dict(self.initial_actor_model_dict) 186 | 187 | polybeast.learn(*self.learn_args) 188 | 189 | self.assertNotEqual(self.stats["total_loss"], 0.0) 190 | self.assertNotEqual(self.stats["pg_loss"], 0.0) 191 | self.assertNotEqual(self.stats["baseline_loss"], 0.0) 192 | self.assertNotEqual(self.stats["entropy_loss"], 0.0) 193 | 194 | 195 | if __name__ == "__main__": 196 | unittest.main() 197 | -------------------------------------------------------------------------------- /reinforcement_learning/tests/polybeast_loss_functions_test.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | """Tests for polybeast loss functions implementation.""" 15 | 16 | import unittest 17 | 18 | import numpy as np 19 | import torch 20 | from torch.nn import functional as F 21 | from torchbeast import polybeast_learner as polybeast 22 | 23 | 24 | def _softmax(logits): 25 | """Applies softmax non-linearity on inputs.""" 26 | return np.exp(logits) / np.sum(np.exp(logits), axis=-1, keepdims=True) 27 | 28 | 29 | def _softmax_grad(logits): 30 | """Compute the gradient of softmax function.""" 31 | s = np.expand_dims(_softmax(logits), 0) 32 | return s.T * (np.eye(s.size) - s) 33 | 34 | 35 | def assert_allclose(actual, desired): 36 | return np.testing.assert_allclose(actual, desired, rtol=1e-06, atol=1e-05) 37 | 38 | 39 | class ComputeBaselineLossTest(unittest.TestCase): 40 | def setUp(self): 41 | # Floating point constants are randomly generated. 42 | self.advantages = np.array([1.4, 3.43, 5.2, 0.33]) 43 | 44 | def test_compute_baseline_loss(self): 45 | ground_truth_value = 0.5 * np.sum(self.advantages ** 2) 46 | assert_allclose( 47 | ground_truth_value, 48 | polybeast.compute_baseline_loss(torch.from_numpy(self.advantages)), 49 | ) 50 | 51 | def test_compute_baseline_loss_grad(self): 52 | advantages_tensor = torch.from_numpy(self.advantages) 53 | advantages_tensor.requires_grad_() 54 | calculated_value = polybeast.compute_baseline_loss(advantages_tensor) 55 | calculated_value.backward() 56 | 57 | # Manually computed gradients: 58 | # 0.5 * d(xˆ2)/dx == x 59 | # hence the expected gradient is the same as self.advantages. 60 | assert_allclose(advantages_tensor.grad, self.advantages) 61 | 62 | 63 | class ComputeEntropyLossTest(unittest.TestCase): 64 | def setUp(self): 65 | # Floating point constants are randomly generated. 66 | self.logits = np.array([0.0012, 0.321, 0.523, 0.109, 0.416]) 67 | 68 | def test_compute_entropy_loss(self): 69 | # Calculate entropy with: 70 | # H(s) = - sum(prob(x) * ln(prob(x)) for each x in s) 71 | softmax_logits = _softmax(self.logits) 72 | ground_truth_value = np.sum(softmax_logits * np.log(softmax_logits)) 73 | calculated_value = polybeast.compute_entropy_loss(torch.from_numpy(self.logits)) 74 | 75 | assert_allclose(ground_truth_value, calculated_value) 76 | 77 | def test_compute_entropy_loss_grad(self): 78 | logits_tensor = torch.from_numpy(self.logits) 79 | logits_tensor.requires_grad_() 80 | calculated_value = polybeast.compute_entropy_loss(logits_tensor) 81 | calculated_value.backward() 82 | 83 | expected_grad = np.matmul( 84 | np.ones_like(self.logits), 85 | np.matmul( 86 | np.diag(1 + np.log(_softmax(self.logits))), _softmax_grad(self.logits) 87 | ), 88 | ) 89 | 90 | assert_allclose(logits_tensor.grad, expected_grad) 91 | 92 | 93 | class ComputePolicyGradientLossTest(unittest.TestCase): 94 | def setUp(self): 95 | # Floating point constants are randomly generated. 96 | self.logits = np.array( 97 | [ 98 | [ 99 | [0.206, 0.738, 0.125, 0.484, 0.332], 100 | [0.168, 0.504, 0.523, 0.496, 0.626], 101 | [0.236, 0.186, 0.627, 0.441, 0.533], 102 | ], 103 | [ 104 | [0.015, 0.904, 0.583, 0.651, 0.855], 105 | [0.811, 0.292, 0.061, 0.597, 0.590], 106 | [0.999, 0.504, 0.464, 0.077, 0.143], 107 | ], 108 | ] 109 | ) 110 | self.actions = np.array([[3, 0, 1], [4, 2, 2]]) 111 | self.advantages = np.array([[1.4, 0.31, 0.75], [2.1, 1.5, 0.03]]) 112 | 113 | def test_compute_policy_gradient_loss(self): 114 | T, B, N = self.logits.shape 115 | 116 | # Calculate the the cross entropy loss, with the formula: 117 | # loss = -sum_over_j(y_j * log(p_j)) 118 | # Where: 119 | # - `y_j` is whether the action corrisponding to index j has been taken or not, 120 | # (hence y is a one-hot-array of size == number of actions). 121 | # - `p_j` is the value of the sofmax logit corresponding to the jth action. 122 | # In our implementation, we also multiply for the advantages. 123 | labels = F.one_hot(torch.from_numpy(self.actions), num_classes=N).numpy() 124 | cross_entropy_loss = -labels * np.log(_softmax(self.logits)) 125 | ground_truth_value = np.sum( 126 | cross_entropy_loss * self.advantages.reshape(T, B, 1) 127 | ) 128 | 129 | calculated_value = polybeast.compute_policy_gradient_loss( 130 | torch.from_numpy(self.logits), 131 | torch.from_numpy(self.actions), 132 | torch.from_numpy(self.advantages), 133 | ) 134 | assert_allclose(ground_truth_value, calculated_value.item()) 135 | 136 | def test_compute_policy_gradient_loss_grad(self): 137 | T, B, N = self.logits.shape 138 | 139 | logits_tensor = torch.from_numpy(self.logits) 140 | logits_tensor.requires_grad_() 141 | 142 | calculated_value = polybeast.compute_policy_gradient_loss( 143 | logits_tensor, 144 | torch.from_numpy(self.actions), 145 | torch.from_numpy(self.advantages), 146 | ) 147 | 148 | self.assertSequenceEqual(calculated_value.shape, []) 149 | calculated_value.backward() 150 | 151 | # The gradient of the cross entropy loss function for the jth logit 152 | # can be expressed as: 153 | # p_j - y_j 154 | # where: 155 | # - `p_j` is the value of the softmax logit corresponding to the jth action. 156 | # - `y_j` is whether the action corrisponding to index j has been taken, 157 | # (hence y is a one-hot-array of size == number of actions). 158 | # In our implementation, we also multiply for the advantages. 159 | softmax = _softmax(self.logits) 160 | labels = F.one_hot(torch.from_numpy(self.actions), num_classes=N).numpy() 161 | expected_grad = (softmax - labels) * self.advantages.reshape(T, B, 1) 162 | 163 | assert_allclose(logits_tensor.grad, expected_grad) 164 | 165 | def test_compute_policy_gradient_loss_grad_flow(self): 166 | logits_tensor = torch.from_numpy(self.logits) 167 | logits_tensor.requires_grad_() 168 | advantages_tensor = torch.from_numpy(self.advantages) 169 | advantages_tensor.requires_grad_() 170 | 171 | loss = polybeast.compute_policy_gradient_loss( 172 | logits_tensor, torch.from_numpy(self.actions), advantages_tensor 173 | ) 174 | loss.backward() 175 | 176 | self.assertIsNotNone(logits_tensor.grad) 177 | self.assertIsNone(advantages_tensor.grad) 178 | 179 | 180 | if __name__ == "__main__": 181 | unittest.main() 182 | -------------------------------------------------------------------------------- /reinforcement_learning/tests/polybeast_net_test.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | """Tests for polybeast Net class implementation.""" 15 | 16 | import unittest 17 | 18 | import torch 19 | from torchbeast import polybeast_learner as polybeast 20 | 21 | 22 | class NetTest(unittest.TestCase): 23 | def setUp(self): 24 | self.unroll_length = 4 # Arbitrary. 25 | self.batch_size = 4 # Arbitrary. 26 | self.frame_dimension = 84 # Has to match what expected by the model. 27 | self.num_actions = 6 # Specific to each environment. 28 | self.num_channels = 4 # Has to match with the first conv layer of the net. 29 | self.core_output_size = 256 # Has to match what expected by the model. 30 | self.num_lstm_layers = 1 # As in the model. 31 | 32 | self.inputs = dict( 33 | frame=torch.ones( 34 | self.unroll_length, 35 | self.batch_size, 36 | self.num_channels, 37 | self.frame_dimension, 38 | self.frame_dimension, 39 | ), 40 | reward=torch.ones(self.batch_size, self.unroll_length), 41 | done=torch.zeros(self.batch_size, self.unroll_length, dtype=torch.uint8), 42 | ) 43 | 44 | def test_forward_return_signature_no_lstm(self): 45 | model = polybeast.Net(num_actions=self.num_actions, use_lstm=False) 46 | core_state = () 47 | 48 | (action, policy_logits, baseline), core_state = model(self.inputs, core_state) 49 | self.assertSequenceEqual(action.shape, (self.batch_size, self.unroll_length)) 50 | self.assertSequenceEqual( 51 | policy_logits.shape, (self.batch_size, self.unroll_length, self.num_actions) 52 | ) 53 | self.assertSequenceEqual(baseline.shape, (self.batch_size, self.unroll_length)) 54 | self.assertSequenceEqual(core_state, ()) 55 | 56 | def test_forward_return_signature_with_lstm(self): 57 | model = polybeast.Net(num_actions=self.num_actions, use_lstm=True) 58 | core_state = model.initial_state(self.batch_size) 59 | 60 | (action, policy_logits, baseline), core_state = model(self.inputs, core_state) 61 | self.assertSequenceEqual(action.shape, (self.batch_size, self.unroll_length)) 62 | self.assertSequenceEqual( 63 | policy_logits.shape, (self.batch_size, self.unroll_length, self.num_actions) 64 | ) 65 | self.assertSequenceEqual(baseline.shape, (self.batch_size, self.unroll_length)) 66 | self.assertEqual(len(core_state), 2) 67 | for core_state_element in core_state: 68 | self.assertSequenceEqual( 69 | core_state_element.shape, 70 | (self.num_lstm_layers, self.batch_size, self.core_output_size), 71 | ) 72 | 73 | def test_initial_state(self): 74 | model_no_lstm = polybeast.Net(num_actions=self.num_actions, use_lstm=False) 75 | initial_state_no_lstm = model_no_lstm.initial_state(self.batch_size) 76 | self.assertSequenceEqual(initial_state_no_lstm, ()) 77 | 78 | model_with_lstm = polybeast.Net(num_actions=self.num_actions, use_lstm=True) 79 | initial_state_with_lstm = model_with_lstm.initial_state(self.batch_size) 80 | self.assertEqual(len(initial_state_with_lstm), 2) 81 | for core_state_element in initial_state_with_lstm: 82 | self.assertSequenceEqual( 83 | core_state_element.shape, 84 | (self.num_lstm_layers, self.batch_size, self.core_output_size), 85 | ) 86 | 87 | 88 | if __name__ == "__main__": 89 | unittest.main() 90 | -------------------------------------------------------------------------------- /reinforcement_learning/tests/vtrace_test.py: -------------------------------------------------------------------------------- 1 | # This file taken from 2 | # https://github.com/deepmind/scalable_agent/blob/ 3 | # d24bd74bd53d454b7222b7f0bea57a358e4ca33e/vtrace_test.py 4 | # and modified. 5 | 6 | # Copyright 2018 Google LLC 7 | # 8 | # Licensed under the Apache License, Version 2.0 (the "License"); 9 | # you may not use this file except in compliance with the License. 10 | # You may obtain a copy of the License at 11 | # 12 | # https://www.apache.org/licenses/LICENSE-2.0 13 | # 14 | # Unless required by applicable law or agreed to in writing, software 15 | # distributed under the License is distributed on an "AS IS" BASIS, 16 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 17 | # See the License for the specific language governing permissions and 18 | # limitations under the License. 19 | 20 | """Tests for V-trace. 21 | 22 | For details and theory see: 23 | 24 | "IMPALA: Scalable Distributed Deep-RL with 25 | Importance Weighted Actor-Learner Architectures" 26 | by Espeholt, Soyer, Munos et al. 27 | """ 28 | 29 | import unittest 30 | 31 | import numpy as np 32 | import torch 33 | from torchbeast.core import vtrace 34 | 35 | 36 | def _shaped_arange(*shape): 37 | """Runs np.arange, converts to float and reshapes.""" 38 | return np.arange(np.prod(shape), dtype=np.float32).reshape(*shape) 39 | 40 | 41 | def _softmax(logits): 42 | """Applies softmax non-linearity on inputs.""" 43 | return np.exp(logits) / np.sum(np.exp(logits), axis=-1, keepdims=True) 44 | 45 | 46 | def _ground_truth_calculation( 47 | discounts, 48 | log_rhos, 49 | rewards, 50 | values, 51 | bootstrap_value, 52 | clip_rho_threshold, 53 | clip_pg_rho_threshold, 54 | ): 55 | """Calculates the ground truth for V-trace in Python/Numpy.""" 56 | vs = [] 57 | seq_len = len(discounts) 58 | rhos = np.exp(log_rhos) 59 | cs = np.minimum(rhos, 1.0) 60 | clipped_rhos = rhos 61 | if clip_rho_threshold: 62 | clipped_rhos = np.minimum(rhos, clip_rho_threshold) 63 | clipped_pg_rhos = rhos 64 | if clip_pg_rho_threshold: 65 | clipped_pg_rhos = np.minimum(rhos, clip_pg_rho_threshold) 66 | 67 | # This is a very inefficient way to calculate the V-trace ground truth. 68 | # We calculate it this way because it is close to the mathematical notation 69 | # of V-trace. 70 | # v_s = V(x_s) 71 | # + \sum^{T-1}_{t=s} \gamma^{t-s} 72 | # * \prod_{i=s}^{t-1} c_i 73 | # * \rho_t (r_t + \gamma V(x_{t+1}) - V(x_t)) 74 | # Note that when we take the product over c_i, we write `s:t` as the 75 | # notation of the paper is inclusive of the `t-1`, but Python is exclusive. 76 | # Also note that np.prod([]) == 1. 77 | values_t_plus_1 = np.concatenate([values, bootstrap_value[None, :]], axis=0) 78 | for s in range(seq_len): 79 | v_s = np.copy(values[s]) # Very important copy. 80 | for t in range(s, seq_len): 81 | v_s += ( 82 | np.prod(discounts[s:t], axis=0) 83 | * np.prod(cs[s:t], axis=0) 84 | * clipped_rhos[t] 85 | * (rewards[t] + discounts[t] * values_t_plus_1[t + 1] - values[t]) 86 | ) 87 | vs.append(v_s) 88 | vs = np.stack(vs, axis=0) 89 | pg_advantages = clipped_pg_rhos * ( 90 | rewards 91 | + discounts * np.concatenate([vs[1:], bootstrap_value[None, :]], axis=0) 92 | - values 93 | ) 94 | 95 | return vtrace.VTraceReturns(vs=vs, pg_advantages=pg_advantages) 96 | 97 | 98 | def assert_allclose(actual, desired): 99 | return np.testing.assert_allclose(actual, desired, rtol=1e-06, atol=1e-05) 100 | 101 | 102 | class ActionLogProbsTest(unittest.TestCase): 103 | def test_action_log_probs(self, batch_size=2): 104 | seq_len = 7 105 | num_actions = 3 106 | 107 | policy_logits = _shaped_arange(seq_len, batch_size, num_actions) + 10 108 | actions = np.random.randint( 109 | 0, num_actions, size=(seq_len, batch_size), dtype=np.int64 110 | ) 111 | 112 | action_log_probs_tensor = vtrace.action_log_probs( 113 | torch.from_numpy(policy_logits), torch.from_numpy(actions) 114 | ) 115 | 116 | # Ground Truth 117 | # Using broadcasting to create a mask that indexes action logits 118 | action_index_mask = actions[..., None] == np.arange(num_actions) 119 | 120 | def index_with_mask(array, mask): 121 | return array[mask].reshape(*array.shape[:-1]) 122 | 123 | # Note: Normally log(softmax) is not a good idea because it's not 124 | # numerically stable. However, in this test we have well-behaved values. 125 | ground_truth_v = index_with_mask( 126 | np.log(_softmax(policy_logits)), action_index_mask 127 | ) 128 | 129 | assert_allclose(ground_truth_v, action_log_probs_tensor) 130 | 131 | def test_action_log_probs_batch_1(self): 132 | self.test_action_log_probs(1) 133 | 134 | 135 | class VtraceTest(unittest.TestCase): 136 | def test_vtrace(self, batch_size=5): 137 | """Tests V-trace against ground truth data calculated in python.""" 138 | seq_len = 5 139 | 140 | # Create log_rhos such that rho will span from near-zero to above the 141 | # clipping thresholds. In particular, calculate log_rhos in [-2.5, 2.5), 142 | # so that rho is in approx [0.08, 12.2). 143 | log_rhos = _shaped_arange(seq_len, batch_size) / (batch_size * seq_len) 144 | log_rhos = 5 * (log_rhos - 0.5) # [0.0, 1.0) -> [-2.5, 2.5). 145 | values = { 146 | "log_rhos": log_rhos, 147 | # T, B where B_i: [0.9 / (i+1)] * T 148 | "discounts": np.array( 149 | [[0.9 / (b + 1) for b in range(batch_size)] for _ in range(seq_len)], 150 | dtype=np.float32, 151 | ), 152 | "rewards": _shaped_arange(seq_len, batch_size), 153 | "values": _shaped_arange(seq_len, batch_size) / batch_size, 154 | "bootstrap_value": _shaped_arange(batch_size) + 1.0, 155 | "clip_rho_threshold": 3.7, 156 | "clip_pg_rho_threshold": 2.2, 157 | } 158 | 159 | ground_truth = _ground_truth_calculation(**values) 160 | 161 | values = {key: torch.tensor(value) for key, value in values.items()} 162 | output = vtrace.from_importance_weights(**values) 163 | 164 | for a, b in zip(ground_truth, output): 165 | assert_allclose(a, b) 166 | 167 | def test_vtrace_batch_1(self): 168 | self.test_vtrace(1) 169 | 170 | def test_vtrace_from_logits(self, batch_size=2): 171 | """Tests V-trace calculated from logits.""" 172 | seq_len = 5 173 | num_actions = 3 174 | clip_rho_threshold = None # No clipping. 175 | clip_pg_rho_threshold = None # No clipping. 176 | 177 | values = { 178 | "behavior_policy_logits": _shaped_arange(seq_len, batch_size, num_actions), 179 | "target_policy_logits": _shaped_arange(seq_len, batch_size, num_actions), 180 | "actions": np.random.randint( 181 | 0, num_actions - 1, size=(seq_len, batch_size) 182 | ), 183 | "discounts": np.array( # T, B where B_i: [0.9 / (i+1)] * T 184 | [[0.9 / (b + 1) for b in range(batch_size)] for _ in range(seq_len)], 185 | dtype=np.float32, 186 | ), 187 | "rewards": _shaped_arange(seq_len, batch_size), 188 | "values": _shaped_arange(seq_len, batch_size) / batch_size, 189 | "bootstrap_value": _shaped_arange(batch_size) + 1.0, # B 190 | } 191 | values = {k: torch.from_numpy(v) for k, v in values.items()} 192 | 193 | from_logits_output = vtrace.from_logits( 194 | clip_rho_threshold=clip_rho_threshold, 195 | clip_pg_rho_threshold=clip_pg_rho_threshold, 196 | **values, 197 | ) 198 | 199 | target_log_probs = vtrace.action_log_probs( 200 | values["target_policy_logits"], values["actions"] 201 | ) 202 | behavior_log_probs = vtrace.action_log_probs( 203 | values["behavior_policy_logits"], values["actions"] 204 | ) 205 | log_rhos = target_log_probs - behavior_log_probs 206 | 207 | # Calculate V-trace using the ground truth logits. 208 | from_iw = vtrace.from_importance_weights( 209 | log_rhos=log_rhos, 210 | discounts=values["discounts"], 211 | rewards=values["rewards"], 212 | values=values["values"], 213 | bootstrap_value=values["bootstrap_value"], 214 | clip_rho_threshold=clip_rho_threshold, 215 | clip_pg_rho_threshold=clip_pg_rho_threshold, 216 | ) 217 | 218 | assert_allclose(from_iw.vs, from_logits_output.vs) 219 | assert_allclose(from_iw.pg_advantages, from_logits_output.pg_advantages) 220 | assert_allclose( 221 | behavior_log_probs, from_logits_output.behavior_action_log_probs 222 | ) 223 | assert_allclose(target_log_probs, from_logits_output.target_action_log_probs) 224 | assert_allclose(log_rhos, from_logits_output.log_rhos) 225 | 226 | def test_vtrace_from_logits_batch_1(self): 227 | self.test_vtrace_from_logits(1) 228 | 229 | def test_higher_rank_inputs_for_importance_weights(self): 230 | """Checks support for additional dimensions in inputs.""" 231 | T = 3 # pylint: disable=invalid-name 232 | B = 2 # pylint: disable=invalid-name 233 | values = { 234 | "log_rhos": torch.zeros(T, B, 1), 235 | "discounts": torch.zeros(T, B, 1), 236 | "rewards": torch.zeros(T, B, 42), 237 | "values": torch.zeros(T, B, 42), 238 | "bootstrap_value": torch.zeros(B, 42), 239 | } 240 | output = vtrace.from_importance_weights(**values) 241 | self.assertSequenceEqual(output.vs.shape, (T, B, 42)) 242 | 243 | def test_inconsistent_rank_inputs_for_importance_weights(self): 244 | """Test one of many possible errors in shape of inputs.""" 245 | T = 3 # pylint: disable=invalid-name 246 | B = 2 # pylint: disable=invalid-name 247 | 248 | values = { 249 | "log_rhos": torch.zeros(T, B, 1), 250 | "discounts": torch.zeros(T, B, 1), 251 | "rewards": torch.zeros(T, B, 42), 252 | "values": torch.zeros(T, B, 42), 253 | # Should be [B, 42]. 254 | "bootstrap_value": torch.zeros(B), 255 | } 256 | 257 | with self.assertRaisesRegex( 258 | RuntimeError, "same number of dimensions: got 3 and 2" 259 | ): 260 | vtrace.from_importance_weights(**values) 261 | 262 | 263 | if __name__ == "__main__": 264 | unittest.main() 265 | -------------------------------------------------------------------------------- /reinforcement_learning/torchbeast/core/environment.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | """The environment class for MonoBeast.""" 15 | 16 | import torch 17 | 18 | 19 | def _format_frame(frame): 20 | frame = torch.from_numpy(frame) 21 | return frame.view((1, 1) + frame.shape) # (...) -> (T,B,...). 22 | 23 | 24 | class Environment: 25 | def __init__(self, gym_env): 26 | self.gym_env = gym_env 27 | self.episode_return = None 28 | self.episode_step = None 29 | 30 | def initial(self): 31 | initial_reward = torch.zeros(1, 1) 32 | # This supports only single-tensor actions ATM. 33 | initial_last_action = torch.zeros(1, 1, dtype=torch.int64) 34 | self.episode_return = torch.zeros(1, 1) 35 | self.episode_step = torch.zeros(1, 1, dtype=torch.int32) 36 | initial_done = torch.ones(1, 1, dtype=torch.uint8) 37 | initial_frame = _format_frame(self.gym_env.reset()) 38 | return dict( 39 | frame=initial_frame, 40 | reward=initial_reward, 41 | done=initial_done, 42 | episode_return=self.episode_return, 43 | episode_step=self.episode_step, 44 | last_action=initial_last_action, 45 | ) 46 | 47 | def step(self, action): 48 | frame, reward, done, unused_info = self.gym_env.step(action.item()) 49 | self.episode_step += 1 50 | self.episode_return += reward 51 | episode_step = self.episode_step 52 | episode_return = self.episode_return 53 | if done: 54 | frame = self.gym_env.reset() 55 | self.episode_return = torch.zeros(1, 1) 56 | self.episode_step = torch.zeros(1, 1, dtype=torch.int32) 57 | 58 | frame = _format_frame(frame) 59 | reward = torch.tensor(reward).view(1, 1) 60 | done = torch.tensor(done).view(1, 1) 61 | 62 | return dict( 63 | frame=frame, 64 | reward=reward, 65 | done=done, 66 | episode_return=episode_return, 67 | episode_step=episode_step, 68 | last_action=action, 69 | ) 70 | 71 | def close(self): 72 | self.gym_env.close() 73 | -------------------------------------------------------------------------------- /reinforcement_learning/torchbeast/core/file_writer.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | 16 | import copy 17 | import csv 18 | import datetime 19 | import json 20 | import logging 21 | import os 22 | import time 23 | from typing import Dict 24 | 25 | 26 | 27 | def gather_metadata() -> Dict: 28 | date_start = datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S.%f") 29 | # Gathering git metadata. 30 | try: 31 | import git 32 | 33 | try: 34 | repo = git.Repo(search_parent_directories=True) 35 | git_sha = repo.commit().hexsha 36 | git_data = dict( 37 | commit=git_sha, 38 | branch=None if repo.head.is_detached else repo.active_branch.name, 39 | is_dirty=repo.is_dirty(), 40 | path=repo.git_dir, 41 | ) 42 | except git.InvalidGitRepositoryError: 43 | git_data = None 44 | except ImportError: 45 | git_data = None 46 | # Gathering slurm metadata. 47 | if "SLURM_JOB_ID" in os.environ: 48 | slurm_env_keys = [k for k in os.environ if k.startswith("SLURM")] 49 | slurm_data = {} 50 | for k in slurm_env_keys: 51 | d_key = k.replace("SLURM_", "").replace("SLURMD_", "").lower() 52 | slurm_data[d_key] = os.environ[k] 53 | else: 54 | slurm_data = None 55 | return dict( 56 | date_start=date_start, 57 | date_end=None, 58 | successful=False, 59 | git=git_data, 60 | slurm=slurm_data, 61 | env=os.environ.copy(), 62 | ) 63 | 64 | 65 | class FileWriter: 66 | def __init__( 67 | self, 68 | xpid: str = None, 69 | xp_args: dict = None, 70 | rootdir: str = "~/logs", 71 | symlink_to_latest: bool = True, 72 | ): 73 | if not xpid: 74 | # Make unique id. 75 | xpid = "{proc}_{unixtime}".format( 76 | proc=os.getpid(), unixtime=int(time.time()) 77 | ) 78 | self.xpid = xpid 79 | self._tick = 0 80 | 81 | # Metadata gathering. 82 | if xp_args is None: 83 | xp_args = {} 84 | self.metadata = gather_metadata() 85 | # We need to copy the args, otherwise when we close the file writer 86 | # (and rewrite the args) we might have non-serializable objects (or 87 | # other unwanted side-effects). 88 | self.metadata["args"] = copy.deepcopy(xp_args) 89 | self.metadata["xpid"] = self.xpid 90 | 91 | formatter = logging.Formatter("%(message)s") 92 | self._logger = logging.getLogger("logs/out") 93 | 94 | # To stdout handler. 95 | shandle = logging.StreamHandler() 96 | shandle.setFormatter(formatter) 97 | self._logger.addHandler(shandle) 98 | self._logger.setLevel(logging.INFO) 99 | 100 | rootdir = os.path.expandvars(os.path.expanduser(rootdir)) 101 | # To file handler. 102 | self.basepath = os.path.join(rootdir, self.xpid) 103 | if not os.path.exists(self.basepath): 104 | self._logger.info("Creating log directory: %s", self.basepath) 105 | os.makedirs(self.basepath, exist_ok=True) 106 | else: 107 | self._logger.info("Found log directory: %s", self.basepath) 108 | 109 | if symlink_to_latest: 110 | # Add 'latest' as symlink unless it exists and is no symlink. 111 | symlink = os.path.join(rootdir, "latest") 112 | try: 113 | if os.path.islink(symlink): 114 | os.remove(symlink) 115 | if not os.path.exists(symlink): 116 | os.symlink(self.basepath, symlink) 117 | self._logger.info("Symlinked log directory: %s", symlink) 118 | except OSError: 119 | # os.remove() or os.symlink() raced. Don't do anything. 120 | pass 121 | 122 | self.paths = dict( 123 | msg="{base}/out.log".format(base=self.basepath), 124 | logs="{base}/logs.csv".format(base=self.basepath), 125 | fields="{base}/fields.csv".format(base=self.basepath), 126 | meta="{base}/meta.json".format(base=self.basepath), 127 | ) 128 | 129 | self._logger.info("Saving arguments to %s", self.paths["meta"]) 130 | if os.path.exists(self.paths["meta"]): 131 | self._logger.warning( 132 | "Path to meta file already exists. " "Not overriding meta." 133 | ) 134 | else: 135 | self._save_metadata() 136 | 137 | self._logger.info("Saving messages to %s", self.paths["msg"]) 138 | if os.path.exists(self.paths["msg"]): 139 | self._logger.warning( 140 | "Path to message file already exists. " "New data will be appended." 141 | ) 142 | 143 | fhandle = logging.FileHandler(self.paths["msg"]) 144 | fhandle.setFormatter(formatter) 145 | self._logger.addHandler(fhandle) 146 | 147 | self._logger.info("Saving logs data to %s", self.paths["logs"]) 148 | self._logger.info("Saving logs' fields to %s", self.paths["fields"]) 149 | self.fieldnames = ["_tick", "_time"] 150 | if os.path.exists(self.paths["logs"]): 151 | self._logger.warning( 152 | "Path to log file already exists. " "New data will be appended." 153 | ) 154 | # Override default fieldnames. 155 | with open(self.paths["fields"], "r") as csvfile: 156 | reader = csv.reader(csvfile) 157 | lines = list(reader) 158 | if len(lines) > 0: 159 | self.fieldnames = lines[-1] 160 | # Override default tick: use the last tick from the logs file plus 1. 161 | with open(self.paths["logs"], "r") as csvfile: 162 | reader = csv.reader(csvfile) 163 | lines = list(reader) 164 | # Need at least two lines in order to read the last tick: 165 | # the first is the csv header and the second is the first line 166 | # of data. 167 | if len(lines) > 1: 168 | self._tick = int(lines[-1][0]) + 1 169 | 170 | self._fieldfile = open(self.paths["fields"], "a") 171 | self._fieldwriter = csv.writer(self._fieldfile) 172 | self._logfile = open(self.paths["logs"], "a") 173 | self._logwriter = csv.DictWriter(self._logfile, fieldnames=self.fieldnames) 174 | 175 | def log(self, to_log: Dict, tick: int = None, verbose: bool = False) -> None: 176 | if tick is not None: 177 | raise NotImplementedError 178 | else: 179 | to_log["_tick"] = self._tick 180 | self._tick += 1 181 | to_log["_time"] = time.time() 182 | 183 | old_len = len(self.fieldnames) 184 | for k in to_log: 185 | if k not in self.fieldnames: 186 | self.fieldnames.append(k) 187 | if old_len != len(self.fieldnames): 188 | self._fieldwriter.writerow(self.fieldnames) 189 | self._logger.info("Updated log fields: %s", self.fieldnames) 190 | 191 | if to_log["_tick"] == 0: 192 | self._logfile.write("# %s\n" % ",".join(self.fieldnames)) 193 | 194 | if verbose: 195 | self._logger.info( 196 | "LOG | %s", 197 | ", ".join(["{}: {}".format(k, to_log[k]) for k in sorted(to_log)]), 198 | ) 199 | 200 | self._logwriter.writerow(to_log) 201 | self._logfile.flush() 202 | 203 | def close(self, successful: bool = True) -> None: 204 | self.metadata["date_end"] = datetime.datetime.now().strftime( 205 | "%Y-%m-%d %H:%M:%S.%f" 206 | ) 207 | self.metadata["successful"] = successful 208 | self._save_metadata() 209 | 210 | for f in [self._logfile, self._fieldfile]: 211 | f.close() 212 | 213 | def _save_metadata(self) -> None: 214 | with open(self.paths["meta"], "w") as jsonfile: 215 | json.dump(self.metadata, jsonfile, indent=4, sort_keys=True) 216 | -------------------------------------------------------------------------------- /reinforcement_learning/torchbeast/core/prof.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | """Naive profiling using timeit. (Used in MonoBeast.)""" 15 | 16 | import collections 17 | import timeit 18 | 19 | 20 | class Timings: 21 | """Not thread-safe.""" 22 | 23 | def __init__(self): 24 | self._means = collections.defaultdict(int) 25 | self._vars = collections.defaultdict(int) 26 | self._counts = collections.defaultdict(int) 27 | self.reset() 28 | 29 | def reset(self): 30 | self.last_time = timeit.default_timer() 31 | 32 | def time(self, name): 33 | """Save an update for event `name`. 34 | 35 | Nerd alarm: We could just store a 36 | collections.defaultdict(list) 37 | and compute means and standard deviations at the end. But thanks to the 38 | clever math in Sutton-Barto 39 | (http://www.incompleteideas.net/book/first/ebook/node19.html) and 40 | https://math.stackexchange.com/a/103025/5051 we can update both the 41 | means and the stds online. O(1) FTW! 42 | """ 43 | now = timeit.default_timer() 44 | x = now - self.last_time 45 | self.last_time = now 46 | 47 | n = self._counts[name] 48 | 49 | mean = self._means[name] + (x - self._means[name]) / (n + 1) 50 | var = ( 51 | n * self._vars[name] + n * (self._means[name] - mean) ** 2 + (x - mean) ** 2 52 | ) / (n + 1) 53 | 54 | self._means[name] = mean 55 | self._vars[name] = var 56 | self._counts[name] += 1 57 | 58 | def means(self): 59 | return self._means 60 | 61 | def vars(self): 62 | return self._vars 63 | 64 | def stds(self): 65 | return {k: v ** 0.5 for k, v in self._vars.items()} 66 | 67 | def summary(self, prefix=""): 68 | means = self.means() 69 | stds = self.stds() 70 | total = sum(means.values()) 71 | 72 | result = prefix 73 | for k in sorted(means, key=means.get, reverse=True): 74 | result += f"\n %s: %.6fms +- %.6fms (%.2f%%) " % ( 75 | k, 76 | 1000 * means[k], 77 | 1000 * stds[k], 78 | 100 * means[k] / total, 79 | ) 80 | result += "\nTotal: %.6fms" % (1000 * total) 81 | return result 82 | -------------------------------------------------------------------------------- /reinforcement_learning/torchbeast/core/vtrace.py: -------------------------------------------------------------------------------- 1 | # This file taken from 2 | # https://github.com/deepmind/scalable_agent/blob/ 3 | # cd66d00914d56c8ba2f0615d9cdeefcb169a8d70/vtrace.py 4 | # and modified. 5 | 6 | # Copyright 2018 Google LLC 7 | # 8 | # Licensed under the Apache License, Version 2.0 (the "License"); 9 | # you may not use this file except in compliance with the License. 10 | # You may obtain a copy of the License at 11 | # 12 | # https://www.apache.org/licenses/LICENSE-2.0 13 | # 14 | # Unless required by applicable law or agreed to in writing, software 15 | # distributed under the License is distributed on an "AS IS" BASIS, 16 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 17 | # See the License for the specific language governing permissions and 18 | # limitations under the License. 19 | """Functions to compute V-trace off-policy actor critic targets. 20 | 21 | For details and theory see: 22 | 23 | "IMPALA: Scalable Distributed Deep-RL with 24 | Importance Weighted Actor-Learner Architectures" 25 | by Espeholt, Soyer, Munos et al. 26 | 27 | See https://arxiv.org/abs/1802.01561 for the full paper. 28 | """ 29 | 30 | import collections 31 | 32 | import torch 33 | import torch.nn.functional as F 34 | 35 | 36 | VTraceFromLogitsReturns = collections.namedtuple( 37 | "VTraceFromLogitsReturns", 38 | [ 39 | "vs", 40 | "pg_advantages", 41 | "log_rhos", 42 | "behavior_action_log_probs", 43 | "target_action_log_probs", 44 | ], 45 | ) 46 | 47 | VTraceReturns = collections.namedtuple("VTraceReturns", "vs pg_advantages") 48 | 49 | 50 | def action_log_probs(policy_logits, actions): 51 | return -F.nll_loss( 52 | F.log_softmax(torch.flatten(policy_logits, 0, -2), dim=-1), 53 | torch.flatten(actions), 54 | reduction="none", 55 | ).view_as(actions) 56 | 57 | 58 | def from_logits( 59 | behavior_policy_logits, 60 | target_policy_logits, 61 | actions, 62 | discounts, 63 | rewards, 64 | values, 65 | bootstrap_value, 66 | clip_rho_threshold=1.0, 67 | clip_pg_rho_threshold=1.0, 68 | ): 69 | """V-trace for softmax policies.""" 70 | 71 | target_action_log_probs = action_log_probs(target_policy_logits, actions) 72 | behavior_action_log_probs = action_log_probs(behavior_policy_logits, actions) 73 | log_rhos = target_action_log_probs - behavior_action_log_probs 74 | vtrace_returns = from_importance_weights( 75 | log_rhos=log_rhos, 76 | discounts=discounts, 77 | rewards=rewards, 78 | values=values, 79 | bootstrap_value=bootstrap_value, 80 | clip_rho_threshold=clip_rho_threshold, 81 | clip_pg_rho_threshold=clip_pg_rho_threshold, 82 | ) 83 | return VTraceFromLogitsReturns( 84 | log_rhos=log_rhos, 85 | behavior_action_log_probs=behavior_action_log_probs, 86 | target_action_log_probs=target_action_log_probs, 87 | **vtrace_returns._asdict(), 88 | ) 89 | 90 | 91 | @torch.no_grad() 92 | def from_importance_weights( 93 | log_rhos, 94 | discounts, 95 | rewards, 96 | values, 97 | bootstrap_value, 98 | clip_rho_threshold=1.0, 99 | clip_pg_rho_threshold=1.0, 100 | ): 101 | """V-trace from log importance weights.""" 102 | with torch.no_grad(): 103 | rhos = torch.exp(log_rhos) 104 | if clip_rho_threshold is not None: 105 | clipped_rhos = torch.clamp(rhos, max=clip_rho_threshold) 106 | else: 107 | clipped_rhos = rhos 108 | 109 | cs = torch.clamp(rhos, max=1.0) 110 | # Append bootstrapped value to get [v1, ..., v_t+1] 111 | values_t_plus_1 = torch.cat( 112 | [values[1:], torch.unsqueeze(bootstrap_value, 0)], dim=0 113 | ) 114 | deltas = clipped_rhos * (rewards + discounts * values_t_plus_1 - values) 115 | 116 | acc = torch.zeros_like(bootstrap_value) 117 | result = [] 118 | for t in range(discounts.shape[0] - 1, -1, -1): 119 | acc = deltas[t] + discounts[t] * cs[t] * acc 120 | result.append(acc) 121 | result.reverse() 122 | vs_minus_v_xs = torch.stack(result) 123 | 124 | # Add V(x_s) to get v_s. 125 | vs = torch.add(vs_minus_v_xs, values) 126 | 127 | # Advantage for policy gradient. 128 | broadcasted_bootstrap_values = torch.ones_like(vs[0]) * bootstrap_value 129 | vs_t_plus_1 = torch.cat( 130 | [vs[1:], broadcasted_bootstrap_values.unsqueeze(0)], dim=0 131 | ) 132 | if clip_pg_rho_threshold is not None: 133 | clipped_pg_rhos = torch.clamp(rhos, max=clip_pg_rho_threshold) 134 | else: 135 | clipped_pg_rhos = rhos 136 | pg_advantages = clipped_pg_rhos * (rewards + discounts * vs_t_plus_1 - values) 137 | 138 | # Make sure no gradients backpropagated through the returned values. 139 | return VTraceReturns(vs=vs, pg_advantages=pg_advantages) 140 | -------------------------------------------------------------------------------- /reinforcement_learning/torchbeast/model.py: -------------------------------------------------------------------------------- 1 | import nest 2 | 3 | import torch 4 | from torch import nn 5 | from torch.nn import functional as F 6 | 7 | from torchbeast.layer import DeltaNetLayer 8 | from torchbeast.layer import LinearTransformerLayer 9 | from torchbeast.layer import FastFFRecUpdateTanhLayer 10 | from torchbeast.layer import FastRNNModelLayer 11 | from torchbeast.layer import DeltaDeltaNetLayer 12 | 13 | 14 | # Baseline model from torchbeast 15 | class Net(nn.Module): 16 | def __init__(self, num_actions, use_lstm=False): 17 | super(Net, self).__init__() 18 | self.num_actions = num_actions 19 | self.use_lstm = use_lstm 20 | 21 | self.feat_convs = [] 22 | self.resnet1 = [] 23 | self.resnet2 = [] 24 | 25 | self.convs = [] 26 | 27 | input_channels = 4 28 | for num_ch in [16, 32, 32]: 29 | feats_convs = [] 30 | feats_convs.append( 31 | nn.Conv2d( 32 | in_channels=input_channels, 33 | out_channels=num_ch, 34 | kernel_size=3, 35 | stride=1, 36 | padding=1, 37 | ) 38 | ) 39 | feats_convs.append(nn.MaxPool2d(kernel_size=3, stride=2, padding=1)) 40 | self.feat_convs.append(nn.Sequential(*feats_convs)) 41 | 42 | input_channels = num_ch 43 | 44 | for i in range(2): 45 | resnet_block = [] 46 | resnet_block.append(nn.ReLU()) 47 | resnet_block.append( 48 | nn.Conv2d( 49 | in_channels=input_channels, 50 | out_channels=num_ch, 51 | kernel_size=3, 52 | stride=1, 53 | padding=1, 54 | ) 55 | ) 56 | resnet_block.append(nn.ReLU()) 57 | resnet_block.append( 58 | nn.Conv2d( 59 | in_channels=input_channels, 60 | out_channels=num_ch, 61 | kernel_size=3, 62 | stride=1, 63 | padding=1, 64 | ) 65 | ) 66 | if i == 0: 67 | self.resnet1.append(nn.Sequential(*resnet_block)) 68 | else: 69 | self.resnet2.append(nn.Sequential(*resnet_block)) 70 | 71 | self.feat_convs = nn.ModuleList(self.feat_convs) 72 | self.resnet1 = nn.ModuleList(self.resnet1) 73 | self.resnet2 = nn.ModuleList(self.resnet2) 74 | 75 | self.fc = nn.Linear(3872, 256) 76 | 77 | # FC output size + last reward. 78 | core_output_size = self.fc.out_features + 1 79 | 80 | if use_lstm: 81 | self.core = nn.LSTM(core_output_size, 256, num_layers=1) 82 | core_output_size = 256 83 | 84 | self.policy = nn.Linear(core_output_size, self.num_actions) 85 | self.baseline = nn.Linear(core_output_size, 1) 86 | 87 | def initial_state(self, batch_size=1): 88 | if not self.use_lstm: 89 | return tuple() 90 | return tuple( 91 | torch.zeros(self.core.num_layers, batch_size, self.core.hidden_size) 92 | for _ in range(2) 93 | ) 94 | 95 | def forward(self, inputs, core_state): 96 | x = inputs["frame"] 97 | T, B, *_ = x.shape 98 | x = torch.flatten(x, 0, 1) # Merge time and batch. 99 | x = x.float() / 255.0 100 | 101 | res_input = None 102 | for i, fconv in enumerate(self.feat_convs): 103 | x = fconv(x) 104 | res_input = x 105 | x = self.resnet1[i](x) 106 | x += res_input 107 | res_input = x 108 | x = self.resnet2[i](x) 109 | x += res_input 110 | 111 | x = F.relu(x) 112 | x = x.view(T * B, -1) 113 | x = F.relu(self.fc(x)) 114 | 115 | clipped_reward = torch.clamp(inputs["reward"], -1, 1).view(T * B, 1) 116 | core_input = torch.cat([x, clipped_reward], dim=-1) 117 | 118 | if self.use_lstm: 119 | core_input = core_input.view(T, B, -1) 120 | core_output_list = [] 121 | notdone = (~inputs["done"]).float() 122 | for input, nd in zip(core_input.unbind(), notdone.unbind()): 123 | # Reset core state to zero whenever an episode ended. 124 | # Make `done` broadcastable with (num_layers, B, hidden_size) 125 | # states: 126 | nd = nd.view(1, -1, 1) 127 | core_state = nest.map(nd.mul, core_state) 128 | output, core_state = self.core(input.unsqueeze(0), core_state) 129 | core_output_list.append(output) 130 | core_output = torch.flatten(torch.cat(core_output_list), 0, 1) 131 | else: 132 | core_output = core_input 133 | 134 | policy_logits = self.policy(core_output) 135 | baseline = self.baseline(core_output) 136 | 137 | if self.training: 138 | action = torch.multinomial(F.softmax(policy_logits, dim=1), num_samples=1) 139 | else: 140 | # Don't sample when testing. 141 | action = torch.argmax(policy_logits, dim=1) 142 | 143 | policy_logits = policy_logits.view(T, B, self.num_actions) 144 | baseline = baseline.view(T, B) 145 | action = action.view(T, B) 146 | 147 | return (action, policy_logits, baseline), core_state 148 | -------------------------------------------------------------------------------- /reinforcement_learning/torchbeast/polybeast.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import argparse 16 | import multiprocessing as mp 17 | 18 | import numpy as np 19 | 20 | from torchbeast import polybeast_learner 21 | from torchbeast import polybeast_env 22 | 23 | 24 | def run_env(flags, actor_id): 25 | np.random.seed() # Get new random seed in forked process. 26 | polybeast_env.main(flags) 27 | 28 | 29 | def run_learner(flags): 30 | polybeast_learner.main(flags) 31 | 32 | 33 | def main(): 34 | flags = argparse.Namespace() 35 | flags, argv = polybeast_learner.parser.parse_known_args(namespace=flags) 36 | flags, argv = polybeast_env.parser.parse_known_args(args=argv, namespace=flags) 37 | if argv: 38 | # Produce an error message. 39 | polybeast_learner.parser.print_usage() 40 | print("") 41 | polybeast_env.parser.print_usage() 42 | print("Unkown args:", " ".join(argv)) 43 | return -1 44 | 45 | flags.num_servers = flags.num_actors 46 | env_processes = [] 47 | for actor_id in range(1): 48 | p = mp.Process(target=run_env, args=(flags, actor_id)) 49 | p.start() 50 | env_processes.append(p) 51 | 52 | run_learner(flags) 53 | 54 | for p in env_processes: 55 | # p.terminate() 56 | p.join() 57 | 58 | 59 | if __name__ == "__main__": 60 | main() 61 | -------------------------------------------------------------------------------- /reinforcement_learning/torchbeast_atari/polybeast.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import argparse 16 | import multiprocessing as mp 17 | 18 | import numpy as np 19 | 20 | from torchbeast_atari import polybeast_learner 21 | from torchbeast_atari import polybeast_env 22 | 23 | 24 | def run_env(flags, actor_id): 25 | np.random.seed() # Get new random seed in forked process. 26 | polybeast_env.main(flags) 27 | 28 | 29 | def run_learner(flags): 30 | polybeast_learner.main(flags) 31 | 32 | 33 | def main(): 34 | flags = argparse.Namespace() 35 | flags, argv = polybeast_learner.parser.parse_known_args(namespace=flags) 36 | flags, argv = polybeast_env.parser.parse_known_args(args=argv, namespace=flags) 37 | if argv: 38 | # Produce an error message. 39 | polybeast_learner.parser.print_usage() 40 | print("") 41 | polybeast_env.parser.print_usage() 42 | print("Unkown args:", " ".join(argv)) 43 | return -1 44 | 45 | flags.num_servers = flags.num_actors 46 | env_processes = [] 47 | for actor_id in range(1): 48 | p = mp.Process(target=run_env, args=(flags, actor_id)) 49 | p.start() 50 | env_processes.append(p) 51 | 52 | run_learner(flags) 53 | 54 | for p in env_processes: 55 | # p.terminate() 56 | p.join() 57 | 58 | 59 | if __name__ == "__main__": 60 | main() 61 | -------------------------------------------------------------------------------- /reinforcement_learning/torchbeast_atari/polybeast_env.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | 16 | import argparse 17 | import multiprocessing as mp 18 | import threading 19 | import time 20 | 21 | import numpy as np 22 | import libtorchbeast 23 | from torchbeast_atari import atari_wrappers 24 | 25 | 26 | # yapf: disable 27 | parser = argparse.ArgumentParser(description='Remote Environment Server') 28 | 29 | parser.add_argument("--pipes_basename", default="unix:/tmp/polybeast", 30 | help="Basename for the pipes for inter-process communication. " 31 | "Has to be of the type unix:/some/path.") 32 | parser.add_argument('--num_servers', default=4, type=int, metavar='N', 33 | help='Number of environment servers.') 34 | parser.add_argument('--env', type=str, default='PongNoFrameskip-v4', 35 | help='Gym environment.') 36 | parser.add_argument('--allow_oov', action="store_true", 37 | help='Allow action space larger than the env specific one.' 38 | ' All out-of-vocab action will be mapped to NoOp.') 39 | # yapf: enable 40 | 41 | 42 | class Env: 43 | def reset(self): 44 | print("reset called") 45 | return np.ones((4, 84, 84), dtype=np.uint8) 46 | 47 | def step(self, action): 48 | frame = np.zeros((4, 84, 84), dtype=np.uint8) 49 | return frame, 0.0, False, {} # First three mandatory. 50 | 51 | 52 | def create_env(env_name, oov, lock=threading.Lock()): 53 | with lock: # Atari isn't threadsafe at construction time. 54 | return atari_wrappers.wrap_pytorch( 55 | atari_wrappers.wrap_deepmind( 56 | atari_wrappers.make_atari(env_name), 57 | clip_rewards=False, 58 | frame_stack=True, 59 | scale=False, 60 | allow_oov_action=oov, 61 | ) 62 | ) 63 | 64 | 65 | def serve(env_name, oov, server_address): 66 | init = Env if env_name == "Mock" else lambda: create_env(env_name, oov) 67 | server = libtorchbeast.Server(init, server_address=server_address) 68 | server.run() 69 | 70 | 71 | def main(flags): 72 | if not flags.pipes_basename.startswith("unix:"): 73 | raise Exception("--pipes_basename has to be of the form unix:/some/path.") 74 | 75 | processes = [] 76 | for i in range(flags.num_servers): 77 | p = mp.Process( 78 | target=serve, args=(flags.env, flags.allow_oov, f"{flags.pipes_basename}.{i}"), daemon=True 79 | ) 80 | p.start() 81 | processes.append(p) 82 | 83 | try: 84 | # We are only here to listen to the interrupt. 85 | while True: 86 | time.sleep(10) 87 | except KeyboardInterrupt: 88 | pass 89 | 90 | 91 | if __name__ == "__main__": 92 | flags = parser.parse_args() 93 | print(f"Env: {flags.env}") 94 | main(flags) 95 | -------------------------------------------------------------------------------- /reinforcement_learning/torchbeast_dmlab/core/.history.kazuki: -------------------------------------------------------------------------------- 1 | 2021-05-09.23-47-46.v03 less environment.py 2 | 2021-05-09.23-47-49.v03 less file_writer.py 3 | 2021-05-09.23-47-53.v03 less vtrace.py 4 | 2021-05-09.23-47-55.v03 less prof.py 5 | 2021-05-09.23-48-10.v03 mv environment.py _environment.py 6 | 2021-05-09.23-48-11.v03 rp / 7 | 2021-05-09.23-48-13.v03 rp . 8 | 2021-05-09.23-48-55.v03 vimdiff environment.py _environment.py 9 | -------------------------------------------------------------------------------- /reinforcement_learning/torchbeast_dmlab/core/_environment.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | """The environment class for MonoBeast.""" 15 | 16 | import torch 17 | 18 | 19 | def _format_frame(frame): 20 | frame = torch.from_numpy(frame) 21 | return frame.view((1, 1) + frame.shape) # (...) -> (T,B,...). 22 | 23 | 24 | class Environment: 25 | def __init__(self, gym_env): 26 | self.gym_env = gym_env 27 | self.episode_return = None 28 | self.episode_step = None 29 | 30 | def initial(self): 31 | initial_reward = torch.zeros(1, 1) 32 | # This supports only single-tensor actions ATM. 33 | initial_last_action = torch.zeros(1, 1, dtype=torch.int64) 34 | self.episode_return = torch.zeros(1, 1) 35 | self.episode_step = torch.zeros(1, 1, dtype=torch.int32) 36 | initial_done = torch.ones(1, 1, dtype=torch.uint8) 37 | initial_frame = _format_frame(self.gym_env.reset()) 38 | return dict( 39 | frame=initial_frame, 40 | reward=initial_reward, 41 | done=initial_done, 42 | episode_return=self.episode_return, 43 | episode_step=self.episode_step, 44 | last_action=initial_last_action, 45 | ) 46 | 47 | def step(self, action): 48 | frame, reward, done, unused_info = self.gym_env.step(action.item()) 49 | self.episode_step += 1 50 | self.episode_return += reward 51 | episode_step = self.episode_step 52 | episode_return = self.episode_return 53 | if done: 54 | frame = self.gym_env.reset() 55 | self.episode_return = torch.zeros(1, 1) 56 | self.episode_step = torch.zeros(1, 1, dtype=torch.int32) 57 | 58 | frame = _format_frame(frame) 59 | reward = torch.tensor(reward).view(1, 1) 60 | done = torch.tensor(done).view(1, 1) 61 | 62 | return dict( 63 | frame=frame, 64 | reward=reward, 65 | done=done, 66 | episode_return=episode_return, 67 | episode_step=episode_step, 68 | last_action=action, 69 | ) 70 | 71 | def close(self): 72 | self.gym_env.close() 73 | -------------------------------------------------------------------------------- /reinforcement_learning/torchbeast_dmlab/core/environment.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | """The environment class for MonoBeast.""" 15 | 16 | import torch 17 | import numpy as np 18 | 19 | 20 | ########## The set of actions that can go as input to the DMLAB step 21 | DEFAULT_ACTION_SET = ( 22 | (0, 0, 0, 1, 0, 0, 0), # Forward 23 | (0, 0, 0, -1, 0, 0, 0), # Backward 24 | (0, 0, -1, 0, 0, 0, 0), # Strafe Left 25 | (0, 0, 1, 0, 0, 0, 0), # Strafe Right 26 | (-20, 0, 0, 0, 0, 0, 0), # Look Left 27 | (20, 0, 0, 0, 0, 0, 0), # Look Right 28 | (-20, 0, 0, 1, 0, 0, 0), # Look Left + Forward 29 | (20, 0, 0, 1, 0, 0, 0), # Look Right + Forward 30 | (0, 0, 0, 0, 1, 0, 0), # Fire. 31 | ) 32 | 33 | 34 | def _format_frame(frame): 35 | frame = torch.from_numpy(frame) 36 | return frame.view((1, 1) + frame.shape) # (...) -> (T,B,...). 37 | 38 | 39 | class Environment: 40 | def __init__(self, gym_env): 41 | self.gym_env = gym_env 42 | self.episode_return = None 43 | self.episode_step = None 44 | 45 | def initial(self): 46 | initial_reward = torch.zeros(1, 1) 47 | # This supports only single-tensor actions ATM. 48 | initial_last_action = torch.zeros(1, 1, dtype=torch.int64) 49 | self.episode_return = torch.zeros(1, 1) 50 | self.episode_step = torch.zeros(1, 1, dtype=torch.int32) 51 | initial_done = torch.ones(1, 1, dtype=torch.uint8) 52 | ###### changed reset to inital to match the format of createDMLab 53 | initial_frame = _format_frame(self.gym_env.initial()) 54 | return dict( 55 | frame=initial_frame, 56 | reward=initial_reward, 57 | done=initial_done, 58 | episode_return=self.episode_return, 59 | episode_step=self.episode_step, 60 | last_action=initial_last_action, 61 | ) 62 | 63 | def step(self, action): 64 | ######## changed actions from int to one of the default_action_set above 65 | raw_action = np.array(DEFAULT_ACTION_SET[action],dtype= np.intc) 66 | frame, reward, done = self.gym_env.step(raw_action) 67 | self.episode_step += 1 68 | self.episode_return += reward 69 | episode_step = self.episode_step 70 | episode_return = self.episode_return 71 | if done: 72 | frame = self.gym_env.initial() 73 | self.episode_return = torch.zeros(1, 1) 74 | self.episode_step = torch.zeros(1, 1, dtype=torch.int32) 75 | 76 | frame = _format_frame(frame) 77 | reward = torch.tensor(reward).view(1, 1) 78 | done = torch.tensor(done).view(1, 1) 79 | 80 | return dict( 81 | frame=frame, 82 | reward=reward, 83 | done=done, 84 | episode_return=episode_return, 85 | episode_step=episode_step, 86 | last_action=action, 87 | ) 88 | 89 | def close(self): 90 | self.gym_env.close() 91 | -------------------------------------------------------------------------------- /reinforcement_learning/torchbeast_dmlab/core/file_writer.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | 16 | import copy 17 | import csv 18 | import datetime 19 | import json 20 | import logging 21 | import os 22 | import time 23 | from typing import Dict 24 | 25 | 26 | 27 | def gather_metadata() -> Dict: 28 | date_start = datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S.%f") 29 | # Gathering git metadata. 30 | try: 31 | import git 32 | 33 | try: 34 | repo = git.Repo(search_parent_directories=True) 35 | git_sha = repo.commit().hexsha 36 | git_data = dict( 37 | commit=git_sha, 38 | branch=None if repo.head.is_detached else repo.active_branch.name, 39 | is_dirty=repo.is_dirty(), 40 | path=repo.git_dir, 41 | ) 42 | except git.InvalidGitRepositoryError: 43 | git_data = None 44 | except ImportError: 45 | git_data = None 46 | # Gathering slurm metadata. 47 | if "SLURM_JOB_ID" in os.environ: 48 | slurm_env_keys = [k for k in os.environ if k.startswith("SLURM")] 49 | slurm_data = {} 50 | for k in slurm_env_keys: 51 | d_key = k.replace("SLURM_", "").replace("SLURMD_", "").lower() 52 | slurm_data[d_key] = os.environ[k] 53 | else: 54 | slurm_data = None 55 | return dict( 56 | date_start=date_start, 57 | date_end=None, 58 | successful=False, 59 | git=git_data, 60 | slurm=slurm_data, 61 | env=os.environ.copy(), 62 | ) 63 | 64 | 65 | class FileWriter: 66 | def __init__( 67 | self, 68 | xpid: str = None, 69 | xp_args: dict = None, 70 | rootdir: str = "~/logs", 71 | symlink_to_latest: bool = True, 72 | ): 73 | if not xpid: 74 | # Make unique id. 75 | xpid = "{proc}_{unixtime}".format( 76 | proc=os.getpid(), unixtime=int(time.time()) 77 | ) 78 | self.xpid = xpid 79 | self._tick = 0 80 | 81 | # Metadata gathering. 82 | if xp_args is None: 83 | xp_args = {} 84 | self.metadata = gather_metadata() 85 | # We need to copy the args, otherwise when we close the file writer 86 | # (and rewrite the args) we might have non-serializable objects (or 87 | # other unwanted side-effects). 88 | self.metadata["args"] = copy.deepcopy(xp_args) 89 | self.metadata["xpid"] = self.xpid 90 | 91 | formatter = logging.Formatter("%(message)s") 92 | self._logger = logging.getLogger("logs/out") 93 | 94 | # To stdout handler. 95 | shandle = logging.StreamHandler() 96 | shandle.setFormatter(formatter) 97 | self._logger.addHandler(shandle) 98 | self._logger.setLevel(logging.INFO) 99 | 100 | rootdir = os.path.expandvars(os.path.expanduser(rootdir)) 101 | # To file handler. 102 | self.basepath = os.path.join(rootdir, self.xpid) 103 | if not os.path.exists(self.basepath): 104 | self._logger.info("Creating log directory: %s", self.basepath) 105 | os.makedirs(self.basepath, exist_ok=True) 106 | else: 107 | self._logger.info("Found log directory: %s", self.basepath) 108 | 109 | if symlink_to_latest: 110 | # Add 'latest' as symlink unless it exists and is no symlink. 111 | symlink = os.path.join(rootdir, "latest") 112 | try: 113 | if os.path.islink(symlink): 114 | os.remove(symlink) 115 | if not os.path.exists(symlink): 116 | os.symlink(self.basepath, symlink) 117 | self._logger.info("Symlinked log directory: %s", symlink) 118 | except OSError: 119 | # os.remove() or os.symlink() raced. Don't do anything. 120 | pass 121 | 122 | self.paths = dict( 123 | msg="{base}/out.log".format(base=self.basepath), 124 | logs="{base}/logs.csv".format(base=self.basepath), 125 | fields="{base}/fields.csv".format(base=self.basepath), 126 | meta="{base}/meta.json".format(base=self.basepath), 127 | ) 128 | 129 | self._logger.info("Saving arguments to %s", self.paths["meta"]) 130 | if os.path.exists(self.paths["meta"]): 131 | self._logger.warning( 132 | "Path to meta file already exists. " "Not overriding meta." 133 | ) 134 | else: 135 | self._save_metadata() 136 | 137 | self._logger.info("Saving messages to %s", self.paths["msg"]) 138 | if os.path.exists(self.paths["msg"]): 139 | self._logger.warning( 140 | "Path to message file already exists. " "New data will be appended." 141 | ) 142 | 143 | fhandle = logging.FileHandler(self.paths["msg"]) 144 | fhandle.setFormatter(formatter) 145 | self._logger.addHandler(fhandle) 146 | 147 | self._logger.info("Saving logs data to %s", self.paths["logs"]) 148 | self._logger.info("Saving logs' fields to %s", self.paths["fields"]) 149 | self.fieldnames = ["_tick", "_time"] 150 | if os.path.exists(self.paths["logs"]): 151 | self._logger.warning( 152 | "Path to log file already exists. " "New data will be appended." 153 | ) 154 | # Override default fieldnames. 155 | with open(self.paths["fields"], "r") as csvfile: 156 | reader = csv.reader(csvfile) 157 | lines = list(reader) 158 | if len(lines) > 0: 159 | self.fieldnames = lines[-1] 160 | # Override default tick: use the last tick from the logs file plus 1. 161 | with open(self.paths["logs"], "r") as csvfile: 162 | reader = csv.reader(csvfile) 163 | lines = list(reader) 164 | # Need at least two lines in order to read the last tick: 165 | # the first is the csv header and the second is the first line 166 | # of data. 167 | if len(lines) > 1: 168 | self._tick = int(lines[-1][0]) + 1 169 | 170 | self._fieldfile = open(self.paths["fields"], "a") 171 | self._fieldwriter = csv.writer(self._fieldfile) 172 | self._logfile = open(self.paths["logs"], "a") 173 | self._logwriter = csv.DictWriter(self._logfile, fieldnames=self.fieldnames) 174 | 175 | def log(self, to_log: Dict, tick: int = None, verbose: bool = False) -> None: 176 | if tick is not None: 177 | raise NotImplementedError 178 | else: 179 | to_log["_tick"] = self._tick 180 | self._tick += 1 181 | to_log["_time"] = time.time() 182 | 183 | old_len = len(self.fieldnames) 184 | for k in to_log: 185 | if k not in self.fieldnames: 186 | self.fieldnames.append(k) 187 | if old_len != len(self.fieldnames): 188 | self._fieldwriter.writerow(self.fieldnames) 189 | self._logger.info("Updated log fields: %s", self.fieldnames) 190 | 191 | if to_log["_tick"] == 0: 192 | self._logfile.write("# %s\n" % ",".join(self.fieldnames)) 193 | 194 | if verbose: 195 | self._logger.info( 196 | "LOG | %s", 197 | ", ".join(["{}: {}".format(k, to_log[k]) for k in sorted(to_log)]), 198 | ) 199 | 200 | self._logwriter.writerow(to_log) 201 | self._logfile.flush() 202 | 203 | def close(self, successful: bool = True) -> None: 204 | self.metadata["date_end"] = datetime.datetime.now().strftime( 205 | "%Y-%m-%d %H:%M:%S.%f" 206 | ) 207 | self.metadata["successful"] = successful 208 | self._save_metadata() 209 | 210 | for f in [self._logfile, self._fieldfile]: 211 | f.close() 212 | 213 | def _save_metadata(self) -> None: 214 | with open(self.paths["meta"], "w") as jsonfile: 215 | json.dump(self.metadata, jsonfile, indent=4, sort_keys=True) 216 | -------------------------------------------------------------------------------- /reinforcement_learning/torchbeast_dmlab/core/prof.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | """Naive profiling using timeit. (Used in MonoBeast.)""" 15 | 16 | import collections 17 | import timeit 18 | 19 | 20 | class Timings: 21 | """Not thread-safe.""" 22 | 23 | def __init__(self): 24 | self._means = collections.defaultdict(int) 25 | self._vars = collections.defaultdict(int) 26 | self._counts = collections.defaultdict(int) 27 | self.reset() 28 | 29 | def reset(self): 30 | self.last_time = timeit.default_timer() 31 | 32 | def time(self, name): 33 | """Save an update for event `name`. 34 | 35 | Nerd alarm: We could just store a 36 | collections.defaultdict(list) 37 | and compute means and standard deviations at the end. But thanks to the 38 | clever math in Sutton-Barto 39 | (http://www.incompleteideas.net/book/first/ebook/node19.html) and 40 | https://math.stackexchange.com/a/103025/5051 we can update both the 41 | means and the stds online. O(1) FTW! 42 | """ 43 | now = timeit.default_timer() 44 | x = now - self.last_time 45 | self.last_time = now 46 | 47 | n = self._counts[name] 48 | 49 | mean = self._means[name] + (x - self._means[name]) / (n + 1) 50 | var = ( 51 | n * self._vars[name] + n * (self._means[name] - mean) ** 2 + (x - mean) ** 2 52 | ) / (n + 1) 53 | 54 | self._means[name] = mean 55 | self._vars[name] = var 56 | self._counts[name] += 1 57 | 58 | def means(self): 59 | return self._means 60 | 61 | def vars(self): 62 | return self._vars 63 | 64 | def stds(self): 65 | return {k: v ** 0.5 for k, v in self._vars.items()} 66 | 67 | def summary(self, prefix=""): 68 | means = self.means() 69 | stds = self.stds() 70 | total = sum(means.values()) 71 | 72 | result = prefix 73 | for k in sorted(means, key=means.get, reverse=True): 74 | result += f"\n %s: %.6fms +- %.6fms (%.2f%%) " % ( 75 | k, 76 | 1000 * means[k], 77 | 1000 * stds[k], 78 | 100 * means[k] / total, 79 | ) 80 | result += "\nTotal: %.6fms" % (1000 * total) 81 | return result 82 | -------------------------------------------------------------------------------- /reinforcement_learning/torchbeast_dmlab/core/vtrace.py: -------------------------------------------------------------------------------- 1 | # This file taken from 2 | # https://github.com/deepmind/scalable_agent/blob/ 3 | # cd66d00914d56c8ba2f0615d9cdeefcb169a8d70/vtrace.py 4 | # and modified. 5 | 6 | # Copyright 2018 Google LLC 7 | # 8 | # Licensed under the Apache License, Version 2.0 (the "License"); 9 | # you may not use this file except in compliance with the License. 10 | # You may obtain a copy of the License at 11 | # 12 | # https://www.apache.org/licenses/LICENSE-2.0 13 | # 14 | # Unless required by applicable law or agreed to in writing, software 15 | # distributed under the License is distributed on an "AS IS" BASIS, 16 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 17 | # See the License for the specific language governing permissions and 18 | # limitations under the License. 19 | """Functions to compute V-trace off-policy actor critic targets. 20 | 21 | For details and theory see: 22 | 23 | "IMPALA: Scalable Distributed Deep-RL with 24 | Importance Weighted Actor-Learner Architectures" 25 | by Espeholt, Soyer, Munos et al. 26 | 27 | See https://arxiv.org/abs/1802.01561 for the full paper. 28 | """ 29 | 30 | import collections 31 | 32 | import torch 33 | import torch.nn.functional as F 34 | 35 | 36 | VTraceFromLogitsReturns = collections.namedtuple( 37 | "VTraceFromLogitsReturns", 38 | [ 39 | "vs", 40 | "pg_advantages", 41 | "log_rhos", 42 | "behavior_action_log_probs", 43 | "target_action_log_probs", 44 | ], 45 | ) 46 | 47 | VTraceReturns = collections.namedtuple("VTraceReturns", "vs pg_advantages") 48 | 49 | 50 | def action_log_probs(policy_logits, actions): 51 | return -F.nll_loss( 52 | F.log_softmax(torch.flatten(policy_logits, 0, -2), dim=-1), 53 | torch.flatten(actions), 54 | reduction="none", 55 | ).view_as(actions) 56 | 57 | 58 | def from_logits( 59 | behavior_policy_logits, 60 | target_policy_logits, 61 | actions, 62 | discounts, 63 | rewards, 64 | values, 65 | bootstrap_value, 66 | clip_rho_threshold=1.0, 67 | clip_pg_rho_threshold=1.0, 68 | ): 69 | """V-trace for softmax policies.""" 70 | 71 | target_action_log_probs = action_log_probs(target_policy_logits, actions) 72 | behavior_action_log_probs = action_log_probs(behavior_policy_logits, actions) 73 | log_rhos = target_action_log_probs - behavior_action_log_probs 74 | vtrace_returns = from_importance_weights( 75 | log_rhos=log_rhos, 76 | discounts=discounts, 77 | rewards=rewards, 78 | values=values, 79 | bootstrap_value=bootstrap_value, 80 | clip_rho_threshold=clip_rho_threshold, 81 | clip_pg_rho_threshold=clip_pg_rho_threshold, 82 | ) 83 | return VTraceFromLogitsReturns( 84 | log_rhos=log_rhos, 85 | behavior_action_log_probs=behavior_action_log_probs, 86 | target_action_log_probs=target_action_log_probs, 87 | **vtrace_returns._asdict(), 88 | ) 89 | 90 | 91 | @torch.no_grad() 92 | def from_importance_weights( 93 | log_rhos, 94 | discounts, 95 | rewards, 96 | values, 97 | bootstrap_value, 98 | clip_rho_threshold=1.0, 99 | clip_pg_rho_threshold=1.0, 100 | ): 101 | """V-trace from log importance weights.""" 102 | with torch.no_grad(): 103 | rhos = torch.exp(log_rhos) 104 | if clip_rho_threshold is not None: 105 | clipped_rhos = torch.clamp(rhos, max=clip_rho_threshold) 106 | else: 107 | clipped_rhos = rhos 108 | 109 | cs = torch.clamp(rhos, max=1.0) 110 | # Append bootstrapped value to get [v1, ..., v_t+1] 111 | values_t_plus_1 = torch.cat( 112 | [values[1:], torch.unsqueeze(bootstrap_value, 0)], dim=0 113 | ) 114 | deltas = clipped_rhos * (rewards + discounts * values_t_plus_1 - values) 115 | 116 | acc = torch.zeros_like(bootstrap_value) 117 | result = [] 118 | for t in range(discounts.shape[0] - 1, -1, -1): 119 | acc = deltas[t] + discounts[t] * cs[t] * acc 120 | result.append(acc) 121 | result.reverse() 122 | vs_minus_v_xs = torch.stack(result) 123 | 124 | # Add V(x_s) to get v_s. 125 | vs = torch.add(vs_minus_v_xs, values) 126 | 127 | # Advantage for policy gradient. 128 | broadcasted_bootstrap_values = torch.ones_like(vs[0]) * bootstrap_value 129 | vs_t_plus_1 = torch.cat( 130 | [vs[1:], broadcasted_bootstrap_values.unsqueeze(0)], dim=0 131 | ) 132 | if clip_pg_rho_threshold is not None: 133 | clipped_pg_rhos = torch.clamp(rhos, max=clip_pg_rho_threshold) 134 | else: 135 | clipped_pg_rhos = rhos 136 | pg_advantages = clipped_pg_rhos * (rewards + discounts * vs_t_plus_1 - values) 137 | 138 | # Make sure no gradients backpropagated through the returned values. 139 | return VTraceReturns(vs=vs, pg_advantages=pg_advantages) 140 | -------------------------------------------------------------------------------- /reinforcement_learning/torchbeast_dmlab/dmlab30.py: -------------------------------------------------------------------------------- 1 | # Copyright 2018 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # https://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Utilities for DMLab-30.""" 16 | 17 | from __future__ import absolute_import 18 | from __future__ import division 19 | from __future__ import print_function 20 | 21 | import collections 22 | 23 | import numpy as np 24 | import tensorflow as tf 25 | 26 | 27 | LEVEL_MAPPING = collections.OrderedDict([ 28 | ('rooms_collect_good_objects_train', 'rooms_collect_good_objects_test'), 29 | ('rooms_exploit_deferred_effects_train', 30 | 'rooms_exploit_deferred_effects_test'), 31 | ('rooms_select_nonmatching_object', 'rooms_select_nonmatching_object'), 32 | ('rooms_watermaze', 'rooms_watermaze'), 33 | ('rooms_keys_doors_puzzle', 'rooms_keys_doors_puzzle'), 34 | ('language_select_described_object', 'language_select_described_object'), 35 | ('language_select_located_object', 'language_select_located_object'), 36 | ('language_execute_random_task', 'language_execute_random_task'), 37 | ('language_answer_quantitative_question', 38 | 'language_answer_quantitative_question'), 39 | ('lasertag_one_opponent_small', 'lasertag_one_opponent_small'), 40 | ('lasertag_three_opponents_small', 'lasertag_three_opponents_small'), 41 | ('lasertag_one_opponent_large', 'lasertag_one_opponent_large'), 42 | ('lasertag_three_opponents_large', 'lasertag_three_opponents_large'), 43 | ('natlab_fixed_large_map', 'natlab_fixed_large_map'), 44 | ('natlab_varying_map_regrowth', 'natlab_varying_map_regrowth'), 45 | ('natlab_varying_map_randomized', 'natlab_varying_map_randomized'), 46 | ('skymaze_irreversible_path_hard', 'skymaze_irreversible_path_hard'), 47 | ('skymaze_irreversible_path_varied', 'skymaze_irreversible_path_varied'), 48 | ('psychlab_arbitrary_visuomotor_mapping', 49 | 'psychlab_arbitrary_visuomotor_mapping'), 50 | ('psychlab_continuous_recognition', 'psychlab_continuous_recognition'), 51 | ('psychlab_sequential_comparison', 'psychlab_sequential_comparison'), 52 | ('psychlab_visual_search', 'psychlab_visual_search'), 53 | ('explore_object_locations_small', 'explore_object_locations_small'), 54 | ('explore_object_locations_large', 'explore_object_locations_large'), 55 | ('explore_obstructed_goals_small', 'explore_obstructed_goals_small'), 56 | ('explore_obstructed_goals_large', 'explore_obstructed_goals_large'), 57 | ('explore_goal_locations_small', 'explore_goal_locations_small'), 58 | ('explore_goal_locations_large', 'explore_goal_locations_large'), 59 | ('explore_object_rewards_few', 'explore_object_rewards_few'), 60 | ('explore_object_rewards_many', 'explore_object_rewards_many'), 61 | ]) 62 | 63 | HUMAN_SCORES = { 64 | 'rooms_collect_good_objects_test': 10, 65 | 'rooms_exploit_deferred_effects_test': 85.65, 66 | 'rooms_select_nonmatching_object': 65.9, 67 | 'rooms_watermaze': 54, 68 | 'rooms_keys_doors_puzzle': 53.8, 69 | 'language_select_described_object': 389.5, 70 | 'language_select_located_object': 280.7, 71 | 'language_execute_random_task': 254.05, 72 | 'language_answer_quantitative_question': 184.5, 73 | 'lasertag_one_opponent_small': 12.65, 74 | 'lasertag_three_opponents_small': 18.55, 75 | 'lasertag_one_opponent_large': 18.6, 76 | 'lasertag_three_opponents_large': 31.5, 77 | 'natlab_fixed_large_map': 36.9, 78 | 'natlab_varying_map_regrowth': 24.45, 79 | 'natlab_varying_map_randomized': 42.35, 80 | 'skymaze_irreversible_path_hard': 100, 81 | 'skymaze_irreversible_path_varied': 100, 82 | 'psychlab_arbitrary_visuomotor_mapping': 58.75, 83 | 'psychlab_continuous_recognition': 58.3, 84 | 'psychlab_sequential_comparison': 39.5, 85 | 'psychlab_visual_search': 78.5, 86 | 'explore_object_locations_small': 74.45, 87 | 'explore_object_locations_large': 65.65, 88 | 'explore_obstructed_goals_small': 206, 89 | 'explore_obstructed_goals_large': 119.5, 90 | 'explore_goal_locations_small': 267.5, 91 | 'explore_goal_locations_large': 194.5, 92 | 'explore_object_rewards_few': 77.7, 93 | 'explore_object_rewards_many': 106.7, 94 | } 95 | 96 | RANDOM_SCORES = { 97 | 'rooms_collect_good_objects_test': 0.073, 98 | 'rooms_exploit_deferred_effects_test': 8.501, 99 | 'rooms_select_nonmatching_object': 0.312, 100 | 'rooms_watermaze': 4.065, 101 | 'rooms_keys_doors_puzzle': 4.135, 102 | 'language_select_described_object': -0.07, 103 | 'language_select_located_object': 1.929, 104 | 'language_execute_random_task': -5.913, 105 | 'language_answer_quantitative_question': -0.33, 106 | 'lasertag_one_opponent_small': -0.224, 107 | 'lasertag_three_opponents_small': -0.214, 108 | 'lasertag_one_opponent_large': -0.083, 109 | 'lasertag_three_opponents_large': -0.102, 110 | 'natlab_fixed_large_map': 2.173, 111 | 'natlab_varying_map_regrowth': 2.989, 112 | 'natlab_varying_map_randomized': 7.346, 113 | 'skymaze_irreversible_path_hard': 0.1, 114 | 'skymaze_irreversible_path_varied': 14.4, 115 | 'psychlab_arbitrary_visuomotor_mapping': 0.163, 116 | 'psychlab_continuous_recognition': 0.224, 117 | 'psychlab_sequential_comparison': 0.129, 118 | 'psychlab_visual_search': 0.085, 119 | 'explore_object_locations_small': 3.575, 120 | 'explore_object_locations_large': 4.673, 121 | 'explore_obstructed_goals_small': 6.76, 122 | 'explore_obstructed_goals_large': 2.61, 123 | 'explore_goal_locations_small': 7.66, 124 | 'explore_goal_locations_large': 3.14, 125 | 'explore_object_rewards_few': 2.073, 126 | 'explore_object_rewards_many': 2.438, 127 | } 128 | 129 | ALL_LEVELS = frozenset([ 130 | 'rooms_collect_good_objects_train', 131 | 'rooms_collect_good_objects_test', 132 | 'rooms_exploit_deferred_effects_train', 133 | 'rooms_exploit_deferred_effects_test', 134 | 'rooms_select_nonmatching_object', 135 | 'rooms_watermaze', 136 | 'rooms_keys_doors_puzzle', 137 | 'language_select_described_object', 138 | 'language_select_located_object', 139 | 'language_execute_random_task', 140 | 'language_answer_quantitative_question', 141 | 'lasertag_one_opponent_small', 142 | 'lasertag_three_opponents_small', 143 | 'lasertag_one_opponent_large', 144 | 'lasertag_three_opponents_large', 145 | 'natlab_fixed_large_map', 146 | 'natlab_varying_map_regrowth', 147 | 'natlab_varying_map_randomized', 148 | 'skymaze_irreversible_path_hard', 149 | 'skymaze_irreversible_path_varied', 150 | 'psychlab_arbitrary_visuomotor_mapping', 151 | 'psychlab_continuous_recognition', 152 | 'psychlab_sequential_comparison', 153 | 'psychlab_visual_search', 154 | 'explore_object_locations_small', 155 | 'explore_object_locations_large', 156 | 'explore_obstructed_goals_small', 157 | 'explore_obstructed_goals_large', 158 | 'explore_goal_locations_small', 159 | 'explore_goal_locations_large', 160 | 'explore_object_rewards_few', 161 | 'explore_object_rewards_many', 162 | ]) 163 | 164 | 165 | def _transform_level_returns(level_returns): 166 | """Converts training level names to test level names.""" 167 | new_level_returns = {} 168 | for level_name, returns in level_returns.items(): 169 | new_level_returns[LEVEL_MAPPING.get(level_name, level_name)] = returns 170 | 171 | test_set = set(LEVEL_MAPPING.values()) 172 | diff = test_set - set(new_level_returns.keys()) 173 | if diff: 174 | raise ValueError('Missing levels: %s' % list(diff)) 175 | 176 | for level_name, returns in new_level_returns.items(): 177 | if level_name in test_set: 178 | if not returns: 179 | raise ValueError('Missing returns for level: \'%s\': ' % level_name) 180 | else: 181 | tf.logging.info('Skipping level %s for calculation.', level_name) 182 | 183 | return new_level_returns 184 | 185 | 186 | def compute_human_normalized_score(level_returns, per_level_cap): 187 | """Computes human normalized score. 188 | Levels that have different training and test versions, will use the returns 189 | for the training level to calculate the score. E.g. 190 | 'rooms_collect_good_objects_train' will be used for 191 | 'rooms_collect_good_objects_test'. All returns for levels not in DmLab-30 192 | will be ignored. 193 | Args: 194 | level_returns: A dictionary from level to list of episode returns. 195 | per_level_cap: A percentage cap (e.g. 100.) on the per level human 196 | normalized score. If None, no cap is applied. 197 | Returns: 198 | A float with the human normalized score in percentage. 199 | Raises: 200 | ValueError: If a level is missing from `level_returns` or has no returns. 201 | """ 202 | new_level_returns = _transform_level_returns(level_returns) 203 | 204 | def human_normalized_score(level_name, returns): 205 | score = np.mean(returns) 206 | human = HUMAN_SCORES[level_name] 207 | random = RANDOM_SCORES[level_name] 208 | human_normalized_score = (score - random) / (human - random) * 100 209 | if per_level_cap is not None: 210 | human_normalized_score = min(human_normalized_score, per_level_cap) 211 | return human_normalized_score 212 | 213 | return np.mean( 214 | [human_normalized_score(k, v) for k, v in new_level_returns.items()]) -------------------------------------------------------------------------------- /reinforcement_learning/torchbeast_dmlab/dmlab_wrappers.py: -------------------------------------------------------------------------------- 1 | # The MIT License 2 | # 3 | # Copyright (c) 2017 OpenAI (http://openai.com) 4 | # 5 | # Permission is hereby granted, free of charge, to any person obtaining a copy 6 | # of this software and associated documentation files (the "Software"), to deal 7 | # in the Software without restriction, including without limitation the rights 8 | # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | # copies of the Software, and to permit persons to whom the Software is 10 | # furnished to do so, subject to the following conditions: 11 | # 12 | # The above copyright notice and this permission notice shall be included in 13 | # all copies or substantial portions of the Software. 14 | # 15 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN 21 | # THE SOFTWARE. 22 | 23 | import numpy as np 24 | from collections import deque 25 | import deepmind_lab 26 | 27 | 28 | # 9 actions, each encoded by a 7-dim vector. 29 | DEFAULT_ACTION_SET_TUPLE = ( 30 | (0, 0, 0, 1, 0, 0, 0), # Forward 31 | (0, 0, 0, -1, 0, 0, 0), # Backward 32 | (0, 0, -1, 0, 0, 0, 0), # Strafe Left 33 | (0, 0, 1, 0, 0, 0, 0), # Strafe Right 34 | (-20, 0, 0, 0, 0, 0, 0), # Look Left 35 | (20, 0, 0, 0, 0, 0, 0), # Look Right 36 | (-20, 0, 0, 1, 0, 0, 0), # Look Left + Forward 37 | (20, 0, 0, 1, 0, 0, 0), # Look Right + Forward 38 | (0, 0, 0, 0, 1, 0, 0), # Fire. 39 | ) 40 | 41 | DEFAULT_ACTION_SET = np.array(DEFAULT_ACTION_SET_TUPLE, dtype=np.intc) 42 | 43 | 44 | # Based on https://github.com/deepmind/scalable_agent/blob/master/environments.py 45 | class create_env_dmlab(object): 46 | """Wrapper around DMLab-30 env.""" 47 | 48 | def __init__(self, level, config, seed, skip=4, 49 | runfiles_path=None, level_cache=None): 50 | self._skip = skip 51 | self.obs_shape = (3, 72, 96) # reshape from (72, 96, 3) RGB_INTERLEAVED 52 | self._random_state = np.random.RandomState(seed=seed) 53 | if runfiles_path: 54 | deepmind_lab.set_runfiles_path(runfiles_path) 55 | config = {k: str(v) for k, v in config.items()} 56 | self._observation_spec = ['RGB_INTERLEAVED'] 57 | self._env = deepmind_lab.Lab( 58 | level=level, 59 | observations=self._observation_spec, 60 | config=config, 61 | level_cache=level_cache, 62 | ) 63 | self._obs_buffer = np.zeros((2,)+self.obs_shape, dtype=np.uint8) 64 | 65 | # Minimum required (see structure/expected return object): 66 | # def reset(self): 67 | # print("reset called") 68 | # return np.ones((4, 84, 84), dtype=np.uint8) 69 | 70 | # def step(self, action): 71 | # frame = np.zeros((4, 84, 84), dtype=np.uint8) 72 | # return frame, 0.0, False, {} # First three mandatory. 73 | 74 | def reset(self): 75 | # return observation! 76 | self._env.reset(seed=self._random_state.randint(0, 2 ** 31 - 1)) 77 | d = self._env.observations() 78 | return d['RGB_INTERLEAVED'].transpose(2, 0, 1) 79 | 80 | def get_obs(self): 81 | d = self._env.observations() 82 | return d['RGB_INTERLEAVED'].transpose(2, 0, 1) 83 | 84 | # atari style max pooling-based frame skipping 85 | def step(self, action): 86 | # `action` is an index here 87 | action_code = DEFAULT_ACTION_SET[action] 88 | total_reward = 0.0 89 | done = None 90 | 91 | for i in range(self._skip): 92 | reward = self._env.step(action_code) 93 | total_reward += reward 94 | done = not self._env.is_running() 95 | if done: 96 | max_frame = self.reset() 97 | return max_frame, total_reward, done, {} 98 | 99 | obs = np.array(self.get_obs(),dtype=np.uint8) 100 | if i == self._skip - 2: self._obs_buffer[0] = obs 101 | if i == self._skip - 1: self._obs_buffer[1] = obs 102 | 103 | max_frame = self._obs_buffer.max(axis=0) 104 | 105 | return max_frame, total_reward, done, {} 106 | 107 | # def step(self, action): 108 | # # `action` is an index here 109 | # action_code = DEFAULT_ACTION_SET[action] 110 | # reward = self._env.step(action_code) 111 | # 112 | # done = not self._env.is_running() 113 | # if done: 114 | # self.reset() 115 | # 116 | # observation = np.array(self.get_obs(),dtype=np.uint8) 117 | # return observation, reward, done, {} 118 | 119 | def close(self): 120 | self._env.close() 121 | -------------------------------------------------------------------------------- /reinforcement_learning/torchbeast_dmlab/polybeast.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import argparse 16 | import multiprocessing as mp 17 | 18 | import numpy as np 19 | 20 | from torchbeast_dmlab import polybeast_learner 21 | from torchbeast_dmlab import polybeast_env 22 | 23 | 24 | def run_env(flags, actor_id): 25 | np.random.seed() # Get new random seed in forked process. 26 | polybeast_env.main(flags) 27 | 28 | 29 | def run_learner(flags): 30 | polybeast_learner.main(flags) 31 | 32 | 33 | def main(): 34 | flags = argparse.Namespace() 35 | flags, argv = polybeast_learner.parser.parse_known_args(namespace=flags) 36 | flags, argv = polybeast_env.parser.parse_known_args(args=argv, namespace=flags) 37 | if argv: 38 | # Produce an error message. 39 | polybeast_learner.parser.print_usage() 40 | print("") 41 | polybeast_env.parser.print_usage() 42 | print("Unkown args:", " ".join(argv)) 43 | return -1 44 | 45 | flags.num_servers = flags.num_actors 46 | env_processes = [] 47 | for actor_id in range(1): 48 | p = mp.Process(target=run_env, args=(flags, actor_id)) 49 | p.start() 50 | env_processes.append(p) 51 | 52 | run_learner(flags) 53 | 54 | for p in env_processes: 55 | p.join() 56 | 57 | 58 | if __name__ == "__main__": 59 | main() 60 | -------------------------------------------------------------------------------- /reinforcement_learning/torchbeast_dmlab/polybeast_env.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | 16 | import argparse 17 | import multiprocessing as mp 18 | import threading 19 | import time 20 | 21 | import numpy as np 22 | import libtorchbeast 23 | from torchbeast_dmlab import dmlab_wrappers 24 | 25 | from torchbeast_dmlab import atari_wrappers 26 | 27 | 28 | # yapf: disable 29 | parser = argparse.ArgumentParser(description='Remote Environment Server') 30 | 31 | parser.add_argument("--pipes_basename", default="unix:/tmp/polybeast", 32 | help="Basename for the pipes for inter-process communication. " 33 | "Has to be of the type unix:/some/path.") 34 | parser.add_argument('--num_servers', default=4, type=int, metavar='N', 35 | help='Number of environment servers.') 36 | parser.add_argument('--env', type=str, default='psychlab_arbitrary_visuomotor_mapping', 37 | help='DMlab environment.') 38 | parser.add_argument('--seed', default=1, type=int, metavar='N', 39 | help='seed.') 40 | parser.add_argument('--allow_oov', action="store_true", 41 | help='Allow action space larger than the env specific one.' 42 | ' All out-of-vocab action will be mapped to NoOp.') 43 | # yapf: enable 44 | 45 | 46 | class Env: 47 | def reset(self): 48 | print("reset called") 49 | return np.ones((4, 84, 84), dtype=np.uint8) 50 | 51 | def step(self, action): 52 | frame = np.zeros((4, 84, 84), dtype=np.uint8) 53 | return frame, 0.0, False, {} # First three mandatory. 54 | 55 | 56 | def create_env(env_name, seed=1, lock=threading.Lock()): 57 | level_name = 'contributed/dmlab30/' + env_name 58 | config = { 59 | 'width': 96, 60 | 'height': 72, 61 | 'logLevel': 'WARN', 62 | } 63 | with lock: 64 | return dmlab_wrappers.create_env_dmlab(level_name, config, seed) 65 | 66 | 67 | def create_test_env(env_name, seed=1): 68 | level_name = 'contributed/dmlab30/' + env_name 69 | config = { 70 | 'width': 96, 71 | 'height': 72, 72 | 'logLevel': 'WARN', 73 | } 74 | return dmlab_wrappers.create_env_dmlab(level_name, config, seed) 75 | 76 | 77 | def serve(env_name, server_address, seed=1): 78 | init = Env if env_name == "Mock" else lambda: create_env(env_name, seed) 79 | server = libtorchbeast.Server(init, server_address=server_address) 80 | server.run() 81 | 82 | 83 | def main(flags): 84 | if not flags.pipes_basename.startswith("unix:"): 85 | raise Exception("--pipes_basename has to be of the form unix:/some/path.") 86 | 87 | processes = [] 88 | for i in range(flags.num_servers): 89 | p = mp.Process( 90 | target=serve, args=(flags.env, f"{flags.pipes_basename}.{i}", flags.seed), daemon=True 91 | ) 92 | p.start() 93 | processes.append(p) 94 | 95 | try: 96 | # We are only here to listen to the interrupt. 97 | while True: 98 | time.sleep(10) 99 | except KeyboardInterrupt: 100 | pass 101 | 102 | 103 | if __name__ == "__main__": 104 | flags = parser.parse_args() 105 | print(f"Env: {flags.env}") 106 | main(flags) 107 | -------------------------------------------------------------------------------- /reinforcement_learning/torchbeast_procgen/polybeast.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import argparse 16 | import multiprocessing as mp 17 | 18 | import numpy as np 19 | 20 | from torchbeast_procgen import polybeast_learner 21 | from torchbeast_procgen import polybeast_env 22 | 23 | 24 | def run_env(flags, actor_id): 25 | np.random.seed() # Get new random seed in forked process. 26 | polybeast_env.main(flags) 27 | 28 | 29 | def run_learner(flags): 30 | polybeast_learner.main(flags) 31 | 32 | 33 | def main(): 34 | flags = argparse.Namespace() 35 | flags, argv = polybeast_learner.parser.parse_known_args(namespace=flags) 36 | flags, argv = polybeast_env.parser.parse_known_args(args=argv, namespace=flags) 37 | if argv: 38 | # Produce an error message. 39 | polybeast_learner.parser.print_usage() 40 | print("") 41 | polybeast_env.parser.print_usage() 42 | print("Unkown args:", " ".join(argv)) 43 | return -1 44 | 45 | flags.num_servers = flags.num_actors 46 | env_processes = [] 47 | for actor_id in range(1): 48 | p = mp.Process(target=run_env, args=(flags, actor_id)) 49 | p.start() 50 | env_processes.append(p) 51 | 52 | run_learner(flags) 53 | 54 | for p in env_processes: 55 | # p.terminate() 56 | p.join() 57 | 58 | 59 | if __name__ == "__main__": 60 | main() 61 | -------------------------------------------------------------------------------- /reinforcement_learning/torchbeast_procgen/polybeast_env.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | 16 | import argparse 17 | import multiprocessing as mp 18 | import threading 19 | import time 20 | 21 | import numpy as np 22 | import libtorchbeast 23 | from torchbeast_procgen import procgen_wrappers 24 | import gym 25 | 26 | 27 | list_procgen_env_lex = [ 28 | 'bigfish', 29 | 'bossfight', 30 | 'caveflyer', 31 | 'chaser', 32 | 'climber', 33 | 'coinrun', 34 | 'dodgeball', 35 | 'fruitbot', 36 | 'heist', 37 | 'jumper', 38 | 'leaper', 39 | 'maze', 40 | 'miner', 41 | 'ninja', 42 | 'plunder', 43 | 'starpilot', 44 | ] 45 | 46 | # interesting first 47 | list_procgen_env = [ 48 | 'bigfish', 49 | 'fruitbot', 50 | 'maze', 51 | 'leaper', 52 | 'plunder', 53 | 'starpilot', 54 | 'miner', 55 | 'bossfight', 56 | 'caveflyer', 57 | 'chaser', 58 | 'climber', 59 | 'coinrun', 60 | 'dodgeball', 61 | 'heist', 62 | 'jumper', 63 | 'ninja', 64 | ] 65 | 66 | # env with memory mode extension 67 | list_procgen_env_mem = [ 68 | 'dodgeball', 69 | 'heist', 70 | 'maze', 71 | 'miner', 72 | 'caveflyer', 73 | 'jumper', 74 | ] 75 | 76 | 77 | # yapf: disable 78 | parser = argparse.ArgumentParser(description='Remote Environment Server') 79 | 80 | parser.add_argument("--pipes_basename", default="unix:/tmp/polybeast", 81 | help="Basename for the pipes for inter-process communication. " 82 | "Has to be of the type unix:/some/path.") 83 | parser.add_argument('--num_servers', default=4, type=int, metavar='N', 84 | help='Number of environment servers.') 85 | parser.add_argument('--env', type=str, default='PongNoFrameskip-v4', 86 | help='Gym environment.') 87 | parser.add_argument('--multi_env', default=1, type=int, metavar='N', 88 | help='number of env to jointly train on.') 89 | parser.add_argument('--allow_oov', action="store_true", 90 | help='Allow action space larger than the env specific one.' 91 | ' All out-of-vocab action will be mapped to NoOp.') 92 | parser.add_argument('--num_levels', default=0, type=int, metavar='N', 93 | help='Procgen num_levels.') 94 | parser.add_argument('--start_level', default=0, type=int, metavar='N', 95 | help='Procgen start_level.') 96 | parser.add_argument('--distribution_mode', type=str, default='hard', 97 | choices=[ 98 | 'easy', 'hard', 'extreme', 'memory', 'exploration'], 99 | help='distribution mode.') 100 | # yapf: enable 101 | 102 | 103 | class Env: 104 | def reset(self): 105 | print("reset called") 106 | return np.ones((4, 84, 84), dtype=np.uint8) 107 | 108 | def step(self, action): 109 | frame = np.zeros((4, 84, 84), dtype=np.uint8) 110 | return frame, 0.0, False, {} # First three mandatory. 111 | 112 | 113 | def create_env(env_name, num_levels=0, start_level=0, distribution_mode="hard", 114 | rand_seed=None, lock=threading.Lock()): 115 | with lock: # Atari isn't threadsafe at construction time. 116 | return procgen_wrappers.wrap_pytorch( 117 | procgen_wrappers.wrap_deepmind( 118 | gym.make(env_name, 119 | num_levels=num_levels, 120 | start_level=start_level, 121 | distribution_mode=distribution_mode, 122 | rand_seed=rand_seed), 123 | clip_rewards=False, 124 | ) 125 | ) 126 | 127 | 128 | def serve(env_name, num_levels, start_level, distribution_mode, 129 | server_address): 130 | init = Env if env_name == "Mock" else lambda: create_env( 131 | env_name, num_levels, start_level, distribution_mode) 132 | server = libtorchbeast.Server(init, server_address=server_address) 133 | server.run() 134 | 135 | 136 | def main(flags): 137 | if not flags.pipes_basename.startswith("unix:"): 138 | raise Exception( 139 | "--pipes_basename has to be of the form unix:/some/path.") 140 | 141 | if flags.distribution_mode == 'memory': # for multi_env training 142 | list_env = list_procgen_env_mem 143 | else: 144 | list_env = list_procgen_env 145 | 146 | processes = [] 147 | for i in range(flags.num_servers): 148 | if flags.multi_env > 1: 149 | env_name = list_env[i % flags.multi_env] 150 | env_name = f"procgen:procgen-{env_name}-v0" 151 | else: 152 | env_name = flags.env 153 | print(f"Server {i} on {env_name}") 154 | # distributed mode and rand_seed left to default. 155 | p = mp.Process( 156 | target=serve, args=( 157 | env_name, flags.num_levels, flags.start_level, 158 | flags.distribution_mode, f"{flags.pipes_basename}.{i}"), 159 | daemon=True 160 | ) 161 | p.start() 162 | processes.append(p) 163 | 164 | try: 165 | # We are only here to listen to the interrupt. 166 | while True: 167 | time.sleep(10) 168 | except KeyboardInterrupt: 169 | pass 170 | 171 | 172 | if __name__ == "__main__": 173 | flags = parser.parse_args() 174 | print(f"Env: {flags.env}") 175 | main(flags) 176 | --------------------------------------------------------------------------------