├── __init__.py ├── .gitignore ├── requirements.txt ├── optim ├── __init__.py ├── cocob.py └── adam_hd.py ├── activations ├── __init__.py ├── ternary.py ├── isrlu.py ├── softexp.py └── bipolar.py ├── nn ├── __init__.py ├── rnn_cell_base.py ├── trnn.py ├── qrnn.py ├── ran.py ├── cfn.py ├── gru.py ├── lstm.py ├── causal.py ├── mgu.py └── sru.py ├── LICENSE ├── test_sru_optimisation.py ├── test_activation_speeds.py ├── models.py ├── experiment.py ├── data_generation.py └── README.md /__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | __pycache__/ 2 | synthetic_data.png 3 | *_WIP -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | signalz 2 | numpy 3 | matplotlib 4 | torch 5 | -------------------------------------------------------------------------------- /optim/__init__.py: -------------------------------------------------------------------------------- 1 | 2 | from .cocob import COCOB 3 | from .adam_hd import Adam_HD, Adam_HD_lr_per_param 4 | -------------------------------------------------------------------------------- /activations/__init__.py: -------------------------------------------------------------------------------- 1 | 2 | from .isrlu import ISRLU, ISRU_sigmoid, ISRU_tanh 3 | from .bipolar import Bipolar 4 | from .softexp import SoftExp 5 | from .ternary import TernaryTanh -------------------------------------------------------------------------------- /nn/__init__.py: -------------------------------------------------------------------------------- 1 | 2 | from .lstm import LSTM 3 | from .gru import GRU 4 | from .mgu import MGU, MGU2 5 | from .ran import RAN 6 | from .sru import SRUf, SRU2, SRU 7 | from .qrnn import fakeQRNN 8 | from .trnn import TRNN 9 | from .causal import CausalConv1d, Wave, ShortWave 10 | from .cfn import CFN -------------------------------------------------------------------------------- /activations/ternary.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | from torch.autograd import Function 4 | 5 | # f(x) = 1.5 * tanh(x) + 0.5 * tanh(−3 * x) 6 | 7 | class TernaryTanh(Function): 8 | @staticmethod 9 | def forward(ctx, tensor): 10 | return 1.5 * F.tanh(tensor) + 0.5 * F.tanh(-3 * tensor) 11 | 12 | TernaryTanh = TernaryTanh.apply -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2017 jpeg729 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. -------------------------------------------------------------------------------- /test_sru_optimisation.py: -------------------------------------------------------------------------------- 1 | 2 | if __name__ == '__main__' and __package__ is None: 3 | import os 4 | os.sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) 5 | __package__ = "pytorch_bits" 6 | 7 | import timeit 8 | 9 | import numpy as np 10 | import torch 11 | import torch.nn as nn 12 | from torch.autograd import Variable 13 | 14 | from pytorch_bits.nn import SRU 15 | 16 | 17 | def test_cpu(): 18 | a.detach_() 19 | b = cpu_sru(a) 20 | b.backward(loss) 21 | 22 | def test_gpu(): 23 | a.detach_() 24 | b = gpu_sru(a) 25 | b.backward(loss) 26 | 27 | tests = ("test_cpu", "test_gpu",) 28 | 29 | if __name__ == "__main__": 30 | for size in (50, 100, 200, 500, 1000): 31 | print("Size", size) 32 | try: 33 | a = Variable(torch.rand(size,size,size), requires_grad=True) 34 | loss = torch.ones_like(a) 35 | cpu_sru = SRU(size, 100, gpu=False) 36 | gpu_sru = SRU(size, 100, gpu=True) 37 | for test in tests: 38 | timer = timeit.Timer(test, globals=globals()) 39 | print(test, np.mean(timer.repeat(number=1000000, repeat=10))) 40 | except RuntimeError: 41 | print("Not enough RAM. Aborting.") -------------------------------------------------------------------------------- /nn/rnn_cell_base.py: -------------------------------------------------------------------------------- 1 | 2 | from torch.nn.modules.module import Module 3 | import torch.nn.functional as F 4 | 5 | from pytorch_bits import activations 6 | 7 | class RNNCellBase(Module): 8 | 9 | def __repr__(self): 10 | s = '{name}({input_size}, {hidden_size}' 11 | if 'bias' in self.__dict__ and self.bias is not True: 12 | s += ', bias={bias}' 13 | if 'nonlinearity' in self.__dict__ and self.nonlinearity != "tanh": 14 | s += ', nonlinearity={nonlinearity}' 15 | s += ')' 16 | return s.format(name=self.__class__.__name__, **self.__dict__) 17 | 18 | def get_activation(self, activation): 19 | if hasattr(activations, activation): 20 | return getattr(activations, activation) 21 | return getattr(F, activation) 22 | 23 | def check_forward_input(self, input_data): 24 | if input_data.size(1) != self.input_size: 25 | raise RuntimeError( 26 | "input has inconsistent input_size: got {}, expected {}".format( 27 | input_data.size(1), self.input_size)) 28 | 29 | def check_forward_hidden(self, input_data, hx, hidden_label=''): 30 | if input_data.size(0) != hx.size(0): 31 | raise RuntimeError( 32 | "Input batch size {} doesn't match hidden{} batch size {}".format( 33 | input_data.size(0), hidden_label, hx.size(0))) 34 | 35 | if hx.size(1) != self.hidden_size: 36 | raise RuntimeError( 37 | "hidden{} has inconsistent hidden_size: got {}, expected {}".format( 38 | hidden_label, input_data.size(1), self.input_size)) 39 | -------------------------------------------------------------------------------- /activations/isrlu.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.autograd import Function 3 | 4 | class ISRLU(Function): 5 | @staticmethod 6 | def forward(ctx, tensor, alpha=1): 7 | negatives = torch.min(tensor, torch.Tensor([0])) 8 | nisr = torch.rsqrt(1. + alpha * (negatives ** 2)) 9 | # nisr == 1 where tensor elements are positive 10 | return tensor * nisr 11 | 12 | class ISRU_tanh(Function): 13 | @staticmethod 14 | def forward(ctx, tensor, alpha=1): 15 | return tensor * torch.rsqrt(1. + alpha * (tensor ** 2)) 16 | 17 | class ISRU_sigmoid(Function): 18 | @staticmethod 19 | def forward(ctx, tensor, alpha=1): 20 | return .5 + .5 * tensor * torch.rsqrt(1. + alpha * (tensor ** 2)) 21 | 22 | 23 | if __name__ == "__main__": 24 | import numpy as np 25 | 26 | data = np.random.rand(10) 27 | alpha = 1. 28 | 29 | print("Testing ISRLU", end=" ") 30 | pos = data > 0 31 | calc = np.empty_like(data) 32 | calc[pos] = data[pos] 33 | calc[~pos] = 1. / np.sqrt(1. + alpha * (data[~pos] ** 2)) 34 | out = ISRLU.forward(None, torch.Tensor(data), alpha).numpy() 35 | print("--", "passed" if np.allclose(calc, out) else "failed") 36 | 37 | print("Testing ISRU_tanh", end=" ") 38 | calc = data / np.sqrt(1. + alpha * (data ** 2)) 39 | out = ISRU_tanh.forward(None, torch.Tensor(data), alpha).numpy() 40 | print("--", "passed" if np.allclose(calc, out) else "failed") 41 | 42 | print("Testing ISRU_sigmoid", end=" ") 43 | calc = .5 + .5 * data / np.sqrt(1. + alpha * (data ** 2)) 44 | out = ISRU_sigmoid.forward(None, torch.Tensor(data), alpha).numpy() 45 | print("--", "passed" if np.allclose(calc, out) else "failed") 46 | 47 | ISRLU = ISRLU.apply 48 | ISRU_tanh = ISRU_tanh.apply 49 | ISRU_sigmoid = ISRU_sigmoid.apply 50 | -------------------------------------------------------------------------------- /test_activation_speeds.py: -------------------------------------------------------------------------------- 1 | 2 | if __name__ == '__main__' and __package__ is None: 3 | import os 4 | os.sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) 5 | __package__ = "pytorch_bits" 6 | 7 | import timeit 8 | 9 | import numpy as np 10 | import torch 11 | import torch.nn as nn 12 | from torch.autograd import Variable 13 | 14 | import pytorch_bits.activations as activations 15 | 16 | 17 | def test_tanh(): 18 | a.detach_() 19 | b = nn.Tanh(a) 20 | b.backward(loss) 21 | 22 | def test_hardtanh(): 23 | a.detach_() 24 | b = nn.HardTanh(a) 25 | b.backward(loss) 26 | 27 | def test_sigmoid(): 28 | a.detach_() 29 | b = nn.Sigmoid(a) 30 | b.backward(loss) 31 | 32 | def test_elu(): 33 | a.detach_() 34 | b = nn.ELU(a) 35 | b.backward(loss) 36 | 37 | def test_isrlu(): 38 | a.detach_() 39 | b = activations.ISRLU(a) 40 | b.backward(loss) 41 | 42 | def test_isru_tanh(): 43 | a.detach_() 44 | b = activations.ISRU_tanh(a) 45 | b.backward(loss) 46 | 47 | def test_isru_sigmoid(): 48 | a.detach_() 49 | b = activations.ISRU_sigmoid(a) 50 | b.backward(loss) 51 | 52 | def test_isru_softsign(): 53 | a.detach_() 54 | b = nn.SoftSign(a) 55 | b.backward(loss) 56 | 57 | tests = ("test_tanh", "test_hardtanh", "test_sigmoid", 58 | "test_isru_tanh", "test_isru_sigmoid", "test_isru_softsign", 59 | "test_elu", "test_isrlu", 60 | ) 61 | 62 | if __name__ == "__main__": 63 | for size in (50, 100, 200, 500, 1000): 64 | print("Size", size) 65 | try: 66 | a = Variable(torch.rand(size,size,size), requires_grad=True) 67 | loss = torch.ones_like(a) 68 | for test in tests: 69 | timer = timeit.Timer(test, globals=globals()) 70 | print(test, np.mean(timer.repeat(number=1000000, repeat=5))) 71 | except RuntimeError: 72 | print("Not enough RAM. Aborting.") -------------------------------------------------------------------------------- /optim/cocob.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.optim.optimizer import Optimizer, required 3 | 4 | class COCOB(Optimizer): 5 | 6 | def __init__(self, params, alpha=100, weight_decay=False): 7 | defaults = dict(alpha=alpha, weight_decay=weight_decay) 8 | super(COCOB, self).__init__(params, defaults) 9 | 10 | def __setstate__(self, state): 11 | super(COCOB, self).__setstate__(state) 12 | 13 | def step(self, closure=None): 14 | loss = None 15 | if closure is not None: 16 | loss = closure() 17 | 18 | for group in self.param_groups: 19 | alpha = group['alpha'] 20 | 21 | for p in group['params']: 22 | if p.grad is None: 23 | continue 24 | 25 | d_p = p.grad.data 26 | 27 | if group['weight_decay'] != 0: 28 | d_p.add_(group['weight_decay'], p.data) 29 | 30 | state = self.state[p] 31 | if len(state) == 0: 32 | state['L'] = torch.zeros_like(p.data) 33 | state['gradients_sum'] = torch.zeros_like(p.data) 34 | state['grad_norm_sum'] = torch.zeros_like(p.data) 35 | state['reward'] = torch.zeros_like(p.data) 36 | state['w'] = torch.zeros_like(p.data) 37 | 38 | L = state['L'] 39 | reward = state['reward'] 40 | gradients_sum = state['gradients_sum'] 41 | grad_norm_sum = state['grad_norm_sum'] 42 | old_w = state['w'] 43 | 44 | torch.max(L, torch.abs(d_p), out=L) 45 | torch.max(reward - old_w * d_p, torch.Tensor([0]), out=reward) 46 | gradients_sum.add_(d_p) 47 | grad_norm_sum.add_(torch.abs(d_p)) 48 | 49 | # the paper sets weights_t = weights_1 + new_w 50 | # we use the equivalent formula: weights_t = weights_tm1 - old_w + new_w 51 | new_w = state['w'] = -gradients_sum / (L * torch.max(grad_norm_sum + L, alpha * L)) * (L + reward) 52 | p.data.add_(-1, old_w) 53 | p.data.add_(new_w) 54 | 55 | return loss -------------------------------------------------------------------------------- /nn/trnn.py: -------------------------------------------------------------------------------- 1 | import math 2 | import torch 3 | from torch.autograd import Variable 4 | from torch.nn.parameter import Parameter 5 | import torch.nn.functional as F 6 | 7 | from .rnn_cell_base import RNNCellBase 8 | 9 | class TRNN(RNNCellBase): 10 | 11 | def __init__(self, input_size, hidden_size, bias=True, sigmoid=None, tanh=None): 12 | super(TRNN, self).__init__() 13 | self.input_size = input_size 14 | self.hidden_size = hidden_size 15 | self.bias = bias 16 | self.sigmoid = F.sigmoid if sigmoid is None else self.get_activation(sigmoid) 17 | self.tanh = F.tanh if tanh is None else self.get_activation(tanh) 18 | self.weight_ih = Parameter(torch.Tensor(2 * hidden_size, input_size)) 19 | if bias: 20 | self.bias_ih = Parameter(torch.Tensor(2 * hidden_size)) 21 | else: 22 | self.register_parameter('bias_ih', None) 23 | self.reset_parameters() 24 | 25 | def reset_parameters(self): 26 | stdv = 1.0 / math.sqrt(self.hidden_size) 27 | for weight in self.parameters(): 28 | weight.data.uniform_(-stdv, stdv) 29 | self.hidden = None 30 | 31 | def reset_hidden(self): 32 | self.hidden = None 33 | 34 | def detach_hidden(self): 35 | self.hidden.detach_() 36 | 37 | def forward(self, input_data, future=0): 38 | timesteps, batch_size, features = input_data.size() 39 | outputs = Variable(torch.zeros(timesteps + future, batch_size, self.hidden_size), requires_grad=False) 40 | 41 | if self.hidden is None: 42 | self.hidden = Variable(torch.zeros(batch_size, self.hidden_size), requires_grad=False) 43 | 44 | self.check_forward_input(input_data[0]) 45 | self.check_forward_hidden(input_data[0], self.hidden) 46 | 47 | for i, input_t in enumerate(input_data.split(1)): 48 | 49 | gi = F.linear(input_t.view(batch_size, features), self.weight_ih, self.bias_ih) 50 | i_n, i_f = gi.chunk(2, 1) 51 | 52 | forgetgate = self.sigmoid(i_f) 53 | newgate = i_n 54 | self.hidden = newgate + forgetgate * (self.hidden - newgate) 55 | outputs[i] = self.hidden 56 | 57 | return outputs 58 | -------------------------------------------------------------------------------- /nn/qrnn.py: -------------------------------------------------------------------------------- 1 | import math 2 | import torch 3 | from torch.autograd import Variable 4 | from torch.nn.parameter import Parameter 5 | import torch.nn.functional as F 6 | 7 | from .rnn_cell_base import RNNCellBase 8 | 9 | class fakeQRNN(RNNCellBase): 10 | 11 | def __init__(self, input_size, hidden_size, bias=True, sigmoid=None, tanh=None): 12 | super(fakeQRNN, self).__init__() 13 | self.input_size = input_size 14 | self.hidden_size = hidden_size 15 | self.bias = bias 16 | self.sigmoid = F.sigmoid if sigmoid is None else self.get_activation(sigmoid) 17 | self.tanh = F.tanh if tanh is None else self.get_activation(tanh) 18 | self.weight_ih = Parameter(torch.Tensor(3 * hidden_size, input_size)) 19 | if bias: 20 | self.bias_ih = Parameter(torch.Tensor(3 * hidden_size)) 21 | else: 22 | self.register_parameter('bias_ih', None) 23 | self.reset_parameters() 24 | 25 | def reset_parameters(self): 26 | stdv = 1.0 / math.sqrt(self.hidden_size) 27 | for weight in self.parameters(): 28 | weight.data.uniform_(-stdv, stdv) 29 | self.hidden = None 30 | 31 | def reset_hidden(self): 32 | self.hidden = None 33 | 34 | def detach_hidden(self): 35 | self.hidden.detach_() 36 | 37 | def forward(self, input_data, future=0): 38 | timesteps, batch_size, features = input_data.size() 39 | # print("t %d, b %d, f %d" % (timesteps, batch_size, features)) 40 | outputs = Variable(torch.zeros(timesteps + future, batch_size, self.hidden_size), requires_grad=False) 41 | 42 | if self.hidden is None: 43 | self.hidden = Variable(torch.zeros(batch_size, self.hidden_size), requires_grad=False) 44 | 45 | self.check_forward_input(input_data[0]) 46 | self.check_forward_hidden(input_data[0], self.hidden) 47 | 48 | for i, input_t in enumerate(input_data.split(1)): 49 | x = input_t.view(batch_size, features) 50 | gi = F.linear(x, self.weight_ih, self.bias_ih) 51 | i_r, i_f, i_n = gi.chunk(3, 1) 52 | 53 | readgate = self.sigmoid(i_r) 54 | forgetgate = self.sigmoid(i_f) 55 | newgate = self.tanh(i_n) 56 | self.hidden = newgate + forgetgate * (self.hidden - newgate) 57 | outputs[i] = readgate * self.hidden 58 | 59 | return outputs 60 | -------------------------------------------------------------------------------- /nn/ran.py: -------------------------------------------------------------------------------- 1 | import math 2 | import torch 3 | from torch.autograd import Variable 4 | from torch.nn.parameter import Parameter 5 | import torch.nn.functional as F 6 | 7 | from .rnn_cell_base import RNNCellBase 8 | 9 | class RAN(RNNCellBase): 10 | 11 | def __init__(self, input_size, hidden_size, bias=True, sigmoid=None, tanh=None): 12 | super(RAN, self).__init__() 13 | self.input_size = input_size 14 | self.hidden_size = hidden_size 15 | self.bias = bias 16 | self.sigmoid = F.sigmoid if sigmoid is None else self.get_activation(sigmoid) 17 | self.tanh = F.tanh if tanh is None else self.get_activation(tanh) 18 | self.weight_ih = Parameter(torch.Tensor(3 * hidden_size, input_size)) 19 | self.weight_hh = Parameter(torch.Tensor(2 * hidden_size, hidden_size)) 20 | if bias: 21 | self.bias_ih = Parameter(torch.Tensor(3 * hidden_size)) 22 | self.bias_hh = Parameter(torch.Tensor(2 * hidden_size)) 23 | else: 24 | self.register_parameter('bias_ih', None) 25 | self.register_parameter('bias_hh', None) 26 | self.reset_parameters() 27 | 28 | def reset_parameters(self): 29 | stdv = 1.0 / math.sqrt(self.hidden_size) 30 | for weight in self.parameters(): 31 | weight.data.uniform_(-stdv, stdv) 32 | self.hidden = None 33 | 34 | def reset_hidden(self): 35 | self.hidden = None 36 | 37 | def detach_hidden(self): 38 | self.hidden.detach_() 39 | 40 | def forward(self, input_data, future=0): 41 | timesteps, batch_size, features = input_data.size() 42 | # print("t %d, b %d, f %d" % (timesteps, batch_size, features)) 43 | outputs = Variable(torch.zeros(timesteps + future, batch_size, self.hidden_size), requires_grad=False) 44 | 45 | if self.hidden is None: 46 | self.hidden = Variable(torch.zeros(batch_size, self.hidden_size), requires_grad=False) 47 | # else: 48 | # self.hidden.detach_() 49 | 50 | self.check_forward_input(input_data[0]) 51 | self.check_forward_hidden(input_data[0], self.hidden) 52 | 53 | for i, input_t in enumerate(input_data.split(1)): 54 | 55 | gi = F.linear(input_t.view(batch_size, features), self.weight_ih, self.bias_ih) 56 | gh = F.linear(self.hidden, self.weight_hh, self.bias_hh) 57 | i_i, i_f, i_n = gi.chunk(3, 1) 58 | h_i, h_f = gh.chunk(2, 1) 59 | 60 | inputgate = self.sigmoid(i_i + h_i) 61 | forgetgate = self.sigmoid(i_f + h_f) 62 | newgate = i_n 63 | self.hidden = inputgate * newgate + forgetgate * self.hidden 64 | outputs[i] = self.tanh(self.hidden) 65 | 66 | return outputs 67 | -------------------------------------------------------------------------------- /activations/softexp.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from torch.autograd import Variable 5 | 6 | testing = False 7 | 8 | class SoftExp(nn.Module): 9 | def __init__(self, input_size): 10 | super(SoftExp, self).__init__() 11 | self.alpha = nn.Parameter(torch.Tensor(input_size)) 12 | 13 | def forward(self, data): 14 | self.alpha.data.clamp_(-1, 1) 15 | 16 | positives = torch.gt(F.threshold(self.alpha, 0, 0), 0) 17 | negatives = torch.gt(F.threshold(-self.alpha, 0, 0), 0) 18 | 19 | output = data.clone() 20 | pos_out = (torch.exp(self.alpha * data) - 1) / self.alpha + self.alpha 21 | neg_out = -(torch.log(1 - self.alpha * (data + self.alpha))) / self.alpha 22 | 23 | output.masked_scatter_(positives, pos_out.masked_select(positives)) 24 | output.masked_scatter_(negatives, neg_out.masked_select(negatives)) 25 | return output 26 | 27 | if __name__ == "__main__": 28 | import numpy as np 29 | 30 | testing = True 31 | test_size = 50 32 | data = Variable(torch.abs(torch.randn(1, 1, test_size))) 33 | softexp = SoftExp(test_size) 34 | softexp.alpha.data.copy_(torch.randn(test_size)) 35 | softexp.alpha.data[0] = 0 36 | 37 | 38 | print("Testing SoftExp", end=" ") 39 | out = softexp(data) 40 | loss = nn.MSELoss()(out, torch.ones_like(data)) 41 | loss.backward() 42 | 43 | failed = False 44 | for i in range(test_size): 45 | if softexp.alpha.data[i] == 0: 46 | activation = data 47 | if not np.allclose(out[:, :, i].data.numpy(), activation[:, :, i].data.numpy()): 48 | print("\nfailed at index", i, "for zero alpha", softexp.alpha.data[i], "output", out[:, :, i].data.numpy(), "for data", data[:, :, i].data.numpy(), end="") 49 | failed = True 50 | elif softexp.alpha.data[i] > 0: 51 | activation = ((torch.exp(softexp.alpha * data) - 1) / softexp.alpha + softexp.alpha) 52 | if not np.allclose(out[:, :, i].data.numpy(), activation[:, :, i].data.numpy()): 53 | print("\nfailed at index", i, "for positive alpha", softexp.alpha.data[i], "output", out[:, :, i].data.numpy(), "for data", data[:, :, i].data.numpy(), end="") 54 | failed = True 55 | elif softexp.alpha.data[i] < 0: 56 | activation = (-(torch.log(1 - softexp.alpha * (data + softexp.alpha))) / softexp.alpha) 57 | if not np.allclose(out[:, :, i].data.numpy(), activation[:, :, i].data.numpy()): 58 | print("\nfailed at index", i, "for negative alpha", softexp.alpha.data[i], "output", out[:, :, i].data.numpy(), "for data", data[:, :, i].data.numpy(), end="") 59 | failed = True 60 | print("-- passed" if not failed else "") 61 | -------------------------------------------------------------------------------- /nn/cfn.py: -------------------------------------------------------------------------------- 1 | import math 2 | import torch 3 | from torch.autograd import Variable 4 | from torch.nn.parameter import Parameter 5 | import torch.nn.functional as F 6 | 7 | from .rnn_cell_base import RNNCellBase 8 | 9 | class CFN(RNNCellBase): 10 | 11 | def __init__(self, input_size, hidden_size, bias=True, sigmoid=None, tanh=None): 12 | super(CFN, self).__init__() 13 | self.input_size = input_size 14 | self.hidden_size = hidden_size 15 | self.bias = bias 16 | self.sigmoid = F.sigmoid if sigmoid is None else self.get_activation(sigmoid) 17 | self.tanh = F.tanh if tanh is None else self.get_activation(tanh) 18 | self.weight_ih = Parameter(torch.Tensor(3 * hidden_size, input_size)) 19 | self.weight_hh = Parameter(torch.Tensor(2 * hidden_size, hidden_size)) 20 | if bias: 21 | self.bias_ih = Parameter(torch.Tensor(3 * hidden_size)) 22 | self.bias_hh = Parameter(torch.Tensor(2 * hidden_size)) 23 | else: 24 | self.register_parameter('bias_ih', None) 25 | self.register_parameter('bias_hh', None) 26 | self.reset_parameters() 27 | 28 | def reset_parameters(self): 29 | stdv = 1.0 / math.sqrt(self.hidden_size) 30 | for weight in self.parameters(): 31 | weight.data.uniform_(-stdv, stdv) 32 | self.hidden = None 33 | 34 | def reset_hidden(self): 35 | self.hidden = None 36 | 37 | def detach_hidden(self): 38 | self.hidden.detach_() 39 | 40 | def forward(self, input_data, future=0): 41 | timesteps, batch_size, features = input_data.size() 42 | # print("t %d, b %d, f %d" % (timesteps, batch_size, features)) 43 | outputs = Variable(torch.zeros(timesteps + future, batch_size, self.hidden_size), requires_grad=False) 44 | 45 | if self.hidden is None: 46 | self.hidden = Variable(torch.zeros(batch_size, self.hidden_size), requires_grad=False) 47 | 48 | self.check_forward_input(input_data[0]) 49 | self.check_forward_hidden(input_data[0], self.hidden) 50 | 51 | for i, input_t in enumerate(input_data.split(1)): 52 | 53 | gi = F.linear(input_t.view(batch_size, features), self.weight_ih, self.bias_ih) 54 | gh = F.linear(self.hidden, self.weight_hh, self.bias_hh) 55 | i_i, i_f, i_n = gi.chunk(3, 1) 56 | h_i, h_f = gh.chunk(2, 1) 57 | 58 | # f, i = sigmoid(Wx + Vh_tm1 + b) 59 | inputgate = self.sigmoid(i_i + h_i) 60 | forgetgate = self.sigmoid(i_f + h_f) 61 | newgate = i_n 62 | 63 | # h_t = f * tanh(h_tm1) + i * tanh(Wx) 64 | self.hidden = inputgate * self.tanh(newgate) + forgetgate * self.tanh(self.hidden) 65 | outputs[i] = self.hidden 66 | 67 | return outputs 68 | -------------------------------------------------------------------------------- /activations/bipolar.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from torch.autograd import Variable 5 | 6 | class Bipolar(nn.Module): 7 | def __init__(self, activation, input_size): 8 | super(Bipolar, self).__init__() 9 | self.activation = activation 10 | self.positive_indices = Variable(torch.Tensor([0]*input_size), requires_grad=False) 11 | for i in range(0, input_size, 2): 12 | self.positive_indices[i] = 1 13 | 14 | def forward(self, data): 15 | pos_output = self.activation(data) * self.positive_indices 16 | neg_output = -self.activation(-data) * (1 - self.positive_indices) 17 | return pos_output + neg_output 18 | 19 | 20 | if __name__ == "__main__": 21 | import numpy as np 22 | 23 | test_size = 10 24 | data = Variable(torch.randn(5,5,test_size)) 25 | 26 | print("Testing Bipolar(relu)", end=" ") 27 | activation = F.relu 28 | bipolar = Bipolar(activation, test_size) 29 | out = bipolar(data) 30 | failed = False 31 | for i in range(test_size): 32 | if i % 2 == 0: 33 | if not np.allclose(out[:, :, i].data.numpy(), activation(data[:, :, i]).data.numpy()): 34 | print("\nfailed at index ", i) 35 | failed = True 36 | else: 37 | if not np.allclose(out[:, :, i].data.numpy(), -activation(-data[:, :, i]).data.numpy()): 38 | print("\nfailed at index ", i) 39 | failed = True 40 | if not failed: 41 | print("-- passed") 42 | 43 | print("Testing Bipolar(elu)", end=" ") 44 | activation = F.elu 45 | bipolar = Bipolar(activation, test_size) 46 | out = bipolar(data) 47 | failed = False 48 | for i in range(test_size): 49 | if i % 2 == 0: 50 | if not np.allclose(out[:, :, i].data.numpy(), activation(data[:, :, i]).data.numpy()): 51 | print("\nfailed at index ", i) 52 | failed = True 53 | else: 54 | if not np.allclose(out[:, :, i].data.numpy(), -activation(-data[:, :, i]).data.numpy()): 55 | print("\nfailed at index ", i) 56 | failed = True 57 | if not failed: 58 | print("-- passed") 59 | 60 | print("Testing Bipolar(sigmoid)", end=" ") 61 | activation = F.sigmoid 62 | bipolar = Bipolar(activation, test_size) 63 | out = bipolar(data) 64 | failed = False 65 | for i in range(test_size): 66 | if i % 2 == 0: 67 | if not np.allclose(out[:, :, i].data.numpy(), activation(data[:, :, i]).data.numpy()): 68 | print("\nfailed at index ", i) 69 | failed = True 70 | else: 71 | if not np.allclose(out[:, :, i].data.numpy(), -activation(-data[:, :, i]).data.numpy()): 72 | print("\nfailed at index ", i) 73 | failed = True 74 | if not failed: 75 | print("-- passed") 76 | -------------------------------------------------------------------------------- /nn/gru.py: -------------------------------------------------------------------------------- 1 | import math 2 | import torch 3 | from torch.autograd import Variable 4 | from torch.nn.parameter import Parameter 5 | import torch.nn.functional as F 6 | 7 | from .rnn_cell_base import RNNCellBase 8 | 9 | class GRU(RNNCellBase): 10 | def __init__(self, input_size, hidden_size, bias=True, sigmoid=None, tanh=None): 11 | super(GRU, self).__init__() 12 | self.input_size = input_size 13 | self.hidden_size = hidden_size 14 | self.bias = bias 15 | self.sigmoid = F.sigmoid if sigmoid is None else self.get_activation(sigmoid) 16 | self.tanh = F.tanh if tanh is None else self.get_activation(tanh) 17 | self.weight_ih = Parameter(torch.Tensor(3 * hidden_size, input_size)) 18 | self.weight_hh = Parameter(torch.Tensor(3 * hidden_size, hidden_size)) 19 | if bias: 20 | self.bias_ih = Parameter(torch.Tensor(3 * hidden_size)) 21 | self.bias_hh = Parameter(torch.Tensor(3 * hidden_size)) 22 | else: 23 | self.register_parameter('bias_ih', None) 24 | self.register_parameter('bias_hh', None) 25 | self.reset_parameters() 26 | 27 | def reset_parameters(self): 28 | stdv = 1.0 / math.sqrt(self.hidden_size) 29 | for weight in self.parameters(): 30 | weight.data.uniform_(-stdv, stdv) 31 | self.hidden = None 32 | 33 | def reset_hidden(self): 34 | self.hidden = None 35 | 36 | def detach_hidden(self): 37 | self.hidden.detach_() 38 | 39 | def forward(self, input_data, future=0): 40 | timesteps, batch_size, features = input_data.size() 41 | # print("t %d, b %d, f %d" % (timesteps, batch_size, features)) 42 | outputs = Variable(torch.zeros(timesteps + future, batch_size, self.hidden_size), requires_grad=False) 43 | 44 | if self.hidden is None: 45 | self.hidden = Variable(torch.zeros(batch_size, self.hidden_size), requires_grad=False) 46 | 47 | self.check_forward_input(input_data[0]) 48 | self.check_forward_hidden(input_data[0], self.hidden) 49 | 50 | for i, input_t in enumerate(input_data.split(1)): 51 | 52 | gi = F.linear(input_t.view(batch_size, features), self.weight_ih, self.bias_ih) 53 | gh = F.linear(self.hidden, self.weight_hh, self.bias_hh) 54 | i_r, i_i, i_n = gi.chunk(3, 1) 55 | h_r, h_i, h_n = gh.chunk(3, 1) 56 | 57 | resetgate = self.sigmoid(i_r + h_r) 58 | inputgate = self.sigmoid(i_i + h_i) 59 | newgate = self.tanh(i_n + resetgate * h_n) 60 | self.hidden = newgate + inputgate * (self.hidden - newgate) 61 | outputs[i] = self.hidden 62 | """ 63 | outputs[i] = self._backend.GRUCell( 64 | input_t.view(batch_size, features), self.hidden, 65 | self.weight_ih, self.weight_hh, 66 | self.bias_ih, self.bias_hh, 67 | ) 68 | #""" 69 | 70 | return outputs 71 | -------------------------------------------------------------------------------- /nn/lstm.py: -------------------------------------------------------------------------------- 1 | import math 2 | import torch 3 | from torch.autograd import Variable 4 | from torch.nn.parameter import Parameter 5 | import torch.nn.functional as F 6 | 7 | from .rnn_cell_base import RNNCellBase 8 | 9 | class LSTM(RNNCellBase): 10 | 11 | def __init__(self, input_size, hidden_size, bias=True, sigmoid=None, tanh=None): 12 | super(LSTM, self).__init__() 13 | self.input_size = input_size 14 | self.hidden_size = hidden_size 15 | self.bias = bias 16 | self.sigmoid = F.sigmoid if sigmoid is None else self.get_activation(sigmoid) 17 | self.tanh = F.tanh if tanh is None else self.get_activation(tanh) 18 | self.weight_ih = Parameter(torch.Tensor(4 * hidden_size, input_size)) 19 | self.weight_hh = Parameter(torch.Tensor(4 * hidden_size, hidden_size)) 20 | if bias: 21 | self.bias_ih = Parameter(torch.Tensor(4 * hidden_size)) 22 | self.bias_hh = Parameter(torch.Tensor(4 * hidden_size)) 23 | else: 24 | self.register_parameter('bias_ih', None) 25 | self.register_parameter('bias_hh', None) 26 | self.reset_parameters() 27 | 28 | def reset_parameters(self): 29 | stdv = 1.0 / math.sqrt(self.hidden_size) 30 | for weight in self.parameters(): 31 | weight.data.uniform_(-stdv, stdv) 32 | self.hidden = None 33 | 34 | def reset_hidden(self): 35 | self.hidden = None 36 | 37 | def detach_hidden(self): 38 | self.hidden[0].detach_() 39 | self.hidden[1].detach_() 40 | 41 | def forward(self, input_data, future=0): 42 | timesteps, batch_size, features = input_data.size() 43 | # print("t %d, b %d, f %d" % (timesteps, batch_size, features)) 44 | outputs = Variable(torch.zeros(timesteps + future, batch_size, self.hidden_size), requires_grad=False) 45 | 46 | if self.hidden is None: 47 | self.hidden = (Variable(torch.zeros(batch_size, self.hidden_size), requires_grad=False), # h 48 | Variable(torch.zeros(batch_size, self.hidden_size), requires_grad=False)) # c 49 | 50 | self.check_forward_input(input_data[0]) 51 | self.check_forward_hidden(input_data[0], self.hidden[0], '[0]') 52 | self.check_forward_hidden(input_data[0], self.hidden[1], '[1]') 53 | 54 | for i, input_t in enumerate(input_data.split(1)): 55 | 56 | hx, cx = self.hidden 57 | gates = F.linear(input_t.view(batch_size, features), self.weight_ih, self.bias_ih) + \ 58 | F.linear(hx, self.weight_hh, self.bias_hh) 59 | 60 | ingate, forgetgate, cellgate, outgate = gates.chunk(4, 1) 61 | 62 | ingate = self.sigmoid(ingate) 63 | forgetgate = self.sigmoid(forgetgate) 64 | cellgate = self.tanh(cellgate) 65 | outgate = self.sigmoid(outgate) 66 | 67 | cy = (forgetgate * cx) + (ingate * cellgate) 68 | hy = outgate * self.tanh(cy) 69 | 70 | self.hidden = hy, cy 71 | outputs[i] = self.hidden[0] 72 | 73 | return outputs -------------------------------------------------------------------------------- /models.py: -------------------------------------------------------------------------------- 1 | 2 | from torch import nn 3 | import nn as custom 4 | 5 | 6 | class Model(nn.Module): 7 | def __init__(self, input_size=1, layers=["LSTM_51"], output_size=1, sigmoid=None, tanh=None, biases=True): 8 | super(Model, self).__init__() 9 | self.input_size = input_size 10 | self.output_size = output_size 11 | self.layers = [] 12 | prev_size = input_size 13 | for l, spec in enumerate(layers): 14 | bits = spec.split("_") 15 | cell_type = bits.pop(0) 16 | print(spec, cell_type, bits) 17 | 18 | if hasattr(custom, cell_type): 19 | layer = getattr(custom, cell_type) 20 | elif hasattr(nn, cell_type): 21 | layer = getattr(nn, cell_type) 22 | else: 23 | raise Exception("Unrecognised layer type " + cell_type) 24 | 25 | layer_args = {} 26 | if "input_size" in layer.__init__.__code__.co_varnames: 27 | layer_args["input_size"] = prev_size 28 | if "hidden_size" in layer.__init__.__code__.co_varnames: 29 | layer_args["hidden_size"] = int(bits.pop(0)) 30 | prev_size = layer_args["hidden_size"] 31 | 32 | for a in bits: 33 | print(a) 34 | k, v = a.split("=") 35 | k = k.replace("-", "_") 36 | if k not in layer.__init__.__code__.co_varnames: 37 | print("kwarg", k, "for", cell_type, "not recognised") 38 | continue 39 | for t in (int, float): 40 | try: 41 | v = t(v) 42 | break 43 | except ValueError: 44 | pass 45 | layer_args[k] = v 46 | 47 | if "tanh" in layer.__init__.__code__.co_varnames: 48 | layer_args["tanh"] = tanh 49 | if "sigmoid" in layer.__init__.__code__.co_varnames: 50 | layer_args["sigmoid"] = sigmoid 51 | if "bias" in layer.__init__.__code__.co_varnames: 52 | layer_args["bias"] = biases 53 | 54 | print("Adding layer of type", spec, ":", layer_args) 55 | layer = layer(**layer_args,) 56 | self.layers.append(layer) 57 | self.add_module("layer"+str(l), layer) 58 | 59 | if prev_size != output_size: 60 | print("Adding linear layer :", prev_size, "->", output_size) 61 | layer = nn.Linear(prev_size, output_size) 62 | self.layers.append(layer) 63 | self.add_module("layer"+str(l+1), layer) 64 | 65 | def reset_hidden(self): 66 | for layer in self.layers: 67 | if hasattr(layer, "reset_hidden"): 68 | layer.reset_hidden() 69 | # for module in self.modules(): 70 | # if module is not self and hasattr(module, "reset_hidden"): 71 | # module.reset_hidden() 72 | 73 | def detach_hidden(self): 74 | for layer in self.layers: 75 | if hasattr(layer, "detach_hidden"): 76 | layer.detach_hidden() 77 | 78 | def forward(self, data, future=0): 79 | for layer in self.layers: 80 | data = layer(data) 81 | return data 82 | -------------------------------------------------------------------------------- /nn/causal.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | class CausalConv1d(nn.Conv1d): 6 | def __init__(self, 7 | input_size, 8 | hidden_size, 9 | kernel_size, 10 | stride=1, 11 | dilation=1, 12 | groups=1, 13 | bias=True, 14 | sigmoid=None, 15 | tanh=None): 16 | self.left_padding = (kernel_size - 1) * dilation 17 | super(CausalConv1d, self).__init__( 18 | input_size, 19 | hidden_size, 20 | kernel_size, 21 | stride=stride, 22 | padding=0, 23 | dilation=dilation, 24 | groups=groups, 25 | bias=bias) 26 | 27 | def forward(self, input): 28 | # data is in shape (timesteps, batches, features) 29 | # conv needs shape (batches, features, timesteps) 30 | x = F.pad(input.permute(1, 2, 0), (self.left_padding,0)) 31 | conv_out = super(CausalConv1d, self).forward(x) 32 | # must return shape (timesteps, batches, features) 33 | return conv_out.permute(2, 0, 1) 34 | 35 | class Wave(nn.Module): 36 | def __init__(self, input_size, hidden_size, layers=3, activation="tanh"): 37 | super(Wave, self).__init__() 38 | self.layers = [] 39 | prev_size = input_size 40 | for layer in range(layers): 41 | conv = CausalConv1d(prev_size, hidden_size, kernel_size=2, dilation=2**layer) 42 | self.layers.append(conv) 43 | self.add_module("layer"+str(layer), conv) 44 | prev_size = hidden_size 45 | 46 | def forward(self, data): 47 | for layer in self.layers: 48 | data = layer(data) 49 | return data 50 | 51 | class ShortWave(nn.Module): 52 | def __init__(self, input_size, hidden_size, layers=3): 53 | super(ShortWave, self).__init__() 54 | self.layers = [] 55 | prev_size = input_size 56 | for layer in range(layers): 57 | conv = CausalConv1d(prev_size, hidden_size, kernel_size=2, dilation=1) 58 | self.layers.append(conv) 59 | self.add_module("layer"+str(layer), conv) 60 | prev_size = hidden_size 61 | 62 | def forward(self, data): 63 | for layer in self.layers: 64 | data = layer(data) 65 | return data 66 | 67 | def test_CausalConv1d(timesteps, input_size, hidden_size, batch_size, kernel_size, dilation, bias): 68 | m = CausalConv1d(input_size, hidden_size, kernel_size=kernel_size, dilation=dilation, bias=bias!=0) 69 | m.weight.data.fill_(1) 70 | if bias: 71 | m.bias.data.fill_(bias) 72 | x = torch.autograd.Variable(torch.zeros(timesteps, batch_size, input_size), requires_grad=False) 73 | 74 | for batch in range(batch_size): 75 | for t in range(timesteps): 76 | for ci in range(input_size): 77 | x.data.fill_(0) 78 | x[t, batch, ci] = 1 79 | out = m(x) 80 | for b in range(batch_size): 81 | for co in range(hidden_size): 82 | if b == batch: 83 | target = [1+bias if j in range(t, t+k*d, d) else bias for j in range(timesteps)] 84 | else: 85 | target = [bias for j in range(timesteps)] 86 | if list(out[:, b, co].data) != target: 87 | print("\nCausalConv1d wrong output for kernel_size", k, 88 | "and dilation", d, "i", input_size, "out", hidden_size, 89 | "batch_size", batch_size, 90 | "bias", bias) 91 | print("input ", " ".join(str(int(el)) for el in x[:, b, co].data)) 92 | print("output", " ".join(str(el) for el in out[:, b, co].data)) 93 | print("target", " ".join(str(el) for el in target)) 94 | assert list(out[:, b, co].data) == target, "Test failed" 95 | 96 | if __name__ == "__main__": 97 | import numpy as np 98 | timesteps, batch_size = 20, 3 99 | print("Running tests", end="") 100 | for ci in range(1, 3): 101 | for co in range(1, 3): 102 | for k in range(1, 4): 103 | for d in range(1, 3): 104 | print(".", end="", flush=True) 105 | test_CausalConv1d(timesteps, ci, co, batch_size, k, d, 0.5) 106 | test_CausalConv1d(timesteps, ci, co, batch_size, k, d, 0) 107 | print("\nCausalConv1d tests passed") 108 | -------------------------------------------------------------------------------- /experiment.py: -------------------------------------------------------------------------------- 1 | 2 | if __name__ == '__main__' and __package__ is None: 3 | import os 4 | os.sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) 5 | __package__ = "pytorch_bits" 6 | 7 | import sys 8 | import time 9 | from argparse import ArgumentParser 10 | import numpy as np 11 | import matplotlib 12 | # matplotlib.use('WebAgg') 13 | import matplotlib.pyplot as plt 14 | 15 | import torch 16 | import torch.nn as nn 17 | from torch.autograd import Variable 18 | import torch.optim as optim 19 | 20 | import data_generation 21 | import models 22 | import optim as custom_optim 23 | # from pytorch_custom.yellowfin import YFOptimizer 24 | 25 | parser = ArgumentParser(description='PyTorch example') 26 | parser.add_argument('--data_fn', type=str, default="sine_3") 27 | parser.add_argument('--epochs', type=int, default=50) 28 | parser.add_argument('--batch_size', type=int, default=50) 29 | parser.add_argument('--length', type=int, default=1000) 30 | parser.add_argument('--add_noise', action='store_true') 31 | parser.add_argument('--lr', type=float, default=.0001) 32 | parser.add_argument('--seq_len', type=int, default=100) 33 | parser.add_argument('--layers', type=str, nargs="+", default=["LSTM_51"]) 34 | parser.add_argument('--sigmoid', type=str, default=None) 35 | parser.add_argument('--tanh', type=str, default=None) 36 | parser.add_argument('--warmup', type=int, default=10) 37 | parser.add_argument('--optim', type=str, default='Adam_HD') 38 | parser.add_argument('--seed', type=int, default=None) 39 | parser.add_argument('--verbose', action='store_true') 40 | args = parser.parse_args() 41 | print(args) 42 | 43 | if args.seed is not None: 44 | np.random.seed(args.seed) 45 | torch.manual_seed(args.seed) 46 | 47 | X_train, X_val, X_test, y_train, y_val, y_test = data_generation.generate_data( 48 | data_fn=args.data_fn, batch_size=args.batch_size, 49 | length=args.length, add_noise=args.add_noise) 50 | 51 | rnn = models.Model(input_size=X_train.size(-1), layers=args.layers, output_size=y_train.size(-1), 52 | sigmoid=args.sigmoid, tanh=args.tanh) 53 | print(rnn) 54 | print(sum([p.numel() for p in rnn.parameters() if p.requires_grad]), "trainable parameters") 55 | 56 | loss_fn = nn.MSELoss() 57 | if hasattr(custom_optim, args.optim): 58 | optimizer = getattr(custom_optim, args.optim)(rnn.parameters()) 59 | else: 60 | optimizer = getattr(optim, args.optim)(rnn.parameters()) 61 | """ 62 | Training with ground truth -- The input is the ground truth 63 | """ 64 | try: 65 | val_loss_list = [] 66 | start = time.time() 67 | for epoch in range(args.epochs): 68 | training_loss = 0 69 | rnn.train(True) 70 | rnn.reset_hidden() 71 | for batch, data, target in data_generation.get_batches(X_train, y_train, seq_len=args.seq_len, reason="training"): 72 | output = rnn(data) 73 | optimizer.zero_grad() 74 | if batch == 0: 75 | loss = loss_fn(output[args.warmup:], target[args.warmup:]) 76 | else: 77 | loss = loss_fn(output, target) 78 | if args.verbose: 79 | print("Input: mean", data.mean().data[0], "std", data.std().data[0]) 80 | print("Output: mean", output.mean().data[0], "std", output.std().data[0]) 81 | print("Target: mean", target.mean().data[0], "std", target.std().data[0]) 82 | loss.backward(retain_graph=True) 83 | if batch > 0: 84 | optimizer.step() 85 | training_loss += loss.data[0] 86 | rnn.detach_hidden() 87 | training_loss /= batch + 1 88 | 89 | val_loss = 0 90 | rnn.train(False) 91 | rnn.reset_hidden() 92 | for batch, data, targets in data_generation.get_batches(X_val, y_val, seq_len=args.seq_len, reason="validation"): 93 | output = rnn(data) 94 | loss = loss_fn(output, targets) 95 | val_loss += loss.data[0] 96 | val_loss /= batch + 1 97 | val_loss_list.append(val_loss) 98 | print("Ground truth - Epoch " + str(epoch) + " -- train loss = " + str(training_loss) + " -- val loss = " + str(val_loss) 99 | + " -- time %.1fs" % ((time.time() - start) / (epoch + 1))) 100 | 101 | except KeyboardInterrupt: 102 | print("\nTraining interrupted") 103 | 104 | """ 105 | Measuring the test score -> running the test data on the model 106 | """ 107 | rnn.train(False) 108 | rnn.reset_hidden() 109 | test_loss = 0 110 | list1 = [] 111 | list2 = [] 112 | for batch, data, targets in data_generation.get_batches(X_test, y_test, seq_len=args.seq_len, reason="testing"): 113 | output = rnn(data) 114 | loss = loss_fn(output, targets) 115 | test_loss += loss.data[0] 116 | target_last_point = torch.squeeze(targets[:, -1]).data.cpu().numpy().tolist() 117 | pred_last_point = torch.squeeze(output[:, -1]).data.cpu().numpy().tolist() 118 | list1 += target_last_point 119 | list2 += pred_last_point 120 | if len(list1) > 400: 121 | break 122 | plt.figure(1) 123 | plt.plot(list1, "b") 124 | plt.plot(list2, "r") 125 | plt.legend(["Original data", "Generated data"]) 126 | test_loss /= batch + 1 127 | print("Test loss = ", test_loss) 128 | 129 | if epoch > 0: 130 | plt.show() 131 | -------------------------------------------------------------------------------- /data_generation.py: -------------------------------------------------------------------------------- 1 | import time 2 | import numpy as np 3 | import matplotlib.pyplot as plt 4 | import torch 5 | from torch.autograd import Variable 6 | import signalz 7 | 8 | import platform 9 | print("Detected platform is", platform.system()) 10 | # in bash on Ubuntu printing "\x1b[0K" will clear to the end of the line 11 | CLR = "\x1b[0K" if platform.system() != "Windows" else "" 12 | 13 | def get_batches(X, y, seq_len=100, reason=None): 14 | if seq_len > len(X): seq_len = len(X) 15 | batches = len(X) // seq_len 16 | leftover = len(X) % seq_len 17 | if reason == "training" and leftover > 0: 18 | offset = np.random.randint(leftover) 19 | else: 20 | offset = 0 21 | 22 | message = str(batches) 23 | if reason is not None: 24 | message += " " + reason 25 | message += " batches" 26 | if offset > 0: 27 | message += " @" + str(offset) 28 | 29 | start_time = time.time() 30 | for batch in range(batches): 31 | start = batch * seq_len 32 | end = start + seq_len 33 | yield batch, X[start:end], y[start:end] 34 | if batch + 1 < batches: 35 | print("\r%s -- %.1f%% done -- time %.5f ms/sample" % (message, 36 | 100. * (batch + 1) / batches, 37 | 1000 * (time.time()-start_time) / (batch + 1) / (seq_len * X.size(1))), 38 | end=CLR, flush=True) 39 | print("\r", end=CLR, flush=True) 40 | 41 | def sine_1(length, pattern_length=60., add_noise=False, noise_range=(-0.1, 0.1)): 42 | X = np.arange(length) + np.random.randint(pattern_length) 43 | signal = np.sin(2 * np.pi * (X) / pattern_length) 44 | if add_noise: 45 | signal += np.random.uniform(noise_range[0], noise_range[1], size=signal.shape) 46 | return signal 47 | 48 | def sine_2(length, pattern_length=60., add_noise=False, noise_range=(-0.1, 0.1)): 49 | X = np.arange(length) + np.random.randint(pattern_length) 50 | signal = (np.sin(2 * np.pi * (X) / pattern_length) + np.sin(2 * 2 * np.pi * (X) / pattern_length)) / 2.0 51 | if add_noise: 52 | signal += np.random.uniform(noise_range[0], noise_range[1], size=signal.shape) 53 | return signal 54 | 55 | def sine_3(length, pattern_length=60., add_noise=False, noise_range=(-0.1, 0.1)): 56 | X = np.arange(length) + np.random.randint(pattern_length) 57 | signal = (np.sin(2 * np.pi * (X) / pattern_length) + np.sin(2 * 2 * np.pi * (X) / pattern_length) + np.sin(2 * 3 * np.pi * (X) / pattern_length)) / 3.0 58 | if add_noise: 59 | signal += np.random.uniform(noise_range[0], noise_range[1], size=signal.shape) 60 | return signal 61 | 62 | def mackey_glass(length, add_noise=False, noise_range=(-0.01, 0.01)): 63 | initial = .25 + .5 * np.random.rand() 64 | signal = signalz.mackey_glass(length, a=0.2, b=0.8, c=0.9, d=23, e=10, initial=initial) 65 | if add_noise: 66 | signal += np.random.uniform(noise_range[0], noise_range[1], size=signal.shape) 67 | return signal - 1. 68 | 69 | def levy_flight(length, add_noise=False, noise_range=(-0.01, 0.01)): 70 | offset = np.random.randint(length // 2) 71 | signal = signalz.levy_flight(length + offset, alpha=1.8, beta=0., sigma=.01, position=0) 72 | return signal[offset:] - 1. 73 | 74 | def brownian(length, add_noise=False, noise_range=(-0.01, 0.01)): 75 | return signalz.brownian_noise(length, leak=0.1, start=0, std=.1, source="gaussian") 76 | 77 | generators = { 78 | "sine_1": sine_1, 79 | "sine_2": sine_2, 80 | "sine_3": sine_3, 81 | "mackey_glass": mackey_glass, 82 | "levy_flight": levy_flight, 83 | "brownian": brownian, 84 | } 85 | 86 | 87 | def generate_data(data_fn="sine_1", length=10000, pattern_length=60., batch_size=32, add_noise=False): 88 | X = np.empty((length, batch_size, 1)) 89 | y = np.empty((length, batch_size, 1)) 90 | 91 | for b in range(batch_size): 92 | x_data = generators[data_fn](length + 1, add_noise=add_noise) 93 | 94 | if b == 0: 95 | plt.figure("Synthetic data", figsize=(15, 10)) 96 | plt.title("Synthetic data") 97 | plt.plot(range(min(1000, length)), x_data[:min(1000, length)]) 98 | 99 | X[:,b,0] = x_data[:-1] 100 | y[:,b,0] = x_data[1:] 101 | 102 | plt.savefig("synthetic_data.png") 103 | plt.close() 104 | 105 | # 70% training, 10% validation, 20% testing 106 | train_sep = int(length * 0.7) 107 | val_sep = train_sep + int(length * 0.1) 108 | 109 | X_train = Variable(torch.from_numpy(X[:train_sep, :]).float(), requires_grad=False) 110 | y_train = Variable(torch.from_numpy(y[:train_sep, :]).float(), requires_grad=False) 111 | 112 | X_val = Variable(torch.from_numpy(X[train_sep:val_sep, :]).float(), requires_grad=False) 113 | y_val = Variable(torch.from_numpy(y[train_sep:val_sep, :]).float(), requires_grad=False) 114 | 115 | X_test = Variable(torch.from_numpy(X[val_sep:, :]).float(), requires_grad=False) 116 | y_test = Variable(torch.from_numpy(y[val_sep:, :]).float(), requires_grad=False) 117 | 118 | print(("X_train size = {}, X_val size = {}, X_test size = {}".format(X_train.size(), X_val.size(), X_test.size()))) 119 | 120 | return X_train, X_val, X_test, y_train, y_val, y_test 121 | -------------------------------------------------------------------------------- /nn/mgu.py: -------------------------------------------------------------------------------- 1 | import math 2 | import torch 3 | from torch.autograd import Variable 4 | from torch.nn.parameter import Parameter 5 | import torch.nn.functional as F 6 | 7 | from .rnn_cell_base import RNNCellBase 8 | 9 | class MGU(RNNCellBase): 10 | 11 | def __init__(self, input_size, hidden_size, bias=True, sigmoid=None, tanh=None): 12 | super(MGU, self).__init__() 13 | self.input_size = input_size 14 | self.hidden_size = hidden_size 15 | self.bias = bias 16 | self.sigmoid = F.sigmoid if sigmoid is None else self.get_activation(sigmoid) 17 | self.tanh = F.tanh if tanh is None else self.get_activation(tanh) 18 | self.weight_ih = Parameter(torch.Tensor(2 * hidden_size, input_size)) 19 | self.weight_hh = Parameter(torch.Tensor(2 * hidden_size, hidden_size)) 20 | if bias: 21 | self.bias_ih = Parameter(torch.Tensor(2 * hidden_size)) 22 | self.bias_hh = Parameter(torch.Tensor(2 * hidden_size)) 23 | else: 24 | self.register_parameter('bias_ih', None) 25 | self.register_parameter('bias_hh', None) 26 | self.reset_parameters() 27 | 28 | def reset_parameters(self): 29 | stdv = 1.0 / math.sqrt(self.hidden_size) 30 | for weight in self.parameters(): 31 | weight.data.uniform_(-stdv, stdv) 32 | self.hidden = None 33 | 34 | def reset_hidden(self): 35 | self.hidden = None 36 | self.timesteps = 0 37 | 38 | def detach_hidden(self): 39 | self.hidden.detach_() 40 | 41 | def forward(self, input_data, future=0): 42 | timesteps, batch_size, features = input_data.size() 43 | # print("t %d, b %d, f %d" % (timesteps, batch_size, features)) 44 | outputs = Variable(torch.zeros(timesteps + future, batch_size, self.hidden_size), requires_grad=False) 45 | 46 | if self.hidden is None: 47 | self.hidden = Variable(torch.zeros(batch_size, self.hidden_size), requires_grad=False) 48 | 49 | self.check_forward_input(input_data[0]) 50 | self.check_forward_hidden(input_data[0], self.hidden) 51 | 52 | for i, input_t in enumerate(input_data.split(1)): 53 | 54 | gi = F.linear(input_t.view(batch_size, features), self.weight_ih, self.bias_ih) 55 | gh = F.linear(self.hidden, self.weight_hh, self.bias_hh) 56 | i_f, i_n = gi.chunk(2, 1) 57 | h_f, h_n = gh.chunk(2, 1) 58 | 59 | forgetgate = self.sigmoid(i_f + h_f) 60 | newgate = self.tanh(i_n + forgetgate * h_n) 61 | self.hidden = newgate + (1 - forgetgate) * (self.hidden - newgate) 62 | outputs[i] = self.hidden 63 | 64 | return outputs 65 | 66 | class MGU2(RNNCellBase): 67 | 68 | def __init__(self, input_size, hidden_size, bias=True, sigmoid=None, tanh=None): 69 | super(MGU2, self).__init__() 70 | self.input_size = input_size 71 | self.hidden_size = hidden_size 72 | self.bias = bias 73 | self.sigmoid = F.sigmoid if sigmoid is None else self.get_activation(sigmoid) 74 | self.tanh = F.tanh if tanh is None else self.get_activation(tanh) 75 | self.weight_ih = Parameter(torch.Tensor(hidden_size, input_size)) 76 | self.weight_hh = Parameter(torch.Tensor(2 * hidden_size, hidden_size)) 77 | if bias: 78 | self.bias_ih = Parameter(torch.Tensor(hidden_size)) 79 | self.bias_hh = Parameter(torch.Tensor(2 * hidden_size)) 80 | else: 81 | self.register_parameter('bias_ih', None) 82 | self.register_parameter('bias_hh', None) 83 | self.reset_parameters() 84 | 85 | def reset_parameters(self): 86 | stdv = 1.0 / math.sqrt(self.hidden_size) 87 | for weight in self.parameters(): 88 | weight.data.uniform_(-stdv, stdv) 89 | self.hidden = None 90 | 91 | def reset_hidden(self): 92 | self.hidden = None 93 | self.timesteps = 0 94 | 95 | def detach_hidden(self): 96 | self.hidden.detach_() 97 | 98 | def forward(self, input_data, future=0): 99 | timesteps, batch_size, features = input_data.size() 100 | # print("t %d, b %d, f %d" % (timesteps, batch_size, features)) 101 | outputs = Variable(torch.zeros(timesteps + future, batch_size, self.hidden_size), requires_grad=False) 102 | 103 | if self.hidden is None: 104 | self.hidden = Variable(torch.zeros(batch_size, self.hidden_size), requires_grad=False) 105 | 106 | self.check_forward_input(input_data[0]) 107 | self.check_forward_hidden(input_data[0], self.hidden) 108 | 109 | for i, input_t in enumerate(input_data.split(1)): 110 | 111 | i_n = F.linear(input_t.view(batch_size, features), self.weight_ih, self.bias_ih) 112 | gh = F.linear(self.hidden, self.weight_hh, self.bias_hh) 113 | h_f, h_n = gh.chunk(2, 1) 114 | 115 | forgetgate = self.sigmoid(h_f) 116 | newgate = self.tanh(i_n + forgetgate * h_n) 117 | self.hidden = newgate + (1 - forgetgate) * (self.hidden - newgate) 118 | outputs[i] = self.hidden 119 | 120 | return outputs 121 | -------------------------------------------------------------------------------- /optim/adam_hd.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.optim.optimizer import Optimizer, required 3 | 4 | class Adam_HD_lr_per_param(Optimizer): 5 | 6 | def __init__(self, params, lr=1e-3, lr_lr=.1, betas=(0.9, 0.999), eps=1e-8, 7 | weight_decay=0): 8 | defaults = dict(lr=lr, lr_lr=lr_lr, betas=betas, eps=eps, 9 | weight_decay=weight_decay) 10 | super(Adam_HD_lr_per_param, self).__init__(params, defaults) 11 | 12 | def step(self, closure=None): 13 | loss = None 14 | if closure is not None: 15 | loss = closure() 16 | 17 | for group in self.param_groups: 18 | for p in group['params']: 19 | if p.grad is None: 20 | continue 21 | grad = p.grad.data 22 | if grad.is_sparse: 23 | raise RuntimeError('Adam_HD does not support sparse gradients, please consider SparseAdam instead') 24 | 25 | state = self.state[p] 26 | 27 | # State initialization 28 | if len(state) == 0: 29 | state['step'] = 0 30 | state['lr'] = torch.zeros_like(p.data).fill_(group['lr']) 31 | # Exponential moving average of gradient values 32 | state['m'] = torch.zeros_like(p.data) 33 | # Exponential moving average of squared gradient values 34 | state['v'] = torch.zeros_like(p.data) 35 | # For calculating df/dlr 36 | state['m_debiased_tm1'] = torch.zeros_like(p.data) 37 | state['v_debiased_tm1'] = torch.zeros_like(p.data) 38 | 39 | m, m_debiased_tm1 = state['m'], state['m_debiased_tm1'] 40 | v, v_debiased_tm1 = state['v'], state['v_debiased_tm1'] 41 | beta1, beta2 = group['betas'] 42 | 43 | state['step'] += 1 44 | 45 | if group['weight_decay'] != 0: 46 | grad = grad.add(group['weight_decay'], p.data) 47 | 48 | # Decay the first and second moment running average coefficient 49 | m.mul_(beta1).add_(1 - beta1, grad) 50 | v.mul_(beta2).addcmul_(1 - beta2, grad, grad) 51 | 52 | # Bias corrections 53 | m_debiased = m.div(1 - beta1 ** state['step']) 54 | v_debiased = v.div(1 - beta2 ** state['step']) 55 | 56 | # Update learning rate 57 | h = grad * (-m_debiased_tm1 / (torch.sqrt(v_debiased_tm1) + group['eps'])) 58 | state['lr'].add_(-group['lr_lr'], h) 59 | 60 | p.data.addcdiv_(-state['lr'] * m_debiased, (torch.sqrt(v_debiased) + group['eps'])) 61 | 62 | m_debiased_tm1.copy_(m_debiased) 63 | v_debiased_tm1.copy_(v_debiased) 64 | 65 | return loss 66 | 67 | class Adam_HD(Optimizer): 68 | 69 | def __init__(self, params, lr=1e-3, lr_lr=.1, betas=(0.9, 0.999), eps=1e-8, 70 | weight_decay=0): 71 | defaults = dict(lr=lr, lr_lr=lr_lr, betas=betas, eps=eps, 72 | weight_decay=weight_decay) 73 | super(Adam_HD, self).__init__(params, defaults) 74 | 75 | def step(self, closure=None): 76 | loss = None 77 | if closure is not None: 78 | loss = closure() 79 | 80 | for group in self.param_groups: 81 | for p in group['params']: 82 | if p.grad is None: 83 | continue 84 | grad = p.grad.data 85 | if grad.is_sparse: 86 | raise RuntimeError('Adam_HD does not support sparse gradients, please consider SparseAdam instead') 87 | 88 | state = self.state[p] 89 | 90 | # State initialization 91 | if len(state) == 0: 92 | state['step'] = 0 93 | state['lr'] = group['lr'] 94 | # Exponential moving average of gradient values 95 | state['m'] = torch.zeros_like(p.data) 96 | # Exponential moving average of squared gradient values 97 | state['v'] = torch.zeros_like(p.data) 98 | # For calculating df/dlr 99 | state['m_debiased_tm1'] = torch.zeros_like(p.data) 100 | state['v_debiased_tm1'] = torch.zeros_like(p.data) 101 | 102 | m, m_debiased_tm1 = state['m'], state['m_debiased_tm1'] 103 | v, v_debiased_tm1 = state['v'], state['v_debiased_tm1'] 104 | beta1, beta2 = group['betas'] 105 | 106 | state['step'] += 1 107 | 108 | if group['weight_decay'] != 0: 109 | grad = grad.add(group['weight_decay'], p.data) 110 | 111 | # Decay the first and second moment running average coefficient 112 | m.mul_(beta1).add_(1 - beta1, grad) 113 | v.mul_(beta2).addcmul_(1 - beta2, grad, grad) 114 | 115 | # Bias corrections 116 | m_debiased = m.div(1 - beta1 ** state['step']) 117 | v_debiased = v.div(1 - beta2 ** state['step']) 118 | 119 | # Update learning rate 120 | h = grad * (-m_debiased_tm1 / (torch.sqrt(v_debiased_tm1) + group['eps'])) 121 | state['lr'] -= group['lr_lr'] * h.mean() 122 | 123 | p.data.addcdiv_(-state['lr'] * m_debiased, (torch.sqrt(v_debiased) + group['eps'])) 124 | 125 | m_debiased_tm1.copy_(m_debiased) 126 | v_debiased_tm1.copy_(v_debiased) 127 | 128 | return loss -------------------------------------------------------------------------------- /nn/sru.py: -------------------------------------------------------------------------------- 1 | import math 2 | import torch 3 | from torch.autograd import Variable 4 | from torch.nn.parameter import Parameter 5 | import torch.nn.functional as F 6 | 7 | from .rnn_cell_base import RNNCellBase 8 | 9 | class SRUf(RNNCellBase): 10 | """The simplest SRU mentioned in the paper.""" 11 | 12 | def __init__(self, input_size, hidden_size, bias=True, sigmoid=None, tanh=None): 13 | super(SRU, self).__init__() 14 | self.input_size = input_size 15 | self.hidden_size = hidden_size 16 | self.bias = bias 17 | self.sigmoid = F.sigmoid if sigmoid is None else self.get_activation(sigmoid) 18 | self.tanh = F.tanh if tanh is None else self.get_activation(tanh) 19 | self.weight_ih = Parameter(torch.Tensor(2 * hidden_size, input_size)) 20 | if bias: 21 | self.bias_ih = Parameter(torch.Tensor(2 * hidden_size)) 22 | else: 23 | self.register_parameter('bias_ih', None) 24 | self.reset_parameters() 25 | 26 | def reset_parameters(self): 27 | stdv = 1.0 / math.sqrt(self.hidden_size) 28 | for weight in self.parameters(): 29 | weight.data.uniform_(-stdv, stdv) 30 | self.hidden = None 31 | 32 | def reset_hidden(self): 33 | self.hidden = None 34 | 35 | def detach_hidden(self): 36 | self.hidden.detach_() 37 | 38 | def forward(self, input_data, future=0): 39 | timesteps, batch_size, features = input_data.size() 40 | # print("t %d, b %d, f %d" % (timesteps, batch_size, features)) 41 | outputs = Variable(torch.zeros(timesteps + future, batch_size, self.hidden_size), requires_grad=False) 42 | 43 | if self.hidden is None: 44 | self.hidden = Variable(torch.zeros(batch_size, self.hidden_size), requires_grad=False) 45 | 46 | self.check_forward_input(input_data[0]) 47 | self.check_forward_hidden(input_data[0], self.hidden) 48 | 49 | for i, input_t in enumerate(input_data.split(1)): 50 | 51 | gi = F.linear(input_t.view(batch_size, features), self.weight_ih, self.bias_ih) 52 | i_f, i_n = gi.chunk(2, 1) 53 | 54 | forgetgate = self.sigmoid(i_f) 55 | newgate = i_n 56 | self.hidden = newgate + forgetgate * (self.hidden - newgate) 57 | outputs[i] = self.tanh(self.hidden) 58 | 59 | return outputs 60 | 61 | class SRU2(RNNCellBase): 62 | 63 | def __init__(self, input_size, hidden_size, bias=True, sigmoid=None, tanh=None): 64 | super(SRU2, self).__init__() 65 | self.input_size = input_size 66 | self.hidden_size = hidden_size 67 | self.bias = bias 68 | self.sigmoid = F.sigmoid if sigmoid is None else self.get_activation(sigmoid) 69 | self.tanh = F.tanh if tanh is None else self.get_activation(tanh) 70 | self.weight_ih = Parameter(torch.Tensor(3 * hidden_size, input_size)) 71 | if bias: 72 | self.bias_ih = Parameter(torch.Tensor(3 * hidden_size)) 73 | else: 74 | self.register_parameter('bias_ih', None) 75 | self.reset_parameters() 76 | 77 | def reset_parameters(self): 78 | stdv = 1.0 / math.sqrt(self.hidden_size) 79 | for weight in self.parameters(): 80 | weight.data.uniform_(-stdv, stdv) 81 | self.hidden = None 82 | 83 | def reset_hidden(self): 84 | self.hidden = None 85 | 86 | def detach_hidden(self): 87 | self.hidden.detach_() 88 | 89 | def forward(self, input_data, future=0): 90 | timesteps, batch_size, features = input_data.size() 91 | # print("t %d, b %d, f %d" % (timesteps, batch_size, features)) 92 | outputs = Variable(torch.zeros(timesteps + future, batch_size, self.hidden_size), requires_grad=False) 93 | 94 | if self.hidden is None: 95 | self.hidden = Variable(torch.zeros(batch_size, self.hidden_size), requires_grad=False) 96 | 97 | self.check_forward_input(input_data[0]) 98 | self.check_forward_hidden(input_data[0], self.hidden) 99 | 100 | for i, input_t in enumerate(input_data.split(1)): 101 | 102 | gi = F.linear(input_t.view(batch_size, features), self.weight_ih, self.bias_ih) 103 | i_i, i_f, i_n = gi.chunk(3, 1) 104 | 105 | inputgate = self.sigmoid(i_i) 106 | forgetgate = self.sigmoid(i_f) 107 | newgate = i_n 108 | self.hidden = inputgate * newgate + forgetgate * self.hidden 109 | outputs[i] = self.tanh(self.hidden) 110 | 111 | return outputs 112 | 113 | class SRU(RNNCellBase): 114 | 115 | def __init__(self, input_size, hidden_size, bias=True, sigmoid=None, tanh=None, gpu=False): 116 | super(SRU, self).__init__() 117 | self.input_size = input_size 118 | self.hidden_size = hidden_size 119 | self.bias = bias 120 | self.gpu = gpu 121 | self.sigmoid = F.sigmoid if sigmoid is None else self.get_activation(sigmoid) 122 | self.tanh = F.tanh if tanh is None else self.get_activation(tanh) 123 | self.weight_ih = Parameter(torch.Tensor(3 * hidden_size, input_size)) 124 | if bias: 125 | self.bias_ih = Parameter(torch.Tensor(3 * hidden_size)) 126 | else: 127 | self.register_parameter('bias_ih', None) 128 | self.reset_parameters() 129 | 130 | def reset_parameters(self): 131 | stdv = 1.0 / math.sqrt(self.hidden_size) 132 | for weight in self.parameters(): 133 | weight.data.uniform_(-stdv, stdv) 134 | self.hidden = None 135 | 136 | def reset_hidden(self): 137 | self.hidden = None 138 | 139 | def detach_hidden(self): 140 | self.hidden.detach_() 141 | 142 | def forward(self, input_data, future=0): 143 | timesteps, batch_size, features = input_data.size() 144 | # print("t %d, b %d, f %d" % (timesteps, batch_size, features)) 145 | outputs = Variable(torch.zeros(timesteps + future, batch_size, self.hidden_size), requires_grad=False) 146 | 147 | if self.hidden is None: 148 | self.hidden = Variable(torch.zeros(batch_size, self.hidden_size), requires_grad=False) 149 | 150 | self.check_forward_input(input_data[0]) 151 | self.check_forward_hidden(input_data[0], self.hidden) 152 | 153 | if self.gpu: 154 | gis = F.linear(input_data, self.weight_ih, self.bias_ih) 155 | 156 | for i, gi in enumerate(gis.split(1)): 157 | i_r, i_f, i_n = gi.squeeze().chunk(3, 1) 158 | 159 | readgate = self.sigmoid(i_r) 160 | forgetgate = self.sigmoid(i_f) 161 | newgate = i_n 162 | self.hidden = newgate + forgetgate * (self.hidden - newgate) 163 | outputs[i] = newgate + readgate * (self.tanh(self.hidden) - newgate) 164 | else: 165 | for i, input_t in enumerate(input_data.split(1)): 166 | x = input_t.view(batch_size, features) 167 | gi = F.linear(x, self.weight_ih, self.bias_ih) 168 | i_r, i_f, i_n = gi.chunk(3, 1) 169 | 170 | readgate = self.sigmoid(i_r) 171 | forgetgate = self.sigmoid(i_f) 172 | newgate = i_n 173 | self.hidden = newgate + forgetgate * (self.hidden - newgate) 174 | outputs[i] = newgate + readgate * (self.tanh(self.hidden) - newgate) 175 | 176 | return outputs 177 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # pytorch-bits 2 | 3 | Experiments for fun and education. Mostly concerning time-series prediction. 4 | 5 | I started my experiments with [@osm3000](https://github.com/osm3000)'s [sequence_generation_pytorch](https://github.com/osm3000/sequence_generation_pytorch/) repo and some of that code still subsists in these files. 6 | 7 | ## How to run these experiments 8 | 9 | 1. clone/download this repo 10 | 1. `pip install -r requirements.txt` 11 | 1. `python experiment.py [ARGS]` 12 | 13 | Possible arguments include... 14 | * `--data_fn FN` where FN is one of the data generation functions listed below 15 | * `--add_noise` to add noise the generated waveform 16 | * `--length TIMESERIES_LENGTH` 17 | * `--batch_size BATCH_SIZE` 18 | * `--seq_len SEQ_LEN` the subsequence length used in training 19 | * `--epochs MAX_EPOCHS` 20 | * `--lr LR` 21 | * `--layers LAYERTYPE_SIZE [LAYERTYPE_SIZE ...]` see the section on Model generation 22 | * `--sigmoid REPLACEMENT` to use an alternative to sigmoid, this can be of the activations mentioned below, e.g. `ISRU_sigmoid`, or any function from torch.nn.functional 23 | * `--tanh REPLACEMENT` to use an alternative to tanh, this must be one of the activations mentioned below, e.g. `ISRU_tanh`, or any function from torch.nn.functional 24 | * `--warmup WARMUP` do not use the loss from the first WARMUP elements of the series in order to let the hidden state warm up. 25 | * `--verbose` 26 | 27 | ## Data generation 28 | 29 | * `sine_1` generates a sine wave of wavelength 60 steps 30 | * `sine_2` overlays `sine_1` with a sine wave of wavelength 120 steps 31 | * `sine_3` overlays `sine_2` with a sine wave of wavelength 180 steps 32 | * `mackey_glass` generates a Mackey-Glass chaotic timeseries using the [signalz](https://matousc89.github.io/signalz/) library 33 | * `levy_flight` generates a Lévy flight process using the [signalz](https://matousc89.github.io/signalz/) library 34 | * `brownian` generates a Brownian random walk using the [signalz](https://matousc89.github.io/signalz/) library 35 | 36 | The generator produces a tensor of shape `(length, batches, 1)` containing `batches` independantly generated series of the required length. 37 | 38 | ## Model generation 39 | 40 | The `--layers` argument takes a simplistic model specification. First you specify the layer type, add a "\_", then if the layer type needs a size, you add a number. Then you can follow up with "\_k=value" for any keyword arguments. If the keyword contains "\_" replace it with "-". 41 | 42 | For example: `--layers LSTM_50 Dropout_p=.5 CausalConv1d_70_kernel-size=3` specifies a three layer network with 50 LSTM units in the first layer, Dropout with p=.5 as the second layer, and 70 CausalConv1d units with kernel_size=3 in the third layer. 43 | 44 | If the output of the last requested layer doesn't match the number of target values (for these experiments the target size is 1) then the script adds a Linear layer to produce the required number of output values. 45 | 46 | ## Layers 47 | 48 | All of these recurrent layers keep track of their own hidden state (if needed, the hidden state is accessible via the `hidden` attribute). They all have methods to `reset_hidden()` and to `detach_hidden()`. 49 | 50 | `reset_hidden()` should be used before feeding the model the start of a new sequence, and `detach_hidden()` can be called in-between batches of the same set of sequences in order to truncate backpropagation through time and thus avoid the slowdown of having to backpropagate through to the beginning of the entire sequence. 51 | 52 | Moreover they all take input of shape `(seq_len, batch_size, features)`. This allows vectorising any calculations that don't depend on the hidden state. 53 | 54 | * LSTM - the typical LSTM adapted from the PyTorch source code for LSTMCell. 55 | * GRU - the typical GRU adapted from the PyTorch source code for GRUCell. 56 | * MGU and variants - from [arxiv:Minimal Gated Unit for Recurrent Neural Networks](https://arxiv.org/abs/1603.09420) and simplified in [arxiv:Simplified Minimal Gated Unit Variations for Recurrent Neural Networks](http://arxiv.org/abs/1701.03452). I have only coded the original MGU and MGU2 variant because they say it is the best of the three. 57 | * RAN - the [arxiv:Recurrent Additive Network](http://arxiv.org/abs/1705.07393) 58 | * SRU - the Simple Recurrent Unit from [arxiv:Training RNNs as fast as CNNs](http://arxiv.org/abs/1709.02755v3). They provide a cuda optimised implementation. This is a simplistic implementation that optionally vectorises the calculation of the gates. The test_sru_optimisation.py script tests the speed of the optimisation when running on the cpu. The difference is negligible. 59 | * CausalConv1d - a wrapper for Conv1d that permutes the input shape to that required by Conv1d, and adds the padding that ensures that each timestep sees no future inputs. 60 | * QRNN - an unoptimised implementation of [arxiv:Quasi-recurrent neural networks](http://arxiv.org/abs/1611.01576v2) The paper makes the QRNN seem rather complex, but when you write out the step equations you see that it is not very different from most other RNNs. 61 | * TRNN - a strongly typed simple RNN from [arxiv:Strongly-Typed Recurrent Neural Networks](https://arxiv.org/abs/1602.02218) 62 | * Chaos Free Network (CFN) from [arxiv:A recurrent neural network without chaos](https://arxiv.org/abs/1612.06212) h_t = f * tanh(h_tm1) + i * tanh(Wx), where f, i = sigmoid(Wx + Vh_tm1 + b) 63 | 64 | ### Planned 65 | 66 | * Strongly typed LSTM and GRU from [arxiv:Strongly-Typed Recurrent Neural Networks](https://arxiv.org/abs/1602.02218) 67 | * Recurrent Identity Network (RIN) from [arxiv:Overcoming the vanishing gradient problem in plain recurrent networks](https://arxiv.org/abs/1801.06105) They converge faster and achieve better accuracy. The basic idea is to initialise the hidden-hidden weights to be ~= 1 rather than ~= 0. 68 | * Phased LSTM from [arxiv:Phased LSTM: Accelerating Recurrent Network Training for Long or Event-based Sequences](https://arxiv.org/abs/1610.09513) The hidden-hidden update is gated by a periodic function which lets updates through only a small percentage of the time. This opens the possibility of irregular updates. 69 | 70 | ### Ideas/research 71 | 72 | * I plan to study [arxiv:Unbiased Online Recurrent Optimization](http://arxiv.org/abs/1702.05043), but for the moment it is not clear to me how best to implement it. 73 | * Optional noisy initial hidden states. Otherwise the model will learn to cope with the fact of having zero initial hidden state which may hinder learning the hidden state dynamics later in the sequences. This probably isn't very important if I have only a few sequences that are very long and that are normalised to zero mean. 74 | * The LSTM class in PyTorch builds a configurable number of identically sized LSTM layers. This architecture allows us to calculate W x h_tm1 for all layers in one single operation. I may try adapting the above layers to take advantage of this. 75 | * Sigmoid activation is typically used for gates, but it is not symetric. [Lets try using tanh instead(https://github.com/jpeg729/pytorch_bits/wiki/Tanh-instead-of-sigmoid). 76 | 77 | ## Optimisers 78 | 79 | COCOB is great for quick experiments, it has a near optimal learning rate with no hyperparameter tuning, so you can quickly tell which experiments are going nowhere. However I suspect that it relies too heavily on assumptions of convexity of the loss. Other optimisers often get lower loss after many epochs. 80 | Adam_HD tunes the learning rate of Adam by backpropagating through the update function. It learns pretty fast too. 81 | 82 | * COCOB - COntinuous COin Betting from [arxiv:Training Deep Networks without Learning Rates Through Coin Betting](https://arxiv.org/abs/1705.07795) 83 | * Adam_HD - Adam with Hypergradient descent from [arxiv:Online Learning Rate Adaptation with Hypergradient Descent](https://arxiv.org/abs/1703.04782) I have set the learning rate's learning rate to 0.1 which is much higher than they recommend, but it works well for the experiments I have run. 84 | 85 | ### Planned 86 | 87 | * SGD_HD - SGD with Hypergradient descent from [arxiv:Online Learning Rate Adaptation with Hypergradient Descent](https://arxiv.org/abs/1703.04782). 88 | * ADINE - ADaptive INErtia from [arxiv:ADINE: An Adaptive Momentum Method for Stochastic Gradient Descent](https://arxiv.org/abs/1712.07424) 89 | 90 | ### Ideas/research 91 | 92 | * Use [arxiv:Safe Mutations for Deep and Recurrent Neural Networks through Output Gradients](https://arxiv.org/abs/1712.06563) to produce random perturbations when training gets stuck. The basic idea involves scaling weight changes according to the gradient of the model's *output* w.r.t. its weights. 93 | 94 | ## Activations 95 | 96 | Note that PyTorch Tanh, Sigmoid and ELU are already very well optimised when run on cpu. My tests show that my simplistic implementation provides little difference when running on cpu. 97 | 98 | * ISRLU - [arxiv:Improving Deep Learning by Inverse Square Root Linear Units (ISRLUs)](https://arxiv.org/abs/1710.09967) An alternative to ELU that is supposed to be faster to calculate. 99 | * ISRU_tanh - from the same paper. A proposed alternative to tanh. 100 | * ISRU_sigmoid - from the same paper. A proposed alternative to sigmoid. 101 | * Bipolar activation wrapper from [arxiv:Shifting Mean Activation Towards Zero with Bipolar Activation Functions](https://arxiv.org/abs/1709.04054) Bipolar(f, x_i) = f(x_i) if i is even else -f(-x_i). The resulting activation has mean zero. Note: initialise weights so that the layer has variance ~= 1. 102 | * The soft exponential activation from [arxiv:A continuum among logarithmic, linear, and exponential functions, and its potential to improve generalization in neural networks](https://arxiv.org/abs/1602.01321) A learnable parameter allows the model to choose between f(x)=log(x), f(x)=x and f(x)=exp(x). Neural networks can easily add values, but can't easily multiply them. With this activation multiplication becomes easy e.g. exp(log(x_0) + log(x_1)). 103 | * TernaryTanh from [the R2RT blog](https://r2rt.com/beyond-binary-ternary-and-one-hot-neurons.html) Tanh is essentially a binary function, it has two stable output zones, 1 and -1. TernaryTanh also has a flat area around zero, enabling a model to choose between 1, -1 and 0 output with greater ease. 104 | 105 | ### Planned 106 | 107 | * [arxiv:Noisy activation functions](https://arxiv.org/abs/1603.00391) are versions of saturating activation functions that add noise when the output is in the saturation zones. 108 | * Learned affine combinations of activations [Learning Combinations of Activation Functions](https://arxiv.org/abs/1801.09403) 109 | 110 | ## Regularisers 111 | 112 | ### Planned 113 | 114 | * DARC1 regularizer from https://arxiv.org/abs/1710.05468 115 | * Spectral norm regulariser from https://arxiv.org/abs/1705.10941 --------------------------------------------------------------------------------