├── .gitignore ├── README.md ├── act ├── __init__.py ├── configuration.py ├── data.py ├── models.py ├── test_models.py ├── train.py └── utils.py └── run_train.py /.gitignore: -------------------------------------------------------------------------------- 1 | *.pyc 2 | .cache 3 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Adaptive Computation Time 2 | 3 | This is an implementation of [Adaptive Computation Time](https://arxiv.org/abs/1603.08983) (Graves, 2016) in PyTorch. 4 | 5 | ### Introduction 6 | 7 | *Adaptive Computation Time* is a drop-in replacement for RNNs structures that allows the model to process multiple time steps on a single input token. More information can be found in the paper, or in this [blog post](http://jasonphang.com/adaptive-computation-time.html). 8 | 9 | ### Requirements 10 | 11 | * Python 3.6 12 | * PyTorch 0.3.0 13 | * `matplotlib`, `argparse` 14 | 15 | ### Experiments 16 | 17 | I am still in the process of replicating the experiments described in the paper. 18 | 19 | - [x] Bit Parity 20 | - [x] Logical Gates 21 | - [ ] Addition 22 | - [ ] Sorting 23 | - [ ] Word Prediction 24 | 25 | ### Usage 26 | 27 | 1. Git clone this repository 28 | 2. Train/Evaluate the model on a given task/parameter setting: 29 | * E.g. 30 | 31 | ```bash 32 | python run_train.py \ 33 | --task=parity \ 34 | --use_act=False \ 35 | --model_save_path="outputs/models/parity/rnn" 36 | ``` 37 | 38 | ```bash 39 | python run_train.py \ 40 | --task=parity \ 41 | --use_act=True \ 42 | --act_ponder_penalty=0.001 \ 43 | --model_save_path="outputs/models/parity/act_0.001" 44 | ``` -------------------------------------------------------------------------------- /act/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zphang/adaptive-computation-time-pytorch/938bb797eb7ba095afd3e4df7788fc74b736636b/act/__init__.py -------------------------------------------------------------------------------- /act/configuration.py: -------------------------------------------------------------------------------- 1 | import attr 2 | import torch 3 | import argparse 4 | 5 | 6 | def argparse_attr(default=attr.NOTHING, validator=None, 7 | repr=True, cmp=True, hash=True, init=True, 8 | convert=None, opt_string=None, 9 | **argparse_kwargs): 10 | if opt_string is None: 11 | opt_string_ls = [] 12 | elif isinstance(opt_string, str): 13 | opt_string_ls = [opt_string] 14 | else: 15 | opt_string_ls = opt_string 16 | 17 | if argparse_kwargs.get("type", None) is bool: 18 | argparse_kwargs["choices"] = {True, False} 19 | argparse_kwargs["type"] = _is_true 20 | 21 | return attr.attr( 22 | default=default, 23 | validator=validator, 24 | repr=repr, 25 | cmp=cmp, 26 | hash=hash, 27 | init=init, 28 | convert=convert, 29 | metadata={ 30 | "opt_string_ls": opt_string_ls, 31 | "argparse_kwargs": argparse_kwargs, 32 | } 33 | ) 34 | 35 | 36 | def update_parser(parser, class_with_attributes): 37 | for attribute in class_with_attributes.__attrs_attrs__: 38 | if "argparse_kwargs" in attribute.metadata: 39 | argparse_kwargs = attribute.metadata["argparse_kwargs"] 40 | opt_string_ls = attribute.metadata["opt_string_ls"] 41 | if attribute.default is attr.NOTHING: 42 | argparse_kwargs = argparse_kwargs.copy() 43 | argparse_kwargs["required"] = True 44 | else: 45 | argparse_kwargs["default"] = attribute.default 46 | parser.add_argument( 47 | f"--{attribute.name}", *opt_string_ls, 48 | **argparse_kwargs 49 | ) 50 | 51 | 52 | def read_parser(parser, class_with_attributes, skip_non_class_attributes=False): 53 | attribute_name_set = { 54 | attribute.name 55 | for attribute in class_with_attributes.__attrs_attrs__ 56 | } 57 | 58 | kwargs = dict() 59 | leftover_kwargs = dict() 60 | 61 | for k, v in vars(parser.parse_args()).items(): 62 | if k in attribute_name_set: 63 | kwargs[k] = v 64 | else: 65 | if not skip_non_class_attributes: 66 | raise RuntimeError(f"Unknown attribute {k}") 67 | leftover_kwargs[k] = v 68 | 69 | instance = class_with_attributes(**kwargs) 70 | if skip_non_class_attributes: 71 | return instance, leftover_kwargs 72 | else: 73 | return instance 74 | 75 | 76 | def _is_true(x): 77 | return x == "True" 78 | 79 | 80 | @attr.s 81 | class Config: 82 | 83 | # Global configuration 84 | cuda = argparse_attr( 85 | default=torch.has_cudnn, type=bool, 86 | help="Whether to use cuda", 87 | ) 88 | seed = argparse_attr( 89 | default=1234, type=int, 90 | help="Seed", 91 | ) 92 | batch_size = argparse_attr( 93 | default=128 * 16, type=int, 94 | help="Batch size for model", 95 | ) 96 | 97 | # ACT configuration 98 | use_act = argparse_attr( 99 | default=True, type=bool, 100 | help="Whether to use ACT", 101 | ) 102 | act_max_ponder = argparse_attr( 103 | default=100, type=int, 104 | help="Maximum number of ponder steps", 105 | ) 106 | act_epsilon = argparse_attr( 107 | default=0.01, type=int, 108 | help="Epsilon margin for halting", 109 | ) 110 | act_ponder_penalty = argparse_attr( 111 | default=0.0001, type=float, 112 | help="Weight for ponder cost", 113 | ) 114 | 115 | # Task configuration 116 | task = argparse_attr( 117 | default=None, type=str, 118 | help="Experiment Task (parity|logic)", 119 | ) 120 | 121 | # Train configuration 122 | learning_rate = argparse_attr( 123 | default=0.1 ** 4 * 16, type=float, 124 | help="Learning rate", 125 | ) 126 | train_log = argparse_attr( 127 | default=True, type=bool, 128 | help="Whether to have verbose training logs", 129 | ) 130 | train_log_interval = argparse_attr( 131 | default=10, type=int, 132 | help="How often to output training log messages", 133 | ) 134 | num_epochs = argparse_attr( 135 | default=64, type=int, 136 | help="Number of training epochs", 137 | ) 138 | 139 | # Test configuration 140 | test_percentage = argparse_attr( 141 | default=0.1 / 16, type=float, 142 | help="Size of test set, as percentage of training set. " 143 | "For synthetic tasks", 144 | ) 145 | 146 | # Task: Parity configuration 147 | parity_data_len = argparse_attr( 148 | default=100000 * 16, type=int, 149 | help="Samples in training epoch", 150 | ) 151 | parity_input_size = argparse_attr( 152 | default=64, type=int, 153 | help="Size of parity input", 154 | ) 155 | parity_rnn_size = argparse_attr( 156 | default=128, type=int, 157 | help="Hidden size of RNN", 158 | ) 159 | parity_rnn_type = argparse_attr( 160 | default="RNN", type=str, 161 | help="RNN type (RNN/LSTM)", 162 | ) 163 | 164 | # Task: Logic configuration 165 | logic_data_len = argparse_attr( 166 | default=10000 * 16, type=int, 167 | help="Samples in training epoch", 168 | ) 169 | logic_input_size = argparse_attr( 170 | default=102, type=int, 171 | help="Size of logic input", 172 | ) 173 | logic_rnn_size = argparse_attr( 174 | default=128, type=int, 175 | help="Hidden size of RNN", 176 | ) 177 | logic_rnn_type = argparse_attr( 178 | default="LSTM", type=str, 179 | help="RNN type (RNN/LSTM)", 180 | ) 181 | 182 | # Saving 183 | model_save_path = argparse_attr( 184 | default=None, type=str, 185 | help="Folder to save models", 186 | required=True, 187 | ) 188 | model_save_interval = argparse_attr( 189 | default=1, type=int, 190 | help="Epoch intervals to save model" 191 | ) 192 | 193 | @classmethod 194 | def parse_configuration(cls, prog=None, description=None): 195 | parser = argparse.ArgumentParser( 196 | prog=prog, 197 | description=description, 198 | ) 199 | update_parser( 200 | parser=parser, 201 | class_with_attributes=cls, 202 | ) 203 | return read_parser( 204 | parser=parser, 205 | class_with_attributes=cls, 206 | ) 207 | -------------------------------------------------------------------------------- /act/data.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import torch.utils.data 4 | 5 | 6 | class MultiDataset(torch.utils.data.Dataset): 7 | def __init__(self, *data_list): 8 | assert len(data_list) > 0 9 | self.data_length = len(data_list[0]) 10 | for data in data_list[1:]: 11 | assert len(data) == self.data_length 12 | self.data_list = data_list 13 | 14 | def __getitem__(self, index): 15 | return [ 16 | data[index] 17 | for data in self.data_list 18 | ] 19 | 20 | def __len__(self): 21 | return self.data_length 22 | 23 | 24 | class DataManager: 25 | @classmethod 26 | def create_data(cls, *args, **kwargs): 27 | raise NotImplementedError 28 | 29 | @classmethod 30 | def _get_length(cls, config): 31 | raise NotImplementedError 32 | 33 | @classmethod 34 | def _get_dataloader(cls, data, batch_size): 35 | data_x, data_y = data 36 | return torch.utils.data.DataLoader( 37 | MultiDataset(data_x, data_y), 38 | batch_size=batch_size, 39 | shuffle=True, 40 | ) 41 | 42 | @classmethod 43 | def create_dataloader(cls, config, mode="train"): 44 | length = cls._get_length(config) 45 | if mode == "train": 46 | pass 47 | elif mode == "test": 48 | length = int(config.test_percentage * length) 49 | else: 50 | raise KeyError(mode) 51 | data = cls.create_data(length=length) 52 | return cls._get_dataloader(data=data, batch_size=config.batch_size) 53 | 54 | 55 | class ParityDataManager(DataManager): 56 | @classmethod 57 | def create_data(cls, length): 58 | parity_x = np.random.randint(2, size=(length, 64)).astype( 59 | np.float32) * 2 - 1 60 | zero_out = np.random.randint(1, 64, size=length) 61 | for i in range(length): 62 | parity_x[i, zero_out[i]:] = 0. 63 | parity_y = (np.sum(parity_x == 1, axis=1) % 2).astype(np.float32) 64 | return np.expand_dims(parity_x, 1), parity_y 65 | 66 | @classmethod 67 | def _get_length(cls, config): 68 | return config.parity_data_len 69 | 70 | 71 | class LogicDataManager(DataManager): 72 | LOGIC_TABLE = np.array([ 73 | [[1, 0], [0, 0]], # NOR 74 | [[0, 1], [0, 0]], # Xq 75 | [[0, 0], [1, 0]], # ABJ 76 | [[0, 1], [1, 0]], # XOR 77 | [[1, 1], [1, 0]], # NAND 78 | [[0, 0], [0, 1]], # AND 79 | [[1, 0], [0, 1]], # XNOR 80 | [[1, 1], [0, 1]], # if/then 81 | [[1, 0], [1, 1]], # then/if 82 | [[0, 1], [1, 1]], # OR 83 | ]) 84 | 85 | @classmethod 86 | def create_data(cls, length): 87 | p_and_q = np.random.randint(2, size=(length, 10, 2)) 88 | p_and_q[:, 1:, 1] = 0 89 | 90 | operations = np.random.randint(0, 10, size=(length, 10, 10)) 91 | num_operations = np.random.randint(1, 11, size=(length, 10)) 92 | for i in range(length): 93 | for t in range(10): 94 | operations[i, t, num_operations[i, t]:] = -1 95 | one_hot_operations = np.zeros((length, 10, 100)) 96 | 97 | logic_y = np.empty(shape=(length, 10)) 98 | for row_index, (row_p_and_q, row_operations) in enumerate( 99 | zip(p_and_q, operations)): 100 | b_0 = row_p_and_q[0, 0] 101 | for t in range(10): 102 | for op_index, operation in enumerate(row_operations[t]): 103 | if operation == -1: 104 | break 105 | one_hot_operations[ 106 | row_index, t, op_index * 10 + operation] = 1 107 | 108 | result = cls._resolve_logic(b_0, row_p_and_q[t, 0], 109 | row_operations[t]) 110 | logic_y[row_index, t] = result 111 | b_0 = result 112 | 113 | logic_x = np.concatenate([p_and_q, one_hot_operations], axis=2) 114 | return ( 115 | logic_x.astype(np.float32), 116 | np.expand_dims(logic_y.astype(np.float32), 2), 117 | ) 118 | 119 | @classmethod 120 | def _get_length(cls, config): 121 | return config.logic_data_len 122 | 123 | @classmethod 124 | def _resolve_logic(cls, p, q, op_list): 125 | for op in op_list: 126 | if op == -1: 127 | break 128 | p, q = q, cls.LOGIC_TABLE[op][p][q] 129 | return q 130 | 131 | 132 | def resolve_data_manager(config): 133 | if config.task == "parity": 134 | return ParityDataManager 135 | elif config.task == "logic": 136 | return LogicDataManager 137 | else: 138 | raise KeyError(config.task) 139 | -------------------------------------------------------------------------------- /act/models.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | import torch.nn.functional as F 4 | from torch.autograd import Variable 5 | 6 | 7 | class ParityRNNModel(nn.Module): 8 | IS_ACT = False 9 | 10 | def __init__(self, config): 11 | super(ParityRNNModel, self).__init__() 12 | self.rnn_size = config.parity_rnn_size 13 | if config.parity_rnn_type == "RNN": 14 | rnn_class = nn.RNN 15 | elif config.parity_rnn_type == "LSTM": 16 | rnn_class = nn.LSTM 17 | else: 18 | raise KeyError("rnn_class") 19 | 20 | self.rnn = rnn_class( 21 | input_size=config.parity_input_size, 22 | hidden_size=self.rnn_size, 23 | batch_first=True, 24 | ) 25 | self.fc1 = nn.Linear(self.rnn_size, 1) 26 | self.reset_parameters() 27 | 28 | def reset_parameters(self): 29 | self.rnn.reset_parameters() 30 | self.fc1.reset_parameters() 31 | 32 | def forward(self, x, compute_ponder_cost=True): 33 | rnn_out, hidden = self.rnn(x) 34 | return self.fc1(rnn_out).squeeze(1).squeeze(1), None 35 | 36 | 37 | class ParityACTModel(nn.Module): 38 | IS_ACT = True 39 | 40 | def __init__(self, config): 41 | super(ParityACTModel, self).__init__() 42 | self.rnn_size = config.parity_rnn_size 43 | if config.parity_rnn_type == "RNN": 44 | rnn_cell_class = nn.RNNCell 45 | elif config.parity_rnn_type == "LSTM": 46 | rnn_cell_class = nn.LSTMCell 47 | else: 48 | raise KeyError("rnn_class") 49 | rnn_cell = rnn_cell_class( 50 | input_size=config.parity_input_size + 1, # +1 for flag 51 | hidden_size=self.rnn_size, 52 | ) 53 | self.rnn = ACTFromCell( 54 | rnn_cell=rnn_cell, 55 | max_ponder=config.act_max_ponder, 56 | epsilon=config.act_epsilon, 57 | batch_first=True, 58 | ) 59 | self.fc1 = nn.Linear(self.rnn_size, 1) 60 | self.reset_parameters() 61 | 62 | def reset_parameters(self): 63 | self.rnn.reset_parameters() 64 | self.fc1.reset_parameters() 65 | 66 | def forward(self, x, compute_ponder_cost=True): 67 | rnn_out, hidden, ponder_cost = self.rnn( 68 | input_=x, 69 | compute_ponder_cost=compute_ponder_cost, 70 | ) 71 | return self.fc1(rnn_out).squeeze(1).squeeze(1), ponder_cost 72 | 73 | 74 | class LogicRNNModel(nn.Module): 75 | IS_ACT = False 76 | 77 | def __init__(self, config): 78 | super(LogicRNNModel, self).__init__() 79 | self.rnn_size = config.logic_rnn_size 80 | if config.logic_rnn_type == "RNN": 81 | rnn_class = nn.RNN 82 | elif config.logic_rnn_type == "LSTM": 83 | rnn_class = nn.LSTM 84 | else: 85 | raise KeyError("rnn_class") 86 | 87 | self.rnn = rnn_class( 88 | input_size=config.logic_input_size, 89 | hidden_size=self.rnn_size, 90 | batch_first=True, 91 | ) 92 | self.fc1 = nn.Linear(self.rnn_size, 1) 93 | self.reset_parameters() 94 | 95 | def reset_parameters(self): 96 | self.rnn.reset_parameters() 97 | self.fc1.reset_parameters() 98 | 99 | def forward(self, x, compute_ponder_cost=True): 100 | rnn_out, hidden = self.rnn(x) 101 | return self.fc1(rnn_out).squeeze(1).squeeze(1), None 102 | 103 | 104 | class LogicACTModel(nn.Module): 105 | IS_ACT = True 106 | 107 | def __init__(self, config): 108 | super(LogicACTModel, self).__init__() 109 | self.rnn_size = config.logic_rnn_size 110 | if config.logic_rnn_type == "RNN": 111 | rnn_cell_class = nn.RNNCell 112 | elif config.logic_rnn_type == "LSTM": 113 | rnn_cell_class = nn.LSTMCell 114 | else: 115 | raise KeyError("rnn_class") 116 | rnn_cell = rnn_cell_class( 117 | input_size=config.logic_input_size + 1, # +1 for flag 118 | hidden_size=self.rnn_size, 119 | ) 120 | self.rnn = ACTFromCell( 121 | rnn_cell=rnn_cell, 122 | max_ponder=config.act_max_ponder, 123 | epsilon=config.act_epsilon, 124 | batch_first=True, 125 | ) 126 | self.fc1 = nn.Linear(self.rnn_size, 1) 127 | self.reset_parameters() 128 | 129 | def reset_parameters(self): 130 | self.rnn.reset_parameters() 131 | self.fc1.reset_parameters() 132 | 133 | def forward(self, x, compute_ponder_cost=True): 134 | rnn_out, hidden, ponder_cost = self.rnn( 135 | input_=x, 136 | compute_ponder_cost=compute_ponder_cost, 137 | ) 138 | return self.fc1(rnn_out).squeeze(1).squeeze(1), ponder_cost 139 | 140 | 141 | class RNNFromCell(nn.Module): 142 | def __init__(self, rnn_cell, batch_first=False): 143 | super(RNNFromCell, self).__init__() 144 | self.rnn_cell = rnn_cell 145 | self.batch_first = batch_first 146 | self._is_lstm = isinstance(self.rnn_cell, nn.LSTMCell) 147 | 148 | def forward(self, input_, hx=None): 149 | if self.batch_first: 150 | # Move batch to second 151 | input_ = input_.transpose(0, 1) 152 | if hx is None: 153 | hx = Variable(input_.data.new( 154 | input_.size(1), self.rnn_cell.hidden_size 155 | ).zero_()) 156 | if self._is_lstm: 157 | hx = [hx, hx] 158 | 159 | hx_list = [] 160 | for input_row in input_: 161 | hx = self.rnn_cell(input_row, hx) 162 | hx_list.append(hx) 163 | 164 | if self._is_lstm: 165 | all_hx = [ 166 | torch.stack([_[0] for _ in hx_list]), 167 | torch.stack([_[1] for _ in hx_list]), 168 | ] 169 | hx = hx[0].unsqueeze(0), hx[1].unsqueeze(1) 170 | else: 171 | all_hx = torch.stack(hx_list) 172 | hx = hx.unsqueeze(0) 173 | 174 | if self.batch_first: 175 | # Move batch to first 176 | if self._is_lstm: 177 | all_hx = all_hx[0].transpose(0, 1) 178 | else: 179 | all_hx = all_hx.transpose(0, 1) 180 | 181 | return all_hx, hx 182 | 183 | 184 | class ACTFromCell(nn.Module): 185 | def __init__(self, rnn_cell, max_ponder=100, epsilon=0.01, 186 | batch_first=False): 187 | super(ACTFromCell, self).__init__() 188 | self.rnn_cell = rnn_cell 189 | self.batch_first = batch_first 190 | self.max_ponder = max_ponder 191 | self.epsilon = epsilon 192 | 193 | self._is_lstm = isinstance(self.rnn_cell, nn.LSTMCell) 194 | self.ponder_linear = nn.Linear(self.rnn_cell.hidden_size, 1) 195 | 196 | def forward(self, input_, hx=None, compute_ponder_cost=True): 197 | if self.batch_first: 198 | # Move batch to second 199 | input_ = input_.transpose(0, 1) 200 | if hx is None: 201 | hx = Variable(input_.data.new( 202 | input_.size(1), self.rnn_cell.hidden_size 203 | ).zero_()) 204 | if self._is_lstm: 205 | cx = hx 206 | 207 | # Pre-allocate variables 208 | time_size, batch_size, input_dim_size = input_.size() 209 | selector = input_.data.new(batch_size).byte() 210 | hx_list, cx_list = [], [] 211 | ponder_cost = 0 212 | ponder_times = [] 213 | 214 | # For each t 215 | for input_row in input_: 216 | 217 | accum_h = Variable(input_.data.new(batch_size).zero_()) 218 | accum_hx = Variable(input_.data.new( 219 | input_.size(1), self.rnn_cell.hidden_size 220 | ).zero_()) 221 | if self._is_lstm: 222 | accum_cx = Variable(input_.data.new( 223 | input_.size(1), self.rnn_cell.hidden_size 224 | ).zero_()) 225 | selector = selector.fill_(1) 226 | 227 | if self._is_lstm: 228 | accum_cx = accum_cx.zero_() 229 | step_count = Variable(input_.data.new(batch_size).zero_()) 230 | input_row_with_flag = torch.cat([ 231 | input_row, 232 | Variable(input_row.data.new(batch_size, 1).zero_()) 233 | ], dim=1) 234 | if compute_ponder_cost: 235 | step_ponder_cost = Variable(input_.data.new(batch_size).zero_()) 236 | 237 | for act_step in range(self.max_ponder): 238 | idx = bool_to_idx(selector) 239 | if compute_ponder_cost: 240 | # Weird but matches formulation 241 | step_ponder_cost[idx] = -accum_h[idx] 242 | 243 | if self._is_lstm: 244 | hx[idx], cx[idx] = self.rnn_cell( 245 | input_row_with_flag[idx], (hx[idx], cx[idx])) 246 | else: 247 | hx[idx] = self.rnn_cell(input_row_with_flag[idx], hx[idx]) 248 | accum_hx[idx] += hx[idx] 249 | h = F.sigmoid(self.ponder_linear(hx[idx]).squeeze(1)) 250 | accum_h[idx] += h 251 | p = h - (accum_h[idx] - 1).clamp(min=0) 252 | accum_hx[idx] += p.unsqueeze(1) * hx[idx] 253 | if self._is_lstm: 254 | accum_cx[idx] += p.unsqueeze(1) * cx[idx] 255 | step_count[idx] += 1 256 | selector = (accum_h < 1 - self.epsilon).data 257 | if not selector.any(): 258 | break 259 | input_row_with_flag[:, input_dim_size] = 1 260 | 261 | ponder_times.append(step_count.data.cpu().numpy()) 262 | if compute_ponder_cost: 263 | ponder_cost += step_ponder_cost 264 | 265 | hx = accum_hx / step_count.clone().unsqueeze(1) 266 | hx_list.append(hx) 267 | 268 | if self._is_lstm: 269 | cx = accum_cx / step_count.clone().unsqueeze(1) 270 | cx_list.append(cx) 271 | 272 | if self._is_lstm: 273 | all_hx = [ 274 | torch.stack(hx_list), 275 | torch.stack(cx_list), 276 | ] 277 | hx, cx = hx.unsqueeze(0), cx.unsqueeze(1) 278 | else: 279 | all_hx = torch.stack(hx_list) 280 | hx = hx.unsqueeze(0) 281 | 282 | if self.batch_first: 283 | # Move batch to first 284 | if self._is_lstm: 285 | all_hx = all_hx[0].transpose(0, 1) 286 | else: 287 | all_hx = all_hx.transpose(0, 1) 288 | 289 | return all_hx, hx, { 290 | "ponder_cost": ponder_cost, 291 | "ponder_times": ponder_times, 292 | } 293 | 294 | def reset_parameters(self): 295 | self.rnn_cell.reset_parameters() 296 | self.ponder_linear.reset_parameters() 297 | self.ponder_linear.bias.data.fill_(1) 298 | 299 | 300 | def bool_to_idx(idx): 301 | return idx.nonzero().squeeze(1) 302 | 303 | 304 | def resolve_model(config): 305 | if config.use_act: 306 | model_class = ParityACTModel 307 | else: 308 | model_class = ParityRNNModel 309 | 310 | model = model_class(config) 311 | if config.cuda: 312 | model = model.cuda() 313 | return model 314 | -------------------------------------------------------------------------------- /act/test_models.py: -------------------------------------------------------------------------------- 1 | from torch import nn 2 | import torch.nn.functional as F 3 | from torch.autograd import Variable 4 | import torch 5 | import numpy as np 6 | 7 | import models 8 | 9 | 10 | def test_rnn_and_rnn_from_cell_match(): 11 | i_size = 5 12 | h_size = 6 13 | 14 | x = np.random.normal(size=(1, 10, i_size)) 15 | x_var = Variable(torch.Tensor(x)) 16 | 17 | rnn1 = nn.RNN( 18 | input_size=i_size, 19 | hidden_size=h_size, 20 | batch_first=True, 21 | ) 22 | rnn2_cell = nn.RNNCell( 23 | input_size=i_size, 24 | hidden_size=h_size, 25 | ) 26 | _copy_rnn_params(rnn1, rnn2_cell) 27 | rnn2 = models.RNNFromCell(rnn2_cell, batch_first=True) 28 | 29 | outputs1, hiddens1 = rnn1(x_var) 30 | outputs2, hiddens2 = rnn2(x_var) 31 | 32 | assert outputs1.equal(outputs2) 33 | assert hiddens1.equal(hiddens2) 34 | 35 | 36 | def test_lstm_and_lstm_from_cell_match(): 37 | i_size = 5 38 | h_size = 6 39 | 40 | x = np.random.normal(size=(1, 10, i_size)) 41 | x_var = Variable(torch.Tensor(x)) 42 | 43 | lstm1 = nn.LSTM( 44 | input_size=i_size, 45 | hidden_size=h_size, 46 | batch_first=True, 47 | ) 48 | lstm2_cell = nn.LSTMCell( 49 | input_size=i_size, 50 | hidden_size=h_size, 51 | ) 52 | _copy_rnn_params(lstm1, lstm2_cell) 53 | lstm2 = models.RNNFromCell(lstm2_cell, batch_first=True) 54 | 55 | outputs1, (h1, c1) = lstm1(x_var) 56 | outputs2, (h2, c2) = lstm2(x_var) 57 | 58 | assert outputs1.equal(outputs2) 59 | assert h1.equal(h2) 60 | assert c1.equal(c2) 61 | 62 | 63 | def _copy_rnn_params(rnn_from, rnn_to): 64 | rnn_to.weight_ih.data = rnn_from.weight_ih_l0.data 65 | rnn_to.weight_hh.data = rnn_from.weight_hh_l0.data 66 | rnn_to.bias_ih.data = rnn_from.bias_ih_l0.data 67 | rnn_to.bias_hh.data = rnn_from.bias_hh_l0.data 68 | -------------------------------------------------------------------------------- /act/train.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import pathlib 3 | 4 | from torch.autograd import Variable 5 | import torch.nn as nn 6 | import torch.optim as optim 7 | import torch 8 | 9 | from . import utils 10 | 11 | 12 | def train(config, model, data_manager): 13 | optimizer = optim.Adam(model.parameters(), lr=config.learning_rate) 14 | 15 | for epoch in range(1, config.num_epochs + 1): 16 | train_data_loader = data_manager.create_dataloader(config) 17 | test_data_loader = data_manager.create_dataloader(config, mode="test") 18 | train_epoch( 19 | epoch=epoch, config=config, model=model, 20 | data_loader=train_data_loader, 21 | optimizer=optimizer, 22 | ) 23 | test_result = test_epoch( 24 | config=config, model=model, 25 | data_loader=test_data_loader, 26 | ) 27 | print( 28 | 'Epoch: {}, Average loss: {:.4f}, ' 29 | 'Accuracy: {}/{} ({:.0f}%), PT: {}'.format( 30 | epoch, 31 | test_result["loss"], test_result["num_correct"], 32 | test_result["length"], 33 | test_result["num_correct"] * 100 / test_result["length"], 34 | f'{test_result["mean_ponder_time"]:.1f}' 35 | if test_result["mean_ponder_time"] else "N/A", 36 | ) 37 | ) 38 | if epoch % config.model_save_interval == 0: 39 | model_save_path = pathlib.Path(config.model_save_path) 40 | model_save_path.mkdir(exist_ok=True, parents=True) 41 | model_save_file_path = ( 42 | model_save_path / f"epoch_{epoch}.pt" 43 | ) 44 | print(f"Saving checkpoint to {model_save_file_path}") 45 | torch.save(model.state_dict(), model_save_file_path) 46 | 47 | 48 | def train_epoch(epoch, config, model, data_loader, optimizer): 49 | model.train() 50 | 51 | loss_func = nn.BCEWithLogitsLoss() 52 | 53 | for batch_idx, (x, y) in enumerate(data_loader): 54 | x_var = utils.maybe_cuda_var(x, cuda=config.cuda) 55 | y_var = Variable(y, requires_grad=False) 56 | if config.cuda: 57 | y_var = y_var.cuda() 58 | 59 | y_hat, ponder_dict = model(x_var) 60 | loss = loss_func(y_hat, y_var) 61 | if ponder_dict: 62 | loss += ( 63 | config.act_ponder_penalty * ponder_dict["ponder_cost"].mean() 64 | ) 65 | 66 | optimizer.zero_grad() 67 | loss.backward() 68 | optimizer.step() 69 | 70 | if config.train_log and batch_idx % config.train_log_interval == 0: 71 | print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format( 72 | epoch, batch_idx * len(x), len(data_loader.dataset), 73 | 100. * batch_idx / len(data_loader), loss.data[0]) 74 | ) 75 | 76 | 77 | def test_epoch(config, model, data_loader): 78 | model.eval() 79 | test_loss = 0 80 | correct = 0 81 | 82 | loss_func = nn.BCEWithLogitsLoss(size_average=False) 83 | 84 | ponder_times_ls = [] 85 | for batch_idx, (x, y) in enumerate(data_loader): 86 | x_var = utils.maybe_cuda_var(x, cuda=config.cuda) 87 | y_var = Variable(y, volatile=True) 88 | if config.cuda: 89 | y_var = y_var.cuda() 90 | 91 | y_hat, ponder_dict = model(x_var, compute_ponder_cost=False) 92 | test_loss += loss_func(y_hat, y_var).data[0] 93 | y_pred = (y_hat.data > 0.5).float() 94 | correct += y_pred.eq(y_var.data).cpu().numpy()\ 95 | .reshape(y.shape[0], -1).all(axis=1).sum() 96 | 97 | if ponder_dict: 98 | ponder_times_ls.append(np.array(ponder_dict["ponder_times"]).T) 99 | 100 | test_loss /= len(data_loader.dataset) 101 | 102 | if ponder_times_ls: 103 | mean_ponder_time = np.mean(np.vstack(ponder_times_ls)) 104 | else: 105 | mean_ponder_time = None 106 | 107 | return { 108 | "loss": test_loss, 109 | "num_correct": correct, 110 | "length": len(data_loader.dataset), 111 | "mean_ponder_time": mean_ponder_time, 112 | } 113 | -------------------------------------------------------------------------------- /act/utils.py: -------------------------------------------------------------------------------- 1 | from torch.autograd import Variable 2 | 3 | 4 | def maybe_cuda_var(x, cuda): 5 | """Helper for converting to a Variable""" 6 | x = Variable(x) 7 | if cuda: 8 | x = x.cuda() 9 | return x 10 | -------------------------------------------------------------------------------- /run_train.py: -------------------------------------------------------------------------------- 1 | from act.configuration import Config 2 | from act.train import train 3 | from act.models import resolve_model 4 | from act.data import resolve_data_manager 5 | 6 | 7 | if __name__ == "__main__": 8 | config = Config.parse_configuration() 9 | train( 10 | config=config, 11 | model=resolve_model(config), 12 | data_manager=resolve_data_manager(config) 13 | ) 14 | --------------------------------------------------------------------------------