├── __init__.py ├── LICENSE ├── README.md └── gatel0rd.py /__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2021 Autonomous Learning Group 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # GateL0RD 2 | This is a lightweight PyTorch implementation of GateL0RD, our RNN presented in ["Sparsely Changing Latent States for Prediction and Planning in Partially Observable Domains"](https://arxiv.org/abs/2110.15949). 3 | 4 | 5 | We provide two variants of GateL0RD: `GateL0RD` can be used like a regular PyTorch `RNN`, whereas `GateL0RDCell` can be used like a PyTorch `RNNCell`. To install put ```gatel0rd.py``` into your working directory. 6 | 7 | Generic example using `GateL0RD`: 8 | ```python 9 | from gatel0rd import GateL0RD 10 | #... 11 | model = GateL0RD(input_size=input_dim, hidden_size=args.latent_dim, reg_lambda=args.lambda, output_size=output_dim) 12 | optimizer = optim.Adam(model.parameters(), lr=args.lr) # optimizer of your choice 13 | # ... 14 | for X,Y in training_data: 15 | Y_hat, H, Theta = model.forward(X) 16 | optimizer.zero_grad() 17 | loss_task = F.mse_loss(Y_hat, Y) # loss of your choice 18 | loss = model.loss(loss_task, Theta) 19 | loss.backward() 20 | optimizer.step() 21 | #... 22 | ``` 23 | A repository containing all experiments of the paper, including examples on how to use a `GateL0RDCell`, can be found [here](https://github.com/martius-lab/GateL0RD-paper). 24 | -------------------------------------------------------------------------------- /gatel0rd.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | 5 | def ReTanh(x): 6 | """ 7 | ReTanh function applied on tensor x 8 | """ 9 | return x.tanh().clamp(min=0, max=1) 10 | 11 | 12 | class HeavisideST(torch.autograd.Function): 13 | """ 14 | Heaviside activation function with straight through estimator 15 | """ 16 | 17 | @staticmethod 18 | def forward(ctx, input): 19 | return torch.ceil(input).clamp(min=0, max=1) 20 | 21 | @staticmethod 22 | def backward(ctx, grad_output): 23 | grad_input = grad_output.clone() 24 | return grad_input 25 | 26 | 27 | class GateL0RDCell(nn.Module): 28 | """ 29 | 30 | RNNCell of GateL0RD 31 | 32 | One GateL0RD cell uses three subnetworks: 33 | 1. a recommendation network r, which proposes a new candidate latent state 34 | 2. a gating network g, which determines how the latent state is updated 35 | 3. output functions (p & o), which computes the output based on the updated latent state and the input. 36 | 37 | The forward pass computes the following from input x_t and previous latent state h_{t-1}: 38 | - s_t \sim \mathcal{N}(g(x_t, h_{t-1}), \Sigma) 39 | - \Lambda(s_t) = max(0, \tanh(s_t)) 40 | - h_t = \Lambda(s_t) \odot r(x_t, h_{t-1}) + (1 - \Lambda(s_t)) \odot h_{t-1} 41 | - y_t = p(x_t, h_t) \odot p(x_t, h_t) 42 | 43 | """ 44 | 45 | def __init__(self, input_size, hidden_size, reg_lambda, output_size=-1, num_layers_internal=1, gate_noise_level=0.1, 46 | device=None): 47 | """ 48 | GateL0RD cell 49 | :param input_size: The number of expected features in the cell input x 50 | :param hidden_size: The number of features in the latent state h 51 | :param reg_lambda: Hyperparameter controlling the sparsity of latent state changes 52 | :param output_size: The number of expected features for the cell output y (Default: same as hidden size) 53 | :param num_layers_internal: Number of layers used in the g - and r-subnetworks 54 | :param gate_noise_level: Standard deviation of normal distributed gate noise for stochastic gates (\Sigma) 55 | :param device: torch.device to use for creating tensors. 56 | """ 57 | 58 | super(GateL0RDCell, self).__init__() 59 | 60 | self.input_size = input_size 61 | self.hidden_size = hidden_size 62 | 63 | if output_size == -1: 64 | output_size = hidden_size 65 | self.output_size = output_size 66 | 67 | input_size_gates = input_size + hidden_size 68 | 69 | # Create g-network: 70 | temp_gating = nn.ModuleList([]) 71 | in_dim_g = input_size_gates 72 | for gl in range(num_layers_internal): 73 | gl_factor = pow(2, (num_layers_internal - gl - 1)) 74 | out_dim_g = gl_factor * hidden_size 75 | temp_gating.append(nn.Linear(in_dim_g, out_dim_g)) 76 | 77 | if gl < (num_layers_internal - 1): 78 | temp_gating.append(nn.Tanh()) 79 | in_dim_g = out_dim_g 80 | self.input_gates = nn.Sequential(*temp_gating) 81 | 82 | # Create r-network: 83 | temp_r_function = nn.ModuleList([]) 84 | in_dim_r = input_size_gates 85 | for rl in range(num_layers_internal): 86 | rl_factor = pow(2, (num_layers_internal - rl - 1)) 87 | out_dim_r = rl_factor * hidden_size 88 | temp_r_function.append(nn.Linear(in_dim_r, out_dim_r)) 89 | temp_r_function.append(nn.Tanh()) 90 | in_dim_r = out_dim_r 91 | self.r_function = nn.Sequential(*temp_r_function) 92 | 93 | # Create output function p * o: 94 | 95 | # Create p-network: 96 | temp_output_function = nn.ModuleList([]) 97 | temp_output_function.append(nn.Linear(input_size_gates, output_size)) 98 | temp_output_function.append(nn.Tanh()) 99 | self.output_function = nn.Sequential(*temp_output_function) 100 | 101 | # Create o-network 102 | temp_outputgate = nn.ModuleList([]) 103 | temp_outputgate.append(nn.Linear(input_size_gates, output_size)) 104 | temp_outputgate.append(nn.Sigmoid()) 105 | self.output_gates = nn.Sequential(*temp_outputgate) 106 | 107 | assert gate_noise_level >= 0, "Need a positive standard deviation as the gate noise" 108 | self.gate_noise_level = gate_noise_level 109 | 110 | # Gate regularization 111 | self.reg_lambda = reg_lambda 112 | self.gate_reg = HeavisideST.apply 113 | 114 | if device is None: 115 | self.device = torch.device('cpu') 116 | else: 117 | self.device = device 118 | 119 | def forward(self, x_t, h_tminus1=None): 120 | """ 121 | Forward pass one step, i.e. pass through g-, r-, p- and o-subnetwork 122 | :param x_t: tensor of cell inputs 123 | :param h_tminus1: tensor of last latent state (Default: initialized by zeros) 124 | :return: rnn output y_t, hidden states h_t, tensor of regularized gatings \Theta(\Lambda_t) 125 | """ 126 | assert len(x_t.shape) == 2, "Wrong input dimensionality of x_t in GateL0RDCell: " + str(x_t.shape) 127 | batch_size, layer_input_size = x_t.size() 128 | 129 | if h_tminus1 is None: 130 | h_tminus1 = torch.zeros((batch_size, self.hidden_size), device=self.device) 131 | else: 132 | assert len(h_tminus1.shape) == 2, "Wrong input dimensionality of h_tminus1 in GateL0RDCell: " + str(h_tminus1.shape) 133 | assert h_tminus1.shape[1] == self.hidden_size 134 | 135 | # Input to g and r-network is the current input plus the last latent state 136 | gr_input = torch.cat((x_t, h_tminus1), 1) 137 | 138 | ''' 139 | G- NETWORK 140 | ''' 141 | i_t = self.input_gates(gr_input) 142 | if self.training: 143 | gate_noise = torch.randn(size=(batch_size, self.hidden_size), device=self.device) * self.gate_noise_level 144 | else: 145 | # Gate noise is zero 146 | gate_noise = torch.zeros((batch_size, self.hidden_size), device=self.device) 147 | 148 | # Stochastic input gate activation 149 | Lambda_t = ReTanh(i_t - gate_noise) 150 | Theta_t = self.gate_reg(Lambda_t) 151 | 152 | ''' 153 | R-Network 154 | ''' 155 | h_hat_t = self.r_function(gr_input) 156 | 157 | ''' 158 | New latent state 159 | ''' 160 | h_t = Lambda_t * h_hat_t + (1.0 - Lambda_t) * h_tminus1 161 | 162 | ''' 163 | Output function : 164 | ''' 165 | xh_t = torch.cat((x_t, h_t), 1) 166 | y_hat_t = self.output_function(xh_t) 167 | 168 | # Output is computed as p(x_t, h_t) * o(x_t, h_t) 169 | o_lt = self.output_gates(xh_t) 170 | y_t = y_hat_t * o_lt 171 | 172 | return y_t, h_t, Theta_t 173 | 174 | def loss(self, loss_task, Theta): 175 | """ 176 | GateL0RD loss function 177 | :param loss_task: Computed task-based loss, e.g. MSE for regression or cross-entropy for classification 178 | :param Theta: Regularized gate activation 179 | :return: lambda-weighted sum of the two losses 180 | """ 181 | assert Theta is not None, 'Provide tensor of regularized gates (Theta) for loss computation.' 182 | gate_loss = torch.mean(Theta) 183 | return loss_task + self.reg_lambda * gate_loss 184 | 185 | 186 | class GateL0RD(torch.nn.Module): 187 | """ 188 | 189 | RNN implementation of GateL0RD 190 | 191 | """ 192 | 193 | def __init__(self, input_size, hidden_size, reg_lambda, output_size=-1, num_layers_internal=1, 194 | h_init_net=False, h_init_net_layers=3, gate_noise_level=0.1, batch_first=False, device=None): 195 | """ 196 | GateL0RD RNN 197 | :param input_size: The number of expected features in the cell input x 198 | :param hidden_size: The number of features in the latent state h 199 | :param reg_lambda: Hyperparameter controlling the sparsity of latent state changes 200 | :param output_size: The number of expected features for the cell output y (Default: same as hidden size) 201 | :param h_init_net: If true, then use a feed-forward network to learn to initialize the hidden state based on the 202 | first input (Default: False) 203 | :param h_init_net_layers: How many layers will be used to initialize the hidden state from the input. Layer 204 | number l has 2^l*hidden_size features. (Default: 3 layers) 205 | :param num_layers_internal: Number of layers used in the g - and r-subnetworks 206 | :param gate_noise_level: Standard deviation of normal distributed gate noise for stochastic gates (\Sigma) 207 | :param batch_first: If True, then the input and output tensors are provided as (batch, seq, feature) instead of 208 | (seq, batch, feature). Note that this does not apply to latent states or gates. Default: False 209 | :param device: torch.device to use for creating tensors. 210 | """ 211 | 212 | super(GateL0RD, self).__init__() 213 | 214 | self.input_size = input_size 215 | self.hidden_size = hidden_size 216 | 217 | if output_size == -1: 218 | output_size = hidden_size 219 | self.output_size = output_size 220 | 221 | self.cell = GateL0RDCell(input_size=input_size, output_size=output_size, hidden_size=hidden_size, 222 | num_layers_internal=num_layers_internal, gate_noise_level=gate_noise_level, 223 | reg_lambda=reg_lambda, device=device) 224 | 225 | self.use_h_init_net = h_init_net 226 | if h_init_net: 227 | self.f_init = self.__create_f_init(f_init_layers=h_init_net_layers, input_dim=input_size, 228 | latent_dim=hidden_size) 229 | 230 | self.last_Thetas = None 231 | 232 | self.batch_first = batch_first 233 | 234 | if device is None: 235 | self.device = torch.device('cpu') 236 | else: 237 | self.device = device 238 | 239 | @staticmethod 240 | def __create_f_init(f_init_layers, input_dim, latent_dim): 241 | input_dim_warm_up = input_dim 242 | warm_up_net = nn.ModuleList([]) 243 | for w in range(f_init_layers): 244 | w_factor = pow(2, (f_init_layers - w - 1)) 245 | warm_up_net.append(nn.Linear(input_dim_warm_up, w_factor * latent_dim)) 246 | warm_up_net.append(nn.Tanh()) 247 | input_dim_warm_up = w_factor * latent_dim 248 | return nn.Sequential(*warm_up_net) 249 | 250 | def __forward_one_step(self, x_t, h_tminus1): 251 | return self.cell.forward(x_t, h_tminus1) 252 | 253 | def forward(self, x, h_init=None, h_sequence=False): 254 | """ 255 | Forward pass for sequence data 256 | :param x: tensor of sequence of input batches with shape (seq, batch, feature) 257 | (or (batch, seq, feature) for batch_first=True) 258 | :param h_init: tensor of initial latent state with shape (1, batch, feature). If None it is initialized by a 259 | feed-forward network based on x_0 (h_init_net=True) or set to zero (h_init_net=False). 260 | :param h_sequence: If True outputs sequence of latent states, else only last latent state (Default:False) 261 | :return: - rnn output y with shape (seq, batch, feature) (or (batch, seq, feature) for batch_first=True), 262 | - latent state h of shape (1, batch, feature) (or (seq, batch, feature) for output_h_sequence=True), 263 | - regularized gate activations (\Theta(\Lambda(s))) with shape (seq, batch, feature) 264 | """ 265 | 266 | assert len(x.shape) == 3, "Input must have 3 dimensions, got " + str(len(x.shape)) 267 | 268 | if self.batch_first: 269 | x = x.permute(1, 0, 2) 270 | 271 | S, B, D = x.shape 272 | 273 | assert D == self.input_size, "Expected input of shape (*, *, " + str(self.input_size) + "), got " + str(x.shape) 274 | 275 | if h_init is None: 276 | if self.use_h_init_net: 277 | h_init = self.f_init(x[0, :, :]).unsqueeze(0) 278 | else: 279 | h_init = torch.zeros((1, B, self.hidden_size), device=self.device) 280 | else: 281 | h_shape = h_init.shape 282 | assert len(h_shape) == 3 and h_shape[0] == 1 and h_shape[1] == B and h_shape[2] == self.hidden_size, \ 283 | "Expected latent state of shape (1, " + str(B) + ", " + str(self.hidden_size) + "), got " + str(h_shape) 284 | 285 | h_t = h_init[0, :, :] 286 | list_ys = [] 287 | list_hs = [] 288 | list_Thetas = [] 289 | for t in range(S): 290 | x_t = x[t, :, :] 291 | y_t, h_t, Theta_t = self.__forward_one_step(x_t=x_t, h_tminus1=h_t) 292 | list_ys.append(y_t) 293 | list_hs.append(h_t) 294 | list_Thetas.append(Theta_t) 295 | 296 | ys = torch.stack(list_ys) 297 | hs = torch.stack(list_hs) 298 | Thetas = torch.stack(list_Thetas) 299 | h_output = h_t.unsqueeze(0) 300 | 301 | if self.batch_first: 302 | ys = ys.permute(1, 0, 2) 303 | 304 | self.last_Thetas = Thetas 305 | 306 | if h_sequence: 307 | h_output = hs 308 | 309 | return ys, h_output, Thetas 310 | 311 | def loss(self, loss_task, Theta=None): 312 | """ 313 | GateL0RD loss function 314 | :param loss_task: Computed task-based loss, e.g. MSE for regression or cross-entropy for classification 315 | :param Theta: Regularized gate activation, Default: Gate activation from last forward-call 316 | :return: lambda-weighted sum of the two losses 317 | """ 318 | if Theta is None: 319 | assert self.last_Thetas is not None, "forward() needs to be called before loss computation." 320 | return self.cell.loss(loss_task, self.last_Thetas) 321 | return self.cell.loss(loss_task, Theta) 322 | --------------------------------------------------------------------------------