├── figures └── model_config.png ├── experiments ├── MackeyGlass │ ├── attract.dill │ ├── cornn.py │ ├── lmu.py │ ├── MackeyGlass-LSTM-increasing_tau.ipynb │ ├── MackeyGlass-LMU-increasing_tau.ipynb │ ├── MackeyGlass-DeepSITH-increasing_tau.ipynb │ └── MackeyGlass-coRNN-increasing_tau.ipynb ├── AddingProblem │ ├── cornn.py │ ├── lmu.py │ ├── addingproblem-LSTM.ipynb │ ├── addingproblem-LMU.ipynb │ └── addingproblem-coRNN.ipynb ├── Hateful8 │ ├── cornn.py │ ├── lmu.py │ ├── hateful8-LSTM.ipynb │ └── hateful8-LMU.ipynb └── psMNIST │ ├── visualization.ipynb │ ├── sMNIST.ipynb │ └── psMNIST.ipynb ├── deepsith ├── __init__.py ├── deepsith.py └── isith.py ├── setup.py ├── .gitignore └── README.md /figures/model_config.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/compmem/DeepSITH/HEAD/figures/model_config.png -------------------------------------------------------------------------------- /experiments/MackeyGlass/attract.dill: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/compmem/DeepSITH/HEAD/experiments/MackeyGlass/attract.dill -------------------------------------------------------------------------------- /deepsith/__init__.py: -------------------------------------------------------------------------------- 1 | from .isith import iSITH 2 | from .deepsith import DeepSITH 3 | 4 | __copyright__ = "2020, Computational Memory Lab" 5 | __version__ = "0.1" -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | from setuptools import setup 2 | 3 | setup(name='deepsith', 4 | version='0.1', 5 | description='DeepSITH: Scale-Invariant Temporal History Across Hidden Layers', 6 | url='https://github.com/beegica/SITH-Con', 7 | license='Free for non-commercial use', 8 | author='Computational Memory Lab', 9 | author_email='bgj5hk@virginia.edu', 10 | packages=['deepsith'], 11 | install_requires=[ 12 | 'torch>=1.1.0', 13 | 14 | ], 15 | zip_safe=False) -------------------------------------------------------------------------------- /experiments/AddingProblem/cornn.py: -------------------------------------------------------------------------------- 1 | from torch import nn 2 | import torch 3 | from torch.autograd import Variable 4 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 5 | 6 | class coRNNCell(nn.Module): 7 | def __init__(self, n_inp, n_hid, dt, gamma, epsilon): 8 | super(coRNNCell, self).__init__() 9 | self.dt = dt 10 | self.gamma = gamma 11 | self.epsilon = epsilon 12 | self.i2h = nn.Linear(n_inp + n_hid + n_hid, n_hid) 13 | 14 | def forward(self,x,hy,hz): 15 | hz = hz + self.dt * (torch.tanh(self.i2h(torch.cat((x, hz, hy),1))) 16 | - self.gamma * hy - self.epsilon * hz) 17 | hy = hy + self.dt * hz 18 | 19 | return hy, hz 20 | 21 | class coRNN(nn.Module): 22 | def __init__(self, n_inp, n_hid, n_out, dt, gamma, epsilon): 23 | super(coRNN, self).__init__() 24 | self.n_hid = n_hid 25 | self.cell = coRNNCell(n_inp,n_hid,dt,gamma,epsilon) 26 | self.readout = nn.Linear(n_hid, n_out) 27 | 28 | def forward(self, x): 29 | ## initialize hidden states 30 | hy = Variable(torch.zeros(x.size(1),self.n_hid)).to(device) 31 | hz = Variable(torch.zeros(x.size(1),self.n_hid)).to(device) 32 | 33 | for t in range(x.size(0)): 34 | hy, hz = self.cell(x[t],hy,hz) 35 | output = self.readout(hy) 36 | 37 | return output 38 | -------------------------------------------------------------------------------- /experiments/Hateful8/cornn.py: -------------------------------------------------------------------------------- 1 | from torch import nn 2 | import torch 3 | from torch.autograd import Variable 4 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 5 | 6 | class coRNNCell(nn.Module): 7 | def __init__(self, n_inp, n_hid, dt, gamma, epsilon): 8 | super(coRNNCell, self).__init__() 9 | self.dt = dt 10 | self.gamma = gamma 11 | self.epsilon = epsilon 12 | self.i2h = nn.Linear(n_inp + n_hid + n_hid, n_hid) 13 | 14 | def forward(self,x,hy,hz): 15 | hz = hz + self.dt * (torch.tanh(self.i2h(torch.cat((x, hz, hy),1))) 16 | - self.gamma * hy - self.epsilon * hz) 17 | hy = hy + self.dt * hz 18 | 19 | return hy, hz 20 | 21 | class coRNN(nn.Module): 22 | def __init__(self, n_inp, n_hid, n_out, dt, gamma, epsilon): 23 | super(coRNN, self).__init__() 24 | self.n_hid = n_hid 25 | self.cell = coRNNCell(n_inp,n_hid,dt,gamma,epsilon) 26 | self.readout = nn.Linear(n_hid, n_out) 27 | 28 | def forward(self, x): 29 | ## initialize hidden states 30 | hy = Variable(torch.zeros(x.size(1),self.n_hid)).to(device) 31 | hz = Variable(torch.zeros(x.size(1),self.n_hid)).to(device) 32 | 33 | for t in range(x.size(0)): 34 | hy, hz = self.cell(x[t],hy,hz) 35 | print(hy.shape) 36 | output = self.readout(hy) 37 | 38 | return output 39 | -------------------------------------------------------------------------------- /experiments/MackeyGlass/cornn.py: -------------------------------------------------------------------------------- 1 | from torch import nn 2 | import torch 3 | from torch.autograd import Variable 4 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 5 | 6 | class coRNNCell(nn.Module): 7 | def __init__(self, n_inp, n_hid, dt, gamma, epsilon): 8 | super(coRNNCell, self).__init__() 9 | self.dt = dt 10 | self.gamma = gamma 11 | self.epsilon = epsilon 12 | self.i2h = nn.Linear(n_inp + n_hid + n_hid, n_hid) 13 | 14 | def forward(self,x,hy,hz): 15 | hz = hz + self.dt * (torch.tanh(self.i2h(torch.cat((x, hz, hy),1))) 16 | - self.gamma * hy - self.epsilon * hz) 17 | hy = hy + self.dt * hz 18 | 19 | return hy, hz 20 | 21 | class coRNN(nn.Module): 22 | def __init__(self, n_inp, n_hid, n_out, dt, gamma, epsilon): 23 | super(coRNN, self).__init__() 24 | self.n_hid = n_hid 25 | self.cell = coRNNCell(n_inp,n_hid,dt,gamma,epsilon) 26 | self.readout = nn.Linear(n_hid, n_out) 27 | 28 | def forward(self, x): 29 | ## initialize hidden states 30 | hy = Variable(torch.zeros(x.size(1),self.n_hid)).to(device) 31 | hz = Variable(torch.zeros(x.size(1),self.n_hid)).to(device) 32 | 33 | for t in range(x.size(0)): 34 | hy, hz = self.cell(x[t],hy,hz) 35 | print(hy.shape) 36 | output = self.readout(hy) 37 | 38 | return output 39 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | *.png 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | pip-wheel-metadata/ 24 | share/python-wheels/ 25 | *.egg-info/ 26 | .installed.cfg 27 | *.egg 28 | MANIFEST 29 | 30 | # PyInstaller 31 | # Usually these files are written by a python script from a template 32 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 33 | *.manifest 34 | *.spec 35 | 36 | # Installer logs 37 | pip-log.txt 38 | pip-delete-this-directory.txt 39 | 40 | # Unit test / coverage reports 41 | htmlcov/ 42 | .tox/ 43 | .nox/ 44 | .coverage 45 | .coverage.* 46 | .cache 47 | nosetests.xml 48 | coverage.xml 49 | *.cover 50 | *.py,cover 51 | .hypothesis/ 52 | .pytest_cache/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | target/ 76 | 77 | # Jupyter Notebook 78 | .ipynb_checkpoints 79 | 80 | # IPython 81 | profile_default/ 82 | ipython_config.py 83 | 84 | # pyenv 85 | .python-version 86 | 87 | # pipenv 88 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 89 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 90 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 91 | # install all needed dependencies. 92 | #Pipfile.lock 93 | 94 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 95 | __pypackages__/ 96 | 97 | # Celery stuff 98 | celerybeat-schedule 99 | celerybeat.pid 100 | 101 | # SageMath parsed files 102 | *.sage.py 103 | 104 | # Environments 105 | .env 106 | .venv 107 | env/ 108 | venv/ 109 | ENV/ 110 | env.bak/ 111 | venv.bak/ 112 | 113 | # Spyder project settings 114 | .spyderproject 115 | .spyproject 116 | 117 | # Rope project settings 118 | .ropeproject 119 | 120 | # mkdocs documentation 121 | /site 122 | 123 | # mypy 124 | .mypy_cache/ 125 | .dmypy.json 126 | dmypy.json 127 | 128 | # Pyre type checker 129 | .pyre/ 130 | 131 | data/ 132 | perf/ 133 | -------------------------------------------------------------------------------- /deepsith/deepsith.py: -------------------------------------------------------------------------------- 1 | # Deep SITH 2 | # PyTorch version 0.1.0 3 | # Authors: Brandon G. Jacques and Per B. Sederberg 4 | 5 | import torch 6 | from torch import nn 7 | from .isith import iSITH 8 | 9 | 10 | from deepsith import iSITH 11 | class _DeepSITH_core(nn.Module): 12 | def __init__(self, layer_params): 13 | super(_DeepSITH_core, self).__init__() 14 | 15 | hidden_size = layer_params.pop('hidden_size', layer_params['in_features']) 16 | in_features = layer_params.pop('in_features', None) 17 | batch_norm = layer_params.pop('batch_norm', True) 18 | act_func = layer_params.pop('act_func', None) 19 | self.batch_norm = batch_norm 20 | self.act_func = not (act_func is None) 21 | self.sith = iSITH(**layer_params) 22 | 23 | self.linear = nn.Linear(layer_params['ntau']*in_features, 24 | hidden_size) 25 | nn.init.kaiming_normal_(self.linear.weight.data) 26 | 27 | if not (act_func is None): 28 | self.act_func = act_func 29 | if batch_norm: 30 | self.dense_bn = nn.BatchNorm1d(hidden_size) 31 | 32 | def forward(self, inp): 33 | # Outputs as : [Batch, features, tau, sequence] 34 | x = self.sith(inp) 35 | 36 | x = x.transpose(3,2).transpose(2,1) 37 | x = x.view(x.shape[0], x.shape[1], -1) 38 | x = self.linear(x) 39 | if self.act_func: 40 | x = self.act_func(x) 41 | if self.batch_norm: 42 | x = x.transpose(2,1) 43 | x = self.dense_bn(x).transpose(2,1) 44 | return x 45 | 46 | class DeepSITH(nn.Module): 47 | """A Module built for SITH like an LSTM 48 | 49 | Parameters 50 | ---------- 51 | layer_params: list 52 | A list of dictionaries for each layer in the desired DeepSITH. All 53 | of the parameters needed for the SITH part of the Layers, as well as 54 | a hidden_size and optional act_func are required to be present. 55 | 56 | layer_params keys 57 | ----------------- 58 | hidden_size: int (default in_features) 59 | The size of the output of the hidden layer. Please note that the 60 | in_features parameter for the next layer's SITH representation should be 61 | equal to the previous layer's hidden_size. This parameter will default 62 | to the in_features of the current SITH layer if not specified. 63 | act_func: torch.nn.Module (default None) 64 | The torch layer of the desired activation function, or None if no 65 | there is no desired activation function between layers. 66 | 67 | In addition to these keys, you must include all of the non-optional SITH 68 | layer keys in each dictionary. Please see the SITH docstring for 69 | suggestions. 70 | 71 | """ 72 | def __init__(self, layer_params, dropout=.5): 73 | super(DeepSITH, self).__init__() 74 | self.layers = nn.ModuleList([_DeepSITH_core(layer_params[i]) 75 | for i in range(len(layer_params))]) 76 | self.dropouts = nn.ModuleList([nn.Dropout(dropout) for i in range(len(layer_params) - 1)]) 77 | 78 | def forward(self, inp): 79 | x = inp 80 | for i, l in enumerate(self.layers[:-1]): 81 | x = l(x) 82 | x = self.dropouts[i](x) 83 | x = x.unsqueeze(1).transpose(3,2) 84 | x = self.layers[-1](x) 85 | return x -------------------------------------------------------------------------------- /deepsith/isith.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from math import factorial, log 3 | # Impulse-based SITH class 4 | class iSITH(torch.nn.Module): 5 | def __init__(self, tau_min=.1, tau_max=100., buff_max=None, k=50, ntau=50, dt=1, g=0.0, 6 | ttype=torch.FloatTensor): 7 | super(iSITH, self).__init__() 8 | """A SITH module using the perfect equation for the resulting ftilde 9 | 10 | Parameters 11 | ---------- 12 | 13 | - tau_min: float 14 | The center of the temporal receptive field for the first taustar produced. 15 | - tau_max: float 16 | The center of the temporal receptive field for the last taustar produced. 17 | - buff_max: int 18 | The maximum time in which the filters go into the past. NOTE: In order to 19 | achieve as few edge effects as possible, buff_max needs to be bigger than 20 | tau_max, and dependent on k, such that the filters have enough time to reach 21 | very close to 0.0. Plot the filters and you will see them go to 0. 22 | - k: int 23 | Temporal Specificity of the taustars. If this number is high, then taustars 24 | will always be more narrow. 25 | - ntau: int 26 | Number of taustars produced, spread out logarithmically. 27 | - dt: float 28 | The time delta of the model. The there will be int(buff_max/dt) filters per 29 | taustar. Essentially this is the base rate of information being presented to the model 30 | - g: float 31 | Typically between 0 and 1. This parameter is the scaling factor of the output 32 | of the module. If set to 1, the output amplitude for a delta function will be 33 | identical through time. If set to 0, the amplitude will decay into the past, 34 | getting smaller and smaller. This value should be picked on an application to 35 | application basis. 36 | - ttype: Torch Tensor 37 | This is the type we set the internal mechanism of the model to before running. 38 | In order to calculate the filters, we must use a DoubleTensor, but this is no 39 | longer necessary after they are calculated. By default we set the filters to 40 | be FloatTensors. NOTE: If you plan to use CUDA, you need to pass in a 41 | cuda.FloatTensor as the ttype, as using .cuda() will not put these filters on 42 | the gpu. 43 | 44 | 45 | """ 46 | self.k = k 47 | self.tau_min = tau_min 48 | self.tau_max = tau_max 49 | if buff_max is None: 50 | buff_max = 3*tau_max 51 | self.buff_max = buff_max 52 | self.ntau = ntau 53 | self.dt = dt 54 | self.g = g 55 | 56 | self.c = (tau_max/tau_min)**(1./(ntau-1))-1 57 | 58 | self.tau_star = tau_min*(1+self.c)**torch.arange(ntau).type(torch.DoubleTensor) 59 | 60 | self.times = torch.arange(dt, buff_max+dt, dt).type(torch.DoubleTensor) 61 | 62 | a = log(k)*k 63 | b = torch.log(torch.arange(2,k).type(torch.DoubleTensor)).sum() 64 | 65 | #A = ((1/self.tau_star)*(k**(k+1)/factorial(k))*(self.tau_star**self.g)).unsqueeze(1) 66 | A = ((1/self.tau_star)*(torch.exp(a-b))*(self.tau_star**self.g)).unsqueeze(1) 67 | 68 | self.filters = A*torch.exp((torch.log(self.times.unsqueeze(0)/self.tau_star.unsqueeze(1))*(k+1)) + \ 69 | (k*(-self.times.unsqueeze(0)/self.tau_star.unsqueeze(1)))) 70 | 71 | self.filters = torch.flip(self.filters, [-1]).unsqueeze(1).unsqueeze(1) 72 | self.filters = self.filters.type(ttype) 73 | 74 | def extra_repr(self): 75 | s = "ntau={ntau}, tau_min={tau_min}, tau_max={tau_max}, buff_max={buff_max}, dt={dt}, k={k}, g={g}" 76 | s = s.format(**self.__dict__) 77 | return s 78 | 79 | def forward(self, inp): 80 | """Takes in (Batch, 1, features, sequence) and returns (Batch, Taustar, features, sequence)""" 81 | assert(len(inp.shape) >= 4) 82 | out = torch.conv2d(inp, self.filters[:, :, :, -inp.shape[-1]:], 83 | padding=[0, self.filters[:, :, :, -inp.shape[-1]:].shape[-1]]) 84 | #padding=[0, self.filters.shape[-1]]) 85 | # note we're scaling the output by both dt and the k/(k+1) 86 | # Off by 1 introduced by the conv2d 87 | return out[:, :, :, 1:inp.shape[-1]+1]*self.dt*self.k/(self.k+1) -------------------------------------------------------------------------------- /experiments/Hateful8/lmu.py: -------------------------------------------------------------------------------- 1 | """https://github.com/AbdouJaouhar/LMU-Legendre-Memory-Unit""" 2 | 3 | import torch 4 | import torch.nn as nn 5 | from sympy.matrices import Matrix, eye, zeros, ones, diag, GramSchmidt 6 | import numpy as np 7 | from functools import partial 8 | import torch.nn.functional as F 9 | import math 10 | 11 | from nengolib.signal import Identity, cont2discrete 12 | from nengolib.synapses import LegendreDelay 13 | from functools import partial 14 | 15 | ''' 16 | Initialisation LECUN_UNIFOR 17 | - tensor to fill 18 | - fan_in is the input dimension size 19 | ''' 20 | def lecun_uniform(tensor): 21 | fan_in = nn.init._calculate_correct_fan(tensor, 'fan_in') 22 | nn.init.uniform_(tensor, -math.sqrt(3 / fan_in), math.sqrt(3 / fan_in)) 23 | 24 | 25 | class LMUCell(nn.Module): 26 | 27 | def __init__(self, input_size, hidden_size, 28 | order, 29 | theta=100, # relative to dt=1 30 | method='zoh', 31 | trainable_input_encoders=True, 32 | trainable_hidden_encoders=True, 33 | trainable_memory_encoders=True, 34 | trainable_input_kernel=True, 35 | trainable_hidden_kernel=True, 36 | trainable_memory_kernel=True, 37 | trainable_A=False, 38 | trainable_B=False, 39 | input_encoders_initializer=lecun_uniform, 40 | hidden_encoders_initializer=lecun_uniform, 41 | memory_encoders_initializer=partial(torch.nn.init.constant_, val=0), 42 | input_kernel_initializer=torch.nn.init.xavier_normal_, 43 | hidden_kernel_initializer=torch.nn.init.xavier_normal_, 44 | memory_kernel_initializer=torch.nn.init.xavier_normal_, 45 | 46 | hidden_activation='tanh', 47 | ): 48 | super(LMUCell, self).__init__() 49 | 50 | self.input_size = input_size 51 | self.hidden_size = hidden_size 52 | self.order = order 53 | 54 | if hidden_activation == 'tanh': 55 | self.hidden_activation = torch.tanh 56 | elif hidden_activation == 'relu': 57 | self.hidden_activation = torch.relu 58 | else: 59 | raise NotImplementedError("hidden activation '{}' is not implemented".format(hidden_activation)) 60 | 61 | realizer = Identity() 62 | self._realizer_result = realizer( 63 | LegendreDelay(theta=theta, order=self.order)) 64 | self._ss = cont2discrete( 65 | self._realizer_result.realization, dt=1., method=method) 66 | self._A = self._ss.A - np.eye(order) # puts into form: x += Ax 67 | self._B = self._ss.B 68 | self._C = self._ss.C 69 | assert np.allclose(self._ss.D, 0) # proper LTI 70 | 71 | self.input_encoders = nn.Parameter(torch.Tensor(1, input_size), requires_grad=trainable_input_encoders) 72 | self.hidden_encoders = nn.Parameter(torch.Tensor(1, hidden_size), requires_grad=trainable_hidden_encoders) 73 | self.memory_encoders = nn.Parameter(torch.Tensor(1, order), requires_grad=trainable_memory_encoders) 74 | self.input_kernel = nn.Parameter(torch.Tensor(hidden_size, input_size), requires_grad=trainable_input_kernel) 75 | self.hidden_kernel = nn.Parameter(torch.Tensor(hidden_size, hidden_size), requires_grad=trainable_hidden_kernel) 76 | self.memory_kernel = nn.Parameter(torch.Tensor(hidden_size, order), requires_grad=trainable_memory_kernel) 77 | self.AT = nn.Parameter(torch.Tensor(self._A), requires_grad=trainable_A) 78 | self.BT = nn.Parameter(torch.Tensor(self._B), requires_grad=trainable_B) 79 | 80 | # Initialize parameters 81 | input_encoders_initializer(self.input_encoders) 82 | hidden_encoders_initializer(self.hidden_encoders) 83 | memory_encoders_initializer(self.memory_encoders) 84 | input_kernel_initializer(self.input_kernel) 85 | hidden_kernel_initializer(self.hidden_kernel) 86 | memory_kernel_initializer(self.memory_kernel) 87 | 88 | def forward(self, input, hx): 89 | 90 | h, m = hx 91 | 92 | u = (F.linear(input, self.input_encoders) + 93 | F.linear(h, self.hidden_encoders) + 94 | F.linear(m, self.memory_encoders)) 95 | 96 | m = m + F.linear(m, self.AT) + F.linear(u, self.BT) 97 | 98 | h = self.hidden_activation( 99 | F.linear(input, self.input_kernel) + 100 | F.linear(h, self.hidden_kernel) + 101 | F.linear(m, self.memory_kernel)) 102 | 103 | return h, [h, m] 104 | 105 | 106 | class LegendreMemoryUnit(nn.Module): 107 | """ 108 | Implementation of LMU using LegendreMemoryUnitCell so it can be used as LSTM or GRU in PyTorch Implementation (no GPU acceleration) 109 | """ 110 | def __init__(self, input_dim, hidden_size, order, theta): 111 | super(LegendreMemoryUnit, self).__init__() 112 | 113 | self.hidden_size = hidden_size 114 | self.order = order 115 | 116 | self.lmucell = LMUCell(input_dim, hidden_size, order, theta) 117 | 118 | def forward(self, xt): 119 | outputs = [] 120 | 121 | h0 = torch.zeros(xt.size(0),self.hidden_size).cuda() 122 | m0 = torch.zeros(xt.size(0),self.order).cuda() 123 | states = (h0,m0) 124 | for i in range(xt.size(1)): 125 | out, states = self.lmucell(xt[:,i,:], states) 126 | outputs += [out] 127 | return torch.stack(outputs).permute(1,0,2), states 128 | 129 | -------------------------------------------------------------------------------- /experiments/MackeyGlass/lmu.py: -------------------------------------------------------------------------------- 1 | """https://github.com/AbdouJaouhar/LMU-Legendre-Memory-Unit""" 2 | 3 | import torch 4 | import torch.nn as nn 5 | from sympy.matrices import Matrix, eye, zeros, ones, diag, GramSchmidt 6 | import numpy as np 7 | from functools import partial 8 | import torch.nn.functional as F 9 | import math 10 | 11 | from nengolib.signal import Identity, cont2discrete 12 | from nengolib.synapses import LegendreDelay 13 | from functools import partial 14 | 15 | ''' 16 | Initialisation LECUN_UNIFOR 17 | - tensor to fill 18 | - fan_in is the input dimension size 19 | ''' 20 | def lecun_uniform(tensor): 21 | fan_in = nn.init._calculate_correct_fan(tensor, 'fan_in') 22 | nn.init.uniform_(tensor, -math.sqrt(3 / fan_in), math.sqrt(3 / fan_in)) 23 | 24 | 25 | class LMUCell(nn.Module): 26 | 27 | def __init__(self, input_size, hidden_size, 28 | order, 29 | theta=100, # relative to dt=1 30 | method='zoh', 31 | trainable_input_encoders=True, 32 | trainable_hidden_encoders=True, 33 | trainable_memory_encoders=True, 34 | trainable_input_kernel=True, 35 | trainable_hidden_kernel=True, 36 | trainable_memory_kernel=True, 37 | trainable_A=False, 38 | trainable_B=False, 39 | input_encoders_initializer=lecun_uniform, 40 | hidden_encoders_initializer=lecun_uniform, 41 | memory_encoders_initializer=partial(torch.nn.init.constant_, val=0), 42 | input_kernel_initializer=torch.nn.init.xavier_normal_, 43 | hidden_kernel_initializer=torch.nn.init.xavier_normal_, 44 | memory_kernel_initializer=torch.nn.init.xavier_normal_, 45 | 46 | hidden_activation='tanh', 47 | ): 48 | super(LMUCell, self).__init__() 49 | 50 | self.input_size = input_size 51 | self.hidden_size = hidden_size 52 | self.order = order 53 | 54 | if hidden_activation == 'tanh': 55 | self.hidden_activation = torch.tanh 56 | elif hidden_activation == 'relu': 57 | self.hidden_activation = torch.relu 58 | else: 59 | raise NotImplementedError("hidden activation '{}' is not implemented".format(hidden_activation)) 60 | 61 | realizer = Identity() 62 | self._realizer_result = realizer( 63 | LegendreDelay(theta=theta, order=self.order)) 64 | self._ss = cont2discrete( 65 | self._realizer_result.realization, dt=1., method=method) 66 | self._A = self._ss.A - np.eye(order) # puts into form: x += Ax 67 | self._B = self._ss.B 68 | self._C = self._ss.C 69 | assert np.allclose(self._ss.D, 0) # proper LTI 70 | 71 | self.input_encoders = nn.Parameter(torch.Tensor(1, input_size), requires_grad=trainable_input_encoders) 72 | self.hidden_encoders = nn.Parameter(torch.Tensor(1, hidden_size), requires_grad=trainable_hidden_encoders) 73 | self.memory_encoders = nn.Parameter(torch.Tensor(1, order), requires_grad=trainable_memory_encoders) 74 | self.input_kernel = nn.Parameter(torch.Tensor(hidden_size, input_size), requires_grad=trainable_input_kernel) 75 | self.hidden_kernel = nn.Parameter(torch.Tensor(hidden_size, hidden_size), requires_grad=trainable_hidden_kernel) 76 | self.memory_kernel = nn.Parameter(torch.Tensor(hidden_size, order), requires_grad=trainable_memory_kernel) 77 | self.AT = nn.Parameter(torch.Tensor(self._A), requires_grad=trainable_A) 78 | self.BT = nn.Parameter(torch.Tensor(self._B), requires_grad=trainable_B) 79 | 80 | # Initialize parameters 81 | input_encoders_initializer(self.input_encoders) 82 | hidden_encoders_initializer(self.hidden_encoders) 83 | memory_encoders_initializer(self.memory_encoders) 84 | input_kernel_initializer(self.input_kernel) 85 | hidden_kernel_initializer(self.hidden_kernel) 86 | memory_kernel_initializer(self.memory_kernel) 87 | 88 | def forward(self, input, hx): 89 | 90 | h, m = hx 91 | 92 | u = (F.linear(input, self.input_encoders) + 93 | F.linear(h, self.hidden_encoders) + 94 | F.linear(m, self.memory_encoders)) 95 | 96 | m = m + F.linear(m, self.AT) + F.linear(u, self.BT) 97 | 98 | h = self.hidden_activation( 99 | F.linear(input, self.input_kernel) + 100 | F.linear(h, self.hidden_kernel) + 101 | F.linear(m, self.memory_kernel)) 102 | 103 | return h, [h, m] 104 | 105 | 106 | class LegendreMemoryUnit(nn.Module): 107 | """ 108 | Implementation of LMU using LegendreMemoryUnitCell so it can be used as LSTM or GRU in PyTorch Implementation (no GPU acceleration) 109 | """ 110 | def __init__(self, input_dim, hidden_size, order, theta): 111 | super(LegendreMemoryUnit, self).__init__() 112 | 113 | self.hidden_size = hidden_size 114 | self.order = order 115 | 116 | self.lmucell = LMUCell(input_dim, hidden_size, order, theta) 117 | 118 | def forward(self, xt): 119 | outputs = [] 120 | 121 | h0 = torch.zeros(xt.size(0),self.hidden_size).cuda() 122 | m0 = torch.zeros(xt.size(0),self.order).cuda() 123 | states = (h0,m0) 124 | for i in range(xt.size(1)): 125 | out, states = self.lmucell(xt[:,i,:], states) 126 | outputs += [out] 127 | return torch.stack(outputs).permute(1,0,2), states 128 | 129 | -------------------------------------------------------------------------------- /experiments/AddingProblem/lmu.py: -------------------------------------------------------------------------------- 1 | """https://github.com/AbdouJaouhar/LMU-Legendre-Memory-Unit""" 2 | 3 | import torch 4 | import torch.nn as nn 5 | from sympy.matrices import Matrix, eye, zeros, ones, diag, GramSchmidt 6 | import numpy as np 7 | from functools import partial 8 | import torch.nn.functional as F 9 | import math 10 | 11 | from nengolib.signal import Identity, cont2discrete 12 | from nengolib.synapses import LegendreDelay 13 | from functools import partial 14 | 15 | ''' 16 | Initialisation LECUN_UNIFOR 17 | - tensor to fill 18 | - fan_in is the input dimension size 19 | ''' 20 | def lecun_uniform(tensor): 21 | fan_in = nn.init._calculate_correct_fan(tensor, 'fan_in') 22 | nn.init.uniform_(tensor, -math.sqrt(3 / fan_in), math.sqrt(3 / fan_in)) 23 | 24 | 25 | class LMUCell(nn.Module): 26 | 27 | def __init__(self, input_size, hidden_size, 28 | order, 29 | theta=100, # relative to dt=1 30 | method='zoh', 31 | trainable_input_encoders=True, 32 | trainable_hidden_encoders=True, 33 | trainable_memory_encoders=True, 34 | trainable_input_kernel=True, 35 | trainable_hidden_kernel=True, 36 | trainable_memory_kernel=True, 37 | trainable_A=False, 38 | trainable_B=False, 39 | input_encoders_initializer=lecun_uniform, 40 | hidden_encoders_initializer=lecun_uniform, 41 | memory_encoders_initializer=partial(torch.nn.init.constant_, val=0), 42 | input_kernel_initializer=torch.nn.init.xavier_normal_, 43 | hidden_kernel_initializer=torch.nn.init.xavier_normal_, 44 | memory_kernel_initializer=torch.nn.init.xavier_normal_, 45 | 46 | hidden_activation='tanh', 47 | ): 48 | super(LMUCell, self).__init__() 49 | 50 | self.input_size = input_size 51 | self.hidden_size = hidden_size 52 | self.order = order 53 | 54 | if hidden_activation == 'tanh': 55 | self.hidden_activation = torch.tanh 56 | elif hidden_activation == 'relu': 57 | self.hidden_activation = torch.relu 58 | else: 59 | raise NotImplementedError("hidden activation '{}' is not implemented".format(hidden_activation)) 60 | 61 | realizer = Identity() 62 | self._realizer_result = realizer( 63 | LegendreDelay(theta=theta, order=self.order)) 64 | self._ss = cont2discrete( 65 | self._realizer_result.realization, dt=1., method=method) 66 | self._A = self._ss.A - np.eye(order) # puts into form: x += Ax 67 | self._B = self._ss.B 68 | self._C = self._ss.C 69 | assert np.allclose(self._ss.D, 0) # proper LTI 70 | 71 | self.input_encoders = nn.Parameter(torch.Tensor(1, input_size), requires_grad=trainable_input_encoders) 72 | self.hidden_encoders = nn.Parameter(torch.Tensor(1, hidden_size), requires_grad=trainable_hidden_encoders) 73 | self.memory_encoders = nn.Parameter(torch.Tensor(1, order), requires_grad=trainable_memory_encoders) 74 | self.input_kernel = nn.Parameter(torch.Tensor(hidden_size, input_size), requires_grad=trainable_input_kernel) 75 | self.hidden_kernel = nn.Parameter(torch.Tensor(hidden_size, hidden_size), requires_grad=trainable_hidden_kernel) 76 | self.memory_kernel = nn.Parameter(torch.Tensor(hidden_size, order), requires_grad=trainable_memory_kernel) 77 | self.AT = nn.Parameter(torch.Tensor(self._A), requires_grad=trainable_A) 78 | self.BT = nn.Parameter(torch.Tensor(self._B), requires_grad=trainable_B) 79 | 80 | # Initialize parameters 81 | input_encoders_initializer(self.input_encoders) 82 | hidden_encoders_initializer(self.hidden_encoders) 83 | memory_encoders_initializer(self.memory_encoders) 84 | input_kernel_initializer(self.input_kernel) 85 | hidden_kernel_initializer(self.hidden_kernel) 86 | memory_kernel_initializer(self.memory_kernel) 87 | 88 | def forward(self, input, hx): 89 | 90 | h, m = hx 91 | 92 | u = (F.linear(input, self.input_encoders) + 93 | F.linear(h, self.hidden_encoders) + 94 | F.linear(m, self.memory_encoders)) 95 | 96 | m = m + F.linear(m, self.AT) + F.linear(u, self.BT) 97 | 98 | h = self.hidden_activation( 99 | F.linear(input, self.input_kernel) + 100 | F.linear(h, self.hidden_kernel) + 101 | F.linear(m, self.memory_kernel)) 102 | 103 | return h, [h, m] 104 | 105 | 106 | class LegendreMemoryUnit(nn.Module): 107 | """ 108 | Implementation of LMU using LegendreMemoryUnitCell so it can be used as LSTM or GRU in PyTorch Implementation (no GPU acceleration) 109 | """ 110 | def __init__(self, input_dim, hidden_size, order, theta): 111 | super(LegendreMemoryUnit, self).__init__() 112 | 113 | self.hidden_size = hidden_size 114 | self.order = order 115 | 116 | self.lmucell = LMUCell(input_dim, hidden_size, order, theta) 117 | 118 | def forward(self, xt): 119 | outputs = [] 120 | 121 | h0 = torch.zeros(xt.size(0),self.hidden_size).cuda() 122 | m0 = torch.zeros(xt.size(0),self.order).cuda() 123 | states = (h0,m0) 124 | for i in range(xt.size(1)): 125 | out, states = self.lmucell(xt[:,i,:], states) 126 | outputs += [out] 127 | return torch.stack(outputs).permute(1,0,2), states 128 | 129 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 |

2 | DeepSITH: Efficient Learning via Decomposition of 3 | What and When Across Time Scales 4 |

5 | 6 |

7 | Overview | 8 | Installation | 9 | DeepSITH | 10 | Examples 11 |

12 | 13 | 14 | ## Overview 15 | 16 | ![DeepSITHLayout](/figures/model_config.png) 17 | 18 | Here, we introduce DeepSITH, a network comprising biologically-inspired Scale-Invariant Temporal History (SITH) modules in series with dense connections between layers. SITH modules respond to their inputs with a geometrically-spaced set of time constants, enabling the DeepSITH network to learn problems along a continuum of time-scales. 19 | 20 | ## Installation 21 | 22 | The easiest way to install the DeepSITH module is with pip. 23 | 24 | pip install . 25 | 26 | ### Requirements 27 | 28 | DeepSITH requires at least PyTorch 1.8.1. It works with cuda, so please follow the instructions for installing pytorch and cuda here. 29 | 30 | ## DeepSITH 31 | DeepSITH is a pytorch module implementing the neurally inspired SITH representation of working memory for use in neural networks. The paper outlining the work detailed in this repository was published at NeurIPS 2021 here. 32 | 33 | Jacques, B., Tiganj, Z., Howard, M., & Sederberg, P. (2021, December 6). DeepSITH: Efficient learning via decomposition of what and when across Time Scales. Advances in Neural Information Processing Systems. 34 | 35 | Primarily, this module utilizes SITH, the Scale-Invariant Temporal History, representation. With SITH, we are able to compress the history of a time series in the same way that human working memory might. For more information, please refer to the paper. 36 | 37 | ### DeepSITH use 38 | 39 | The DeepSITH module in pytorch will initialize as a series of deepsith layers, parameterized by the argument `layer_params`, which is a list of dictionaries. Below is an example initializing a 2 layer DeepSITH module, where the input time-series only has 1 feature. 40 | 41 | from deepsith import DeepSITH 42 | from torch import nn as nn 43 | 44 | # Tensor Type. Use torch.cuda.FloatTensor to put all SITH math 45 | # on the GPU. 46 | ttype = torch.FloatTensor 47 | 48 | sith_params1 = {"in_features":1, 49 | "tau_min":1, "tau_max":25.0, 'buff_max':40, 50 | "k":84, 'dt':1, "ntau":15, 'g':.0, 51 | "ttype":ttype, 'batch_norm':True, 52 | "hidden_size":35, "act_func":nn.ReLU()} 53 | sith_params2 = {"in_features":sith_params1['hidden_size'], 54 | "tau_min":1, "tau_max":100.0, 'buff_max':175, 55 | "k":40, 'dt':1, "ntau":15, 'g':.0, 56 | "ttype":ttype, 'batch_norm':True, 57 | "hidden_size":35, "act_func":nn.ReLU()} 58 | lp = [sith_params1, sith_params2] 59 | deepsith_layers = DeepSITH(layer_params=lp, dropout=0.2) 60 | 61 | Here, we have the first layer only having 15 taustar from `tau_min=1.0` to `tau_max=25`. The second layer is set up to go from `1.0` to `100.0`, which gives it 4 times the temporal range. We found that the logarithmic increase of layer sizes to work well for the experiments in this repository. 62 | 63 | The DeepSITH module expects an input signal of size (batch_size, 1, sith_params1["in_features"], Time). 64 | 65 | If you want to use **only** the SITH module, which is a part of any DeepSITH layer, you can initialize a SITH using the following parameters. Note, these parameters are also used in the dictionaries above. 66 | 67 | #### SITH Parameters 68 | - tau_min: float 69 | The center of the temporal receptive field for the first taustar produced. 70 | - tau_max: float 71 | The center of the temporal receptive field for the last taustar produced. 72 | - buff_max: int 73 | The maximum time in which the filters go into the past. NOTE: In order to 74 | achieve as few edge effects as possible, buff_max needs to be bigger than 75 | tau_max, and dependent on k, such that the filters have enough time to reach 76 | very close to 0.0. Plot the filters and you will see them go to 0. 77 | - k: int 78 | Temporal Specificity of the taustars. If this number is high, then taustars 79 | will always be more narrow. 80 | - ntau: int 81 | Number of taustars produced, spread out logarithmically. 82 | - dt: float 83 | The time delta of the model. There will be int(buff_max/dt) filters per 84 | taustar. Essentially this is the base rate of information being presented to the model 85 | - g: float 86 | Typically between 0 and 1. This parameter is the scaling factor of the output 87 | of the module. If set to 1, the output amplitude for a delta function will be 88 | identical through time. If set to 0, the amplitude will decay into the past, 89 | getting smaller and smaller. This value should be picked on an application to 90 | application basis. 91 | - ttype: Torch Tensor 92 | This is the type we set the internal mechanism of the model to before running. 93 | In order to calculate the filters, we must use a DoubleTensor, but this is no 94 | longer necessary after they are calculated. By default we set the filters to 95 | be FloatTensors. NOTE: If you plan to use CUDA, you need to pass in a 96 | cuda.FloatTensor as the ttype, as using .cuda() will not put these filters on 97 | the gpu. 98 | 99 | Initializing SITH will generate several attributes that depend heavily on the values of the parameters. 100 | 101 | - c: float 102 | `c = (tau_max/tau_min)**(1./(ntau-1))-1`. This is the description of how the distance between 103 | taustars evolves. 104 | - tau_star: DoubleTensor 105 | `tau_star = tau_min*(1+c)**torch.arange(ntau)`. This is the array filled with all of the 106 | centers of all the tau_star receptive fields. 107 | - filters: ttype 108 | The generated convolutional filters to generate SITH output. Will be applied as a convolution 109 | to the input time-series. 110 | 111 | Importantly, this module should be socketed into a larger pytorch model, where the final layers are transforming the last output of the **SITH->Dense Layer** into the shape of the output required for a particular task. We, for instance, use a DeepSITH_Classifier model for most of the tasks within this repository. 112 | 113 | class DeepSITH_Classifier(nn.Module): 114 | def __init__(self, out_features, layer_params, dropout=.5): 115 | super(DeepSITH_Classifier, self).__init__() 116 | last_hidden = layer_params[-1]['hidden_size'] 117 | self.hs = DeepSITH(layer_params=layer_params, dropout=dropout) 118 | self.to_out = nn.Linear(last_hidden, out_features) 119 | def forward(self, inp): 120 | x = self.hs(inp) 121 | x = self.to_out(x) 122 | return x 123 | 124 | ## Examples 125 | 126 | In the `experiments` folder are the experiments that were included in the paper. Everything to recreate the results therein is included. Everything is in jupyter notebooks. We have also included everything needed to recreate the figures from the paper, but with your results if you change file names around. 127 | 128 | 129 | -------------------------------------------------------------------------------- /experiments/psMNIST/visualization.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": null, 6 | "metadata": { 7 | "ExecuteTime": { 8 | "end_time": "2021-05-12T18:22:42.830086Z", 9 | "start_time": "2021-05-12T18:22:42.658411Z" 10 | } 11 | }, 12 | "outputs": [], 13 | "source": [ 14 | "%matplotlib inline" 15 | ] 16 | }, 17 | { 18 | "cell_type": "code", 19 | "execution_count": null, 20 | "metadata": { 21 | "ExecuteTime": { 22 | "end_time": "2021-05-12T18:22:43.466784Z", 23 | "start_time": "2021-05-12T18:22:42.878529Z" 24 | } 25 | }, 26 | "outputs": [], 27 | "source": [ 28 | "import matplotlib.pylab as plt\n", 29 | "from matplotlib import gridspec\n", 30 | "import csv\n", 31 | "import numpy as np\n", 32 | "import pandas as pd\n", 33 | "import os\n", 34 | "import torch\n", 35 | "from torchvision import transforms\n", 36 | "from torchvision import datasets\n", 37 | "import seaborn as sn\n", 38 | "sn.set_context('poster')" 39 | ] 40 | }, 41 | { 42 | "cell_type": "code", 43 | "execution_count": null, 44 | "metadata": { 45 | "ExecuteTime": { 46 | "end_time": "2021-05-12T18:22:43.492461Z", 47 | "start_time": "2021-05-12T18:22:43.469951Z" 48 | } 49 | }, 50 | "outputs": [], 51 | "source": [ 52 | "dats = pd.read_csv(os.path.join('perf', 'smnist_deepsith_11.csv'))\n", 53 | "dats.columns = ['loss', 'test_perf', 'epoch', 'presnum', 'perf']\n", 54 | "maxpres = 60000\n", 55 | "dats['presnum_epoch'] = ((dats.presnum*64) + maxpres*dats.epoch)/maxpres\n", 56 | "test_dats = pd.read_csv(\"perf/smnist_deepsith_test_11.csv\")\n", 57 | "test_dats['epoch'] = np.arange(test_dats.shape[0]) + 1\n", 58 | "test_dats['test'] = test_dats['0']" 59 | ] 60 | }, 61 | { 62 | "cell_type": "code", 63 | "execution_count": null, 64 | "metadata": { 65 | "ExecuteTime": { 66 | "end_time": "2021-05-12T18:22:43.509426Z", 67 | "start_time": "2021-05-12T18:22:43.493696Z" 68 | } 69 | }, 70 | "outputs": [], 71 | "source": [ 72 | "dato = pd.read_csv(os.path.join('perf', 'pmnist_deepsith_78.csv'))\n", 73 | "dato.columns = ['loss', 'epoch', 'presnum', 'perf']\n", 74 | "maxpres = 60000\n", 75 | "dato['presnum_epoch'] = ((dato.presnum*64) + maxpres*dato.epoch)/maxpres\n", 76 | "test_dato = pd.read_csv(\"perf/pmnist_deesith_test_78.csv\")\n", 77 | "test_dato.epoch = test_dato.epoch+1\n", 78 | "test_dato.head()" 79 | ] 80 | }, 81 | { 82 | "cell_type": "code", 83 | "execution_count": null, 84 | "metadata": { 85 | "ExecuteTime": { 86 | "end_time": "2021-05-12T18:23:02.475471Z", 87 | "start_time": "2021-05-12T18:23:02.391851Z" 88 | } 89 | }, 90 | "outputs": [], 91 | "source": [ 92 | "norm = transforms.Normalize((.1307,), (.3081,), )\n", 93 | "batch_size = 400\n", 94 | "transform = transforms.Compose([transforms.ToTensor(),\n", 95 | " transforms.Normalize((.1307,), (.3081,))\n", 96 | " ])\n", 97 | "ds1 = datasets.MNIST('../data', train=True, download=True, transform=transform)\n", 98 | "train_loader=torch.utils.data.DataLoader(ds1,batch_size=batch_size, \n", 99 | " num_workers=1, pin_memory=True, shuffle=True)\n" 100 | ] 101 | }, 102 | { 103 | "cell_type": "code", 104 | "execution_count": null, 105 | "metadata": { 106 | "ExecuteTime": { 107 | "end_time": "2021-05-12T18:22:49.041702Z", 108 | "start_time": "2021-05-12T18:22:49.032510Z" 109 | } 110 | }, 111 | "outputs": [], 112 | "source": [ 113 | "# Same seed and supposed Permutation as the coRNN paper\n", 114 | "torch.manual_seed(12008)\n", 115 | "permute = torch.randperm(784)" 116 | ] 117 | }, 118 | { 119 | "cell_type": "code", 120 | "execution_count": null, 121 | "metadata": { 122 | "ExecuteTime": { 123 | "end_time": "2021-05-12T18:22:49.856527Z", 124 | "start_time": "2021-05-12T18:22:49.847567Z" 125 | } 126 | }, 127 | "outputs": [], 128 | "source": [ 129 | "dat = next(enumerate(train_loader))[1]\n", 130 | "dat[0].shape, dat[1].shape" 131 | ] 132 | }, 133 | { 134 | "cell_type": "code", 135 | "execution_count": null, 136 | "metadata": { 137 | "ExecuteTime": { 138 | "end_time": "2021-05-12T18:22:54.876834Z", 139 | "start_time": "2021-05-12T18:22:54.867463Z" 140 | } 141 | }, 142 | "outputs": [], 143 | "source": [ 144 | "fig_dat = []\n", 145 | "for i in range(10):\n", 146 | " fig_dat.append(dat[0][dat[1] == i][:10])\n", 147 | "fig_dat = torch.cat(fig_dat, dim=0)\n", 148 | "fig_dat = fig_dat.view(100,-1)#[:, permute]\n" 149 | ] 150 | }, 151 | { 152 | "cell_type": "code", 153 | "execution_count": null, 154 | "metadata": { 155 | "ExecuteTime": { 156 | "end_time": "2021-05-12T18:23:07.689070Z", 157 | "start_time": "2021-05-12T18:23:06.492357Z" 158 | } 159 | }, 160 | "outputs": [], 161 | "source": [ 162 | "linew = 4\n", 163 | "with sn.plotting_context(\"notebook\", font_scale=2.8):\n", 164 | " fig2 = plt.figure(figsize=(20,18), )\n", 165 | " spec2 = gridspec.GridSpec(nrows=4, ncols=2, wspace=0.05, figure=fig2)\n", 166 | "\n", 167 | "\n", 168 | " ax = fig2.add_subplot(spec2[-2, 1])\n", 169 | " #fig, ax= plt.subplots(2,2,sharex='col', figsize=(12,10), sharey='row', )\n", 170 | " sn.lineplot(data=dato, x=dato.presnum_epoch, y='perf', ax=ax, linewidth=linew,\n", 171 | " color='darkblue', )\n", 172 | " sn.lineplot(data=test_dato, x='epoch', y='test', ax=ax, linewidth=linew,\n", 173 | " )\n", 174 | " ax.grid(True)\n", 175 | " ax.legend([\"Training\", \"Test\", \n", 176 | " ],loc='lower right')\n", 177 | " ax.set_ylabel('Accuracy')\n", 178 | " ax.set_xlabel('')\n", 179 | " ax.yaxis.tick_right()\n", 180 | " ax.yaxis.set_label_position(\"right\")\n", 181 | " plt.setp(ax.get_xticklabels(), visible=False)\n", 182 | " ax.set_xlim(0,90)#)\n", 183 | " ax.set_ylim(.99, 1.0005)\n", 184 | "\n", 185 | " ax = fig2.add_subplot(spec2[-1, 1], sharex=ax)\n", 186 | " sn.lineplot(data=dato, x=dato.presnum_epoch, y='loss', ax=ax, linewidth=linew,\n", 187 | " color='darkblue', legend=False)\n", 188 | " ax.set_ylabel('Loss')\n", 189 | "\n", 190 | " ax.set_xlabel('Epoch')\n", 191 | " ax.yaxis.tick_right()\n", 192 | " \n", 193 | " ax.yaxis.set_label_position(\"right\")\n", 194 | " ax.set_ylim(0,.01)#)\n", 195 | " ax.grid(True)\n", 196 | "\n", 197 | "\n", 198 | " \n", 199 | " ax = fig2.add_subplot(spec2[-2, 0])\n", 200 | " sn.lineplot(data=dats, x=dats.presnum_epoch, y='perf', ax=ax, linewidth=linew,\n", 201 | " color='darkblue', )\n", 202 | " sn.lineplot(data=test_dats, x='epoch', y='test', ax=ax, linewidth=linew,\n", 203 | " )\n", 204 | " ax.legend([\"Training\", \"Test\", \n", 205 | " ],loc='lower right')\n", 206 | " plt.setp(ax.get_xticklabels(), visible=False)\n", 207 | " ax.grid(True)\n", 208 | " ax.set_ylabel('Accuracy')\n", 209 | " ax.set_xlabel('')\n", 210 | " ax.set_xlim(0,60)#)\n", 211 | " ax.set_ylim(.99, 1.0005)\n", 212 | "\n", 213 | " ax = fig2.add_subplot(spec2[-1, 0], sharex=ax)\n", 214 | " sn.lineplot(data=dats, x=dats.presnum_epoch, y='loss', ax=ax, linewidth=linew,\n", 215 | " color='darkblue', legend=False)\n", 216 | " ax.set_ylabel('Loss')\n", 217 | " ax.set_xlabel('Epoch')\n", 218 | " ax.set_ylim(0,.01)#)\n", 219 | " ax.grid(True)\n", 220 | "\n", 221 | "\n", 222 | " \n", 223 | " ax = fig2.add_subplot(spec2[:-2, 0])\n", 224 | " ax.imshow(fig_dat.detach().cpu().numpy(), aspect='auto')\n", 225 | " ax.tick_params(axis=u'both', which=u'both',length=0)\n", 226 | " ax.set_yticks(np.arange(10,100,10))\n", 227 | " ax.set_xticks([])\n", 228 | " plt.setp(ax.get_yticklabels(), visible=False)\n", 229 | " ax.grid(True)\n", 230 | " ax.set_title('sMNIST')\n", 231 | "\n", 232 | " ax = fig2.add_subplot(spec2[:-2, 1])\n", 233 | " ax.imshow(fig_dat[:, permute].detach().cpu().numpy(), aspect='auto')\n", 234 | " ax.tick_params(axis=u'both', which=u'both',length=0)\n", 235 | " ax.set_yticks(np.arange(10,100,10))\n", 236 | " ax.set_xticks([])\n", 237 | " ax.yaxis.tick_right()\n", 238 | " plt.setp(ax.get_yticklabels(), visible=False)\n", 239 | " ax.grid(True)\n", 240 | " ax.set_title('psMNIST')\n", 241 | "\n", 242 | " plt.savefig('MNIST.pdf',\n", 243 | " bbox='tight',\n", 244 | " edgecolor=fig2.get_edgecolor(),\n", 245 | " facecolor=fig2.get_facecolor(),\n", 246 | " dpi=150\n", 247 | " )\n", 248 | " plt.savefig('MNIST.svg',\n", 249 | " bbox='tight',\n", 250 | " edgecolor=fig2.get_edgecolor(),\n", 251 | " facecolor=fig2.get_facecolor(),\n", 252 | " dpi=150\n", 253 | " )\n", 254 | " \n", 255 | "\n", 256 | "\n", 257 | "\n" 258 | ] 259 | }, 260 | { 261 | "cell_type": "code", 262 | "execution_count": null, 263 | "metadata": {}, 264 | "outputs": [], 265 | "source": [] 266 | } 267 | ], 268 | "metadata": { 269 | "kernelspec": { 270 | "display_name": "Python 3", 271 | "language": "python", 272 | "name": "python3" 273 | }, 274 | "language_info": { 275 | "codemirror_mode": { 276 | "name": "ipython", 277 | "version": 3 278 | }, 279 | "file_extension": ".py", 280 | "mimetype": "text/x-python", 281 | "name": "python", 282 | "nbconvert_exporter": "python", 283 | "pygments_lexer": "ipython3", 284 | "version": "3.6.10" 285 | }, 286 | "toc": { 287 | "nav_menu": {}, 288 | "number_sections": true, 289 | "sideBar": true, 290 | "skip_h1_title": false, 291 | "title_cell": "Table of Contents", 292 | "title_sidebar": "Contents", 293 | "toc_cell": false, 294 | "toc_position": {}, 295 | "toc_section_display": true, 296 | "toc_window_display": false 297 | } 298 | }, 299 | "nbformat": 4, 300 | "nbformat_minor": 4 301 | } 302 | -------------------------------------------------------------------------------- /experiments/MackeyGlass/MackeyGlass-LSTM-increasing_tau.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": null, 6 | "metadata": { 7 | "ExecuteTime": { 8 | "end_time": "2021-01-31T16:16:49.475627Z", 9 | "start_time": "2021-01-31T16:16:49.297073Z" 10 | } 11 | }, 12 | "outputs": [], 13 | "source": [ 14 | "%matplotlib inline" 15 | ] 16 | }, 17 | { 18 | "cell_type": "code", 19 | "execution_count": null, 20 | "metadata": { 21 | "ExecuteTime": { 22 | "end_time": "2021-01-31T16:16:50.214114Z", 23 | "start_time": "2021-01-31T16:16:49.515165Z" 24 | } 25 | }, 26 | "outputs": [], 27 | "source": [ 28 | "import torch\n", 29 | "import matplotlib.pyplot as plt\n", 30 | "import torchvision\n", 31 | "import nengolib\n", 32 | "import numpy as np\n", 33 | "import torch.nn as nn\n", 34 | "import torch.nn.functional as F\n", 35 | "from tqdm import tqdm_notebook\n", 36 | "import PIL\n", 37 | "from torch.nn.utils import weight_norm\n", 38 | "\n", 39 | "from os.path import join\n", 40 | "import scipy.special\n", 41 | "import pandas as pd\n", 42 | "import seaborn as sn\n", 43 | "import scipy\n", 44 | "from scipy.spatial.distance import euclidean\n", 45 | "from scipy.interpolate import interp1d\n", 46 | "from tqdm.notebook import tqdm\n", 47 | "import random\n", 48 | "from csv import DictWriter\n", 49 | "# if gpu is to be used\n", 50 | "use_cuda = torch.cuda.is_available()\n", 51 | "FloatTensor = torch.cuda.FloatTensor if use_cuda else torch.FloatTensor\n", 52 | "DoubleTensor = torch.cuda.DoubleTensor if use_cuda else torch.DoubleTensor\n", 53 | "\n", 54 | "IntTensor = torch.cuda.IntTensor if use_cuda else torch.IntTensor\n", 55 | "LongTensor = torch.cuda.LongTensor if use_cuda else torch.LongTensor\n", 56 | "ByteTensor = torch.cuda.ByteTensor if use_cuda else torch.ByteTensor\n", 57 | "ttype = FloatTensor\n", 58 | "\n", 59 | "import seaborn as sns\n", 60 | "print(use_cuda)\n", 61 | "import pickle\n" 62 | ] 63 | }, 64 | { 65 | "cell_type": "markdown", 66 | "metadata": {}, 67 | "source": [ 68 | "# Load Stimuli" 69 | ] 70 | }, 71 | { 72 | "cell_type": "code", 73 | "execution_count": null, 74 | "metadata": { 75 | "ExecuteTime": { 76 | "end_time": "2021-01-31T16:16:52.360839Z", 77 | "start_time": "2021-01-31T16:16:52.350178Z" 78 | } 79 | }, 80 | "outputs": [], 81 | "source": [ 82 | "import collections\n", 83 | "\n", 84 | "def mackey_glass(sample_len=1000, tau=17, delta_t=10, seed=None, n_samples=1):\n", 85 | " # Adapted from https://github.com/mila-iqia/summerschool2015/blob/master/rnn_tutorial/synthetic.py\n", 86 | " '''\n", 87 | " mackey_glass(sample_len=1000, tau=17, seed = None, n_samples = 1) -> input\n", 88 | " Generate the Mackey Glass time-series. Parameters are:\n", 89 | " - sample_len: length of the time-series in timesteps. Default is 1000.\n", 90 | " - tau: delay of the MG - system. Commonly used values are tau=17 (mild \n", 91 | " chaos) and tau=30 (moderate chaos). Default is 17.\n", 92 | " - seed: to seed the random generator, can be used to generate the same\n", 93 | " timeseries at each invocation.\n", 94 | " - n_samples : number of samples to generate\n", 95 | " '''\n", 96 | " history_len = tau * delta_t \n", 97 | " # Initial conditions for the history of the system\n", 98 | " timeseries = 1.2\n", 99 | " \n", 100 | " if seed is not None:\n", 101 | " np.random.seed(seed)\n", 102 | "\n", 103 | " samples = []\n", 104 | "\n", 105 | " for _ in range(n_samples):\n", 106 | " history = collections.deque(1.2 * np.ones(history_len) + 0.2 * \\\n", 107 | " (np.random.rand(history_len) - 0.5))\n", 108 | " # Preallocate the array for the time-series\n", 109 | " inp = np.zeros((sample_len,1))\n", 110 | " \n", 111 | " for timestep in range(sample_len):\n", 112 | " for _ in range(delta_t):\n", 113 | " xtau = history.popleft()\n", 114 | " history.append(timeseries)\n", 115 | " timeseries = history[-1] + (0.2 * xtau / (1.0 + xtau ** 10) - \\\n", 116 | " 0.1 * history[-1]) / delta_t\n", 117 | " inp[timestep] = timeseries\n", 118 | " \n", 119 | " # Squash timeseries through tanh\n", 120 | " inp = np.tanh(inp - 1)\n", 121 | " samples.append(inp)\n", 122 | " return samples\n", 123 | "\n", 124 | "\n", 125 | "def generate_data(n_batches, length, split=0.5, seed=0,\n", 126 | " predict_length=15, tau=17, washout=100, delta_t=1,\n", 127 | " center=True):\n", 128 | " X = np.asarray(mackey_glass(\n", 129 | " sample_len=length+predict_length+washout, tau=tau,\n", 130 | " seed=seed, n_samples=n_batches))\n", 131 | " X = X[:, washout:, :]\n", 132 | " cutoff = int(split*n_batches)\n", 133 | " if center:\n", 134 | " X -= np.mean(X) # global mean over all batches, approx -0.066\n", 135 | " Y = X[:, predict_length:, :]\n", 136 | " X = X[:, :-predict_length, :]\n", 137 | " assert X.shape == Y.shape\n", 138 | " return ((X[:cutoff], Y[:cutoff]),\n", 139 | " (X[cutoff:], Y[cutoff:]))" 140 | ] 141 | }, 142 | { 143 | "cell_type": "code", 144 | "execution_count": null, 145 | "metadata": { 146 | "ExecuteTime": { 147 | "end_time": "2020-12-28T14:39:38.387303Z", 148 | "start_time": "2020-12-28T14:39:36.507391Z" 149 | } 150 | }, 151 | "outputs": [], 152 | "source": [ 153 | "(train_X, train_Y), (test_X, test_Y) = generate_data(2, 5000)\n", 154 | "dataset = torch.utils.data.TensorDataset(torch.Tensor(train_X).cuda(), torch.Tensor(train_Y).cuda())\n", 155 | "dataset = torch.utils.data.DataLoader(dataset, batch_size=32, shuffle=True)\n", 156 | "\n", 157 | "dataset_valid = torch.utils.data.TensorDataset(torch.Tensor(test_X).cuda(), torch.Tensor(test_Y).cuda())\n", 158 | "dataset_valid = torch.utils.data.DataLoader(dataset_valid, batch_size=64, shuffle=False)\n", 159 | "\n", 160 | "print(train_X.shape, train_Y.shape, test_X.shape)" 161 | ] 162 | }, 163 | { 164 | "cell_type": "markdown", 165 | "metadata": {}, 166 | "source": [ 167 | "## Setup for Model" 168 | ] 169 | }, 170 | { 171 | "cell_type": "code", 172 | "execution_count": null, 173 | "metadata": { 174 | "ExecuteTime": { 175 | "end_time": "2021-01-31T16:17:16.816497Z", 176 | "start_time": "2021-01-31T16:17:16.806348Z" 177 | } 178 | }, 179 | "outputs": [], 180 | "source": [ 181 | "\n", 182 | "def train(model, ttype, train_loader, test_loader, optimizer, loss_func, epoch, perf_file,\n", 183 | " loss_buffer_size=800, batch_size=4, device='cuda',\n", 184 | " prog_bar=None, tau=17, pred_len=15):\n", 185 | " \n", 186 | " assert(loss_buffer_size%batch_size==0)\n", 187 | " \n", 188 | " losses = []\n", 189 | " last_test_perf = 0\n", 190 | " best_test_perf = -1\n", 191 | " \n", 192 | " for batch_idx, (data, target) in enumerate(train_loader):\n", 193 | " model.train()\n", 194 | " data = data.to(device).transpose(1,0)\n", 195 | " target = target.to(device)\n", 196 | " optimizer.zero_grad()\n", 197 | " out = model(data)\n", 198 | " loss = loss_func(out,\n", 199 | " target)\n", 200 | "\n", 201 | " loss.backward()\n", 202 | " optimizer.step()\n", 203 | "\n", 204 | " losses.append(loss.detach().cpu().numpy())\n", 205 | " losses = losses[int(-loss_buffer_size/batch_size):]\n", 206 | " \n", 207 | " if ((batch_idx*batch_size)%loss_buffer_size == 0):\n", 208 | " loss_track = {}\n", 209 | " last_test_perf = test_model(model, 'cuda', test_loader, \n", 210 | " )\n", 211 | " loss_track['avg_loss'] = np.mean(losses)\n", 212 | " loss_track['last_test'] = last_test_perf\n", 213 | " loss_track['epoch'] = epoch\n", 214 | " loss_track['batch_idx'] = batch_idx\n", 215 | " loss_track['tau'] = tau\n", 216 | " loss_track['pred_len'] = pred_len\n", 217 | "\n", 218 | " with open(perf_file, 'a+') as fp:\n", 219 | " csv_writer = DictWriter(fp, fieldnames=list(loss_track.keys()))\n", 220 | " if fp.tell() == 0:\n", 221 | " csv_writer.writeheader()\n", 222 | " csv_writer.writerow(loss_track)\n", 223 | " fp.flush()\n", 224 | " if best_test_perf < last_test_perf:\n", 225 | " torch.save(model.state_dict(), perf_file[:-4]+\".pt\")\n", 226 | " best_test_perf = last_test_perf\n", 227 | " if not (prog_bar is None):\n", 228 | " # Update progress_bar\n", 229 | " s = \"{}:{} Loss: {:.4f},valid: {:.4f}\"\n", 230 | " format_list = [e,batch_idx*batch_size, np.mean(losses), \n", 231 | " last_test_perf] \n", 232 | " s = s.format(*format_list)\n", 233 | " prog_bar.set_description(s)\n", 234 | " \n", 235 | "def test_model(model, device, test_loader):\n", 236 | " # Test the Model\n", 237 | " nrmsd = []\n", 238 | " with torch.no_grad():\n", 239 | " for x, y in test_loader:\n", 240 | " data = x.to(device).transpose(1,0)\n", 241 | " target = y.to(device)\n", 242 | " optimizer.zero_grad()\n", 243 | " out = model(data)\n", 244 | " nrmsd.append(nengolib.signal.nrmse(out.detach().cpu().numpy().flatten(), target=target.detach().cpu().numpy().flatten()))\n", 245 | " perf = np.array(nrmsd).mean()\n", 246 | " return perf" 247 | ] 248 | }, 249 | { 250 | "cell_type": "code", 251 | "execution_count": null, 252 | "metadata": { 253 | "ExecuteTime": { 254 | "end_time": "2021-01-31T16:18:41.596917Z", 255 | "start_time": "2021-01-31T16:18:41.593831Z" 256 | } 257 | }, 258 | "outputs": [], 259 | "source": [ 260 | "class LSTM_Predictor(nn.Module):\n", 261 | " def __init__(self, out_features, lstm_params):\n", 262 | " super(LSTM_Predictor, self).__init__()\n", 263 | " self.lstm = nn.LSTM(**lstm_params)\n", 264 | " self.to_out = nn.Linear(lstm_params['hidden_size'], \n", 265 | " out_features)\n", 266 | " def forward(self, inp):\n", 267 | " x = self.lstm(inp)[0].transpose(1,0)\n", 268 | " x = torch.tanh(self.to_out(x))\n", 269 | " return x" 270 | ] 271 | }, 272 | { 273 | "cell_type": "code", 274 | "execution_count": null, 275 | "metadata": { 276 | "ExecuteTime": { 277 | "end_time": "2021-01-31T18:18:18.908819Z", 278 | "start_time": "2021-01-31T16:18:42.806246Z" 279 | }, 280 | "scrolled": true 281 | }, 282 | "outputs": [], 283 | "source": [ 284 | "start_tau = 17\n", 285 | "start_pd = 15\n", 286 | "diffs = [1, 2, 3, 4, 5]\n", 287 | "for diff in diffs:\n", 288 | " (train_X, train_Y), (test_X, test_Y) = generate_data(128, 5000, tau=start_tau*diff, \n", 289 | " predict_length=start_pd*diff)\n", 290 | " dataset = torch.utils.data.TensorDataset(torch.Tensor(train_X).cuda(), torch.Tensor(train_Y).cuda())\n", 291 | " dataset = torch.utils.data.DataLoader(dataset, batch_size=32, shuffle=True)\n", 292 | "\n", 293 | " dataset_valid = torch.utils.data.TensorDataset(torch.Tensor(test_X).cuda(), torch.Tensor(test_Y).cuda())\n", 294 | " dataset_valid = torch.utils.data.DataLoader(dataset_valid, batch_size=64, shuffle=False)\n", 295 | "\n", 296 | " print(train_X.shape, train_Y.shape, test_X.shape)\n", 297 | "\n", 298 | " lstm_params = dict(input_size=1,\n", 299 | " hidden_size=25, \n", 300 | " num_layers=4)\n", 301 | " model = LSTM_Predictor(1, lstm_params=lstm_params).cuda()\n", 302 | "\n", 303 | "\n", 304 | " optimizer = torch.optim.Adam(model.parameters())\n", 305 | " loss_func = nn.MSELoss()\n", 306 | " \n", 307 | " epochs = 1000\n", 308 | " batch_size = 32\n", 309 | " progress_bar = tqdm(range(int(epochs)), bar_format='{l_bar}{bar:5}{r_bar}{bar:-5b}')\n", 310 | " last_perf = 1000\n", 311 | " for e in progress_bar:\n", 312 | " train(model, ttype, dataset, dataset_valid, \n", 313 | " optimizer, loss_func, batch_size=batch_size, loss_buffer_size=64,\n", 314 | " epoch=e, perf_file=join('perf','mackeyglass_lstm_ratio_1.csv'),\n", 315 | " prog_bar=progress_bar, tau=start_tau*diff, pred_len=start_pd*diff)" 316 | ] 317 | } 318 | ], 319 | "metadata": { 320 | "kernelspec": { 321 | "display_name": "Python 3", 322 | "language": "python", 323 | "name": "python3" 324 | }, 325 | "language_info": { 326 | "codemirror_mode": { 327 | "name": "ipython", 328 | "version": 3 329 | }, 330 | "file_extension": ".py", 331 | "mimetype": "text/x-python", 332 | "name": "python", 333 | "nbconvert_exporter": "python", 334 | "pygments_lexer": "ipython3", 335 | "version": "3.6.10" 336 | }, 337 | "toc": { 338 | "nav_menu": {}, 339 | "number_sections": true, 340 | "sideBar": true, 341 | "skip_h1_title": false, 342 | "title_cell": "Table of Contents", 343 | "title_sidebar": "Contents", 344 | "toc_cell": false, 345 | "toc_position": {}, 346 | "toc_section_display": true, 347 | "toc_window_display": false 348 | } 349 | }, 350 | "nbformat": 4, 351 | "nbformat_minor": 4 352 | } 353 | -------------------------------------------------------------------------------- /experiments/psMNIST/sMNIST.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": null, 6 | "metadata": { 7 | "ExecuteTime": { 8 | "end_time": "2021-03-15T13:46:07.365933Z", 9 | "start_time": "2021-03-15T13:46:07.144247Z" 10 | } 11 | }, 12 | "outputs": [], 13 | "source": [ 14 | "%matplotlib inline" 15 | ] 16 | }, 17 | { 18 | "cell_type": "code", 19 | "execution_count": null, 20 | "metadata": { 21 | "ExecuteTime": { 22 | "end_time": "2021-03-15T13:46:08.069258Z", 23 | "start_time": "2021-03-15T13:46:07.366877Z" 24 | } 25 | }, 26 | "outputs": [], 27 | "source": [ 28 | "import matplotlib.pyplot as plt\n", 29 | "\n", 30 | "import torch\n", 31 | "from torch.optim.lr_scheduler import StepLR\n", 32 | "from sklearn.metrics import confusion_matrix\n", 33 | "\n", 34 | "import numpy as np\n", 35 | "\n", 36 | "import torch.nn as nn\n", 37 | "\n", 38 | "from tqdm import tqdm_notebook\n", 39 | "\n", 40 | "from torchvision import transforms\n", 41 | "from torchvision import datasets\n", 42 | "from os.path import join\n", 43 | "\n", 44 | "from deepsith import DeepSITH\n", 45 | "\n", 46 | "from tqdm.notebook import tqdm\n", 47 | "\n", 48 | "import random\n", 49 | "\n", 50 | "from csv import DictWriter\n", 51 | "# if gpu is to be used\n", 52 | "use_cuda = torch.cuda.is_available()\n", 53 | "\n", 54 | "FloatTensor = torch.cuda.FloatTensor if use_cuda else torch.FloatTensor\n", 55 | "LongTensor = torch.cuda.LongTensor if use_cuda else torch.LongTensor\n", 56 | "\n", 57 | "ttype =FloatTensor\n", 58 | "\n", 59 | "import seaborn as sn\n", 60 | "print(use_cuda)\n", 61 | "import pickle\n", 62 | "\n", 63 | "sn.set_context(\"poster\")" 64 | ] 65 | }, 66 | { 67 | "cell_type": "code", 68 | "execution_count": null, 69 | "metadata": { 70 | "ExecuteTime": { 71 | "end_time": "2021-03-15T13:46:08.073469Z", 72 | "start_time": "2021-03-15T13:46:08.070605Z" 73 | } 74 | }, 75 | "outputs": [], 76 | "source": [ 77 | "class DeepSITH_Classifier(nn.Module):\n", 78 | " def __init__(self, out_features, layer_params, dropout=.1):\n", 79 | " super(DeepSITH_Classifier, self).__init__()\n", 80 | " last_hidden = layer_params[-1]['hidden_size']\n", 81 | " self.hs = DeepSITH(layer_params=layer_params, dropout=dropout)\n", 82 | " self.to_out = nn.Linear(last_hidden, out_features)\n", 83 | " def forward(self, inp):\n", 84 | " x = self.hs(inp)\n", 85 | " x = self.to_out(x)\n", 86 | " return x" 87 | ] 88 | }, 89 | { 90 | "cell_type": "code", 91 | "execution_count": null, 92 | "metadata": { 93 | "ExecuteTime": { 94 | "end_time": "2021-03-15T13:46:08.083443Z", 95 | "start_time": "2021-03-15T13:46:08.074275Z" 96 | } 97 | }, 98 | "outputs": [], 99 | "source": [ 100 | "import scipy.optimize as opt\n", 101 | "from deepsith import iSITH\n", 102 | "def min_fun(x, *args):\n", 103 | " ntau = int(x[0])\n", 104 | " k = int(x[1])\n", 105 | " if k < 4 or k>125:\n", 106 | " return np.inf\n", 107 | " tau_min = args[0]\n", 108 | " tau_max = args[1] \n", 109 | " ev = iSITH(tau_min=tau_min, tau_max=tau_max, buff_max=tau_max*5, k=k, ntau=ntau, dt=1, g=1.0) \n", 110 | " std_0 = ev.filters[:, 0, 0, :].detach().cpu().T.numpy()[::-1].sum(1)[int(tau_min):int(tau_max)].std()\n", 111 | " std_1 = ev.filters[:, 0, 0, :].detach().cpu().T.numpy()[::-1, ::2].sum(1)[int(tau_min):int(tau_max)].std() \n", 112 | " to_min = std_0/std_1\n", 113 | " return to_min" 114 | ] 115 | }, 116 | { 117 | "cell_type": "markdown", 118 | "metadata": {}, 119 | "source": [ 120 | "# Load Stimuli" 121 | ] 122 | }, 123 | { 124 | "cell_type": "code", 125 | "execution_count": null, 126 | "metadata": { 127 | "ExecuteTime": { 128 | "end_time": "2021-03-15T13:46:36.407379Z", 129 | "start_time": "2021-03-15T13:46:36.405404Z" 130 | } 131 | }, 132 | "outputs": [], 133 | "source": [ 134 | "norm = transforms.Normalize((.1307,), (.3081,), )" 135 | ] 136 | }, 137 | { 138 | "cell_type": "code", 139 | "execution_count": null, 140 | "metadata": { 141 | "ExecuteTime": { 142 | "end_time": "2021-03-15T13:46:37.836490Z", 143 | "start_time": "2021-03-15T13:46:37.822290Z" 144 | } 145 | }, 146 | "outputs": [], 147 | "source": [ 148 | "batch_size = 64\n", 149 | "transform = transforms.Compose([transforms.ToTensor(),\n", 150 | " transforms.Normalize((.1307,), (.3081,))\n", 151 | " ])\n", 152 | "ds1 = datasets.MNIST('../data', train=True, download=True, transform=transform)\n", 153 | "ds2 = datasets.MNIST('../data', train=False, download=True, transform=transform)\n", 154 | "train_loader=torch.utils.data.DataLoader(ds1,batch_size=batch_size, \n", 155 | " num_workers=1, pin_memory=True, shuffle=True)\n", 156 | "test_loader=torch.utils.data.DataLoader(ds2, batch_size=batch_size, \n", 157 | " num_workers=1, pin_memory=True, shuffle=True)" 158 | ] 159 | }, 160 | { 161 | "cell_type": "code", 162 | "execution_count": null, 163 | "metadata": { 164 | "ExecuteTime": { 165 | "end_time": "2021-03-15T13:44:18.485182Z", 166 | "start_time": "2021-03-15T13:44:14.672022Z" 167 | } 168 | }, 169 | "outputs": [], 170 | "source": [ 171 | "test = next(iter(test_loader))[0]\n", 172 | "\n", 173 | "plt.imshow(test[0].reshape(-1).reshape(28,28))\n", 174 | "\n", 175 | "plt.colorbar()" 176 | ] 177 | }, 178 | { 179 | "cell_type": "markdown", 180 | "metadata": {}, 181 | "source": [ 182 | "# Define test and train" 183 | ] 184 | }, 185 | { 186 | "cell_type": "code", 187 | "execution_count": null, 188 | "metadata": { 189 | "ExecuteTime": { 190 | "end_time": "2021-03-15T13:46:11.244757Z", 191 | "start_time": "2021-03-15T13:46:11.235002Z" 192 | } 193 | }, 194 | "outputs": [], 195 | "source": [ 196 | "\n", 197 | "def train(model, ttype, train_loader, test_loader, optimizer, loss_func, epoch, perf_file,\n", 198 | " loss_buffer_size=800, batch_size=4, device='cuda',\n", 199 | " prog_bar=None, last_test_perf=0):\n", 200 | " \n", 201 | " assert(loss_buffer_size%batch_size==0)\n", 202 | "\n", 203 | " \n", 204 | " losses = []\n", 205 | " perfs = []\n", 206 | " best_test_perf = -1\n", 207 | " \n", 208 | " for batch_idx, (data, target) in enumerate(train_loader):\n", 209 | " model.train()\n", 210 | " data = data.to(device).view(data.shape[0],1,1,-1)\n", 211 | " target = target.to(device)\n", 212 | " optimizer.zero_grad()\n", 213 | " out = model(data)\n", 214 | " loss = loss_func(out[:, -1, :],\n", 215 | " target)\n", 216 | "\n", 217 | " loss.backward()\n", 218 | " optimizer.step()\n", 219 | "\n", 220 | " perfs.append((torch.argmax(out[:, -1, :], dim=-1) == \n", 221 | " target).sum().item())\n", 222 | " perfs = perfs[int(-loss_buffer_size/batch_size):]\n", 223 | " losses.append(loss.detach().cpu().numpy())\n", 224 | " losses = losses[int(-loss_buffer_size/batch_size):]\n", 225 | " if not (prog_bar is None):\n", 226 | " # Update progress_bar\n", 227 | " s = \"{}:{} Loss: {:.4f}, perf: {:.4f}, valid: {:.4f}\"\n", 228 | " format_list = [e,batch_idx*batch_size, np.mean(losses), \n", 229 | " np.sum(perfs)/((len(perfs))*batch_size), last_test_perf] \n", 230 | " s = s.format(*format_list)\n", 231 | " prog_bar.set_description(s)\n", 232 | " \n", 233 | " if ((batch_idx*batch_size)%loss_buffer_size == 0) & (batch_idx != 0):\n", 234 | " loss_track = {}\n", 235 | " #last_test_perf = test(model, 'cuda', test_loader, \n", 236 | " # batch_size=batch_size, \n", 237 | " # )\n", 238 | " loss_track['avg_loss'] = np.mean(losses)\n", 239 | " loss_track['last_test'] = last_test_perf\n", 240 | " loss_track['epoch'] = epoch\n", 241 | " loss_track['batch_idx'] = batch_idx\n", 242 | " loss_track['train_perf']= np.sum(perfs)/((len(perfs))*batch_size)\n", 243 | " with open(perf_file, 'a+') as fp:\n", 244 | " csv_writer = DictWriter(fp, fieldnames=list(loss_track.keys()))\n", 245 | " if fp.tell() == 0:\n", 246 | " csv_writer.writeheader()\n", 247 | " csv_writer.writerow(loss_track)\n", 248 | " fp.flush()\n", 249 | " #if best_test_perf < last_test_perf:\n", 250 | " # torch.save(model.state_dict(), perf_file[:-4]+\".pt\")\n", 251 | " # best_test_perf = last_test_perf\n", 252 | "\n", 253 | " \n", 254 | "def test(model, device, test_loader, batch_size=4):\n", 255 | " model.eval()\n", 256 | " correct = 0\n", 257 | " count = 0\n", 258 | " with torch.no_grad():\n", 259 | " for data, target in test_loader:\n", 260 | " data = data.to(device).view(data.shape[0],1,1,-1)\n", 261 | " target = target.to(device)\n", 262 | " \n", 263 | " out = model(data)\n", 264 | " pred = out[:, -1].argmax(dim=-1, keepdim=True)\n", 265 | " correct += pred.eq(target.view_as(pred)).sum().item()\n", 266 | " count += 1\n", 267 | " return correct / len(test_loader.dataset)" 268 | ] 269 | }, 270 | { 271 | "cell_type": "markdown", 272 | "metadata": {}, 273 | "source": [ 274 | "# Setup the model" 275 | ] 276 | }, 277 | { 278 | "cell_type": "code", 279 | "execution_count": null, 280 | "metadata": { 281 | "ExecuteTime": { 282 | "end_time": "2021-03-15T13:46:18.898666Z", 283 | "start_time": "2021-03-15T13:46:15.331245Z" 284 | }, 285 | "scrolled": true 286 | }, 287 | "outputs": [], 288 | "source": [ 289 | "g = 0.0\n", 290 | "sith_params1 = {\"in_features\":1, \n", 291 | " \"tau_min\":1, \"tau_max\":30.0, \"buff_max\":50,\n", 292 | " \"k\":125, 'dt':1,\n", 293 | " \"ntau\":20, 'g':g, \n", 294 | " \"ttype\":ttype, \"batch_norm\":True,\n", 295 | " \"hidden_size\":60, \"act_func\":nn.ReLU()\n", 296 | " }\n", 297 | "sith_params2 = {\"in_features\":sith_params1['hidden_size'], \n", 298 | " \"tau_min\":1, \"tau_max\":150.0, \"buff_max\":250,\n", 299 | " \"k\":61, 'dt':1,\n", 300 | " \"ntau\":20, 'g':g, \n", 301 | " \"ttype\":ttype, \"batch_norm\":True,\n", 302 | " \"hidden_size\":60, \"act_func\":nn.ReLU()\n", 303 | " }\n", 304 | "sith_params3 = {\"in_features\":sith_params2['hidden_size'], \n", 305 | " \"tau_min\":1, \"tau_max\":750.0, \"buff_max\":1500,\n", 306 | " \"k\":35, 'dt':1,\n", 307 | " \"ntau\":20, 'g':g, \n", 308 | " \"ttype\":ttype, \"batch_norm\":True,\n", 309 | " \"hidden_size\":60, \"act_func\":nn.ReLU()\n", 310 | " }\n", 311 | "\n", 312 | "layer_params = [sith_params1, sith_params2, sith_params3]\n", 313 | "\n", 314 | "\n", 315 | "\n", 316 | "model = DeepSITH_Classifier(10,\n", 317 | " layer_params=layer_params, \n", 318 | " dropout=0.2).cuda()\n", 319 | "\n", 320 | "tot_weights = 0\n", 321 | "for p in model.parameters():\n", 322 | " tot_weights += p.numel()\n", 323 | "print(\"Total Weights:\", tot_weights)\n", 324 | "print(model)" 325 | ] 326 | }, 327 | { 328 | "cell_type": "code", 329 | "execution_count": null, 330 | "metadata": { 331 | "ExecuteTime": { 332 | "end_time": "2021-03-16T07:38:12.591705Z", 333 | "start_time": "2021-03-15T13:46:45.607305Z" 334 | }, 335 | "scrolled": true 336 | }, 337 | "outputs": [], 338 | "source": [ 339 | "epochs = 40\n", 340 | "loss_func = nn.CrossEntropyLoss()\n", 341 | "optimizer = torch.optim.Adam(model.parameters(), lr=2e-3)\n", 342 | "sched = StepLR(optimizer, step_size=int(epochs / 4), gamma=0.1)\n", 343 | "#sched = None\n", 344 | "perf_file = join('perf','smnist_deepsith_4layer_01.csv')\n", 345 | "test_perf = []\n", 346 | "progress_bar = tqdm(range(int(epochs)), bar_format='{l_bar}{bar:5}{r_bar}{bar:-5b}')\n", 347 | "best_test_perf = 0\n", 348 | "t_p = .0000\n", 349 | "for e in progress_bar:\n", 350 | " train(model, ttype, train_loader, test_loader, optimizer, loss_func, batch_size=batch_size,\n", 351 | " epoch=e, perf_file=perf_file,loss_buffer_size=64*32, \n", 352 | " prog_bar=progress_bar, last_test_perf=t_p)\n", 353 | " \n", 354 | " t_p = test(model, 'cuda', test_loader, \n", 355 | " batch_size=batch_size, \n", 356 | " )\n", 357 | " if t_p > best_test_perf:\n", 358 | " best_test_perf = t_p\n", 359 | " torch.save(model.state_dict(), perf_file[:-4]+\".pt\")\n", 360 | " \n", 361 | " test_perf.append({\"epoch\":e,\n", 362 | " 'test':t_p})\n", 363 | " \n", 364 | " sched.step()" 365 | ] 366 | } 367 | ], 368 | "metadata": { 369 | "kernelspec": { 370 | "display_name": "Python 3", 371 | "language": "python", 372 | "name": "python3" 373 | }, 374 | "language_info": { 375 | "codemirror_mode": { 376 | "name": "ipython", 377 | "version": 3 378 | }, 379 | "file_extension": ".py", 380 | "mimetype": "text/x-python", 381 | "name": "python", 382 | "nbconvert_exporter": "python", 383 | "pygments_lexer": "ipython3", 384 | "version": "3.6.10" 385 | }, 386 | "toc": { 387 | "nav_menu": { 388 | "height": "163px", 389 | "width": "250px" 390 | }, 391 | "number_sections": true, 392 | "sideBar": true, 393 | "skip_h1_title": false, 394 | "title_cell": "Table of Contents", 395 | "title_sidebar": "Contents", 396 | "toc_cell": false, 397 | "toc_position": {}, 398 | "toc_section_display": true, 399 | "toc_window_display": false 400 | } 401 | }, 402 | "nbformat": 4, 403 | "nbformat_minor": 4 404 | } 405 | -------------------------------------------------------------------------------- /experiments/MackeyGlass/MackeyGlass-LMU-increasing_tau.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": null, 6 | "metadata": { 7 | "ExecuteTime": { 8 | "end_time": "2021-01-31T16:15:36.355232Z", 9 | "start_time": "2021-01-31T16:15:36.174356Z" 10 | } 11 | }, 12 | "outputs": [], 13 | "source": [ 14 | "%matplotlib inline" 15 | ] 16 | }, 17 | { 18 | "cell_type": "code", 19 | "execution_count": null, 20 | "metadata": { 21 | "ExecuteTime": { 22 | "end_time": "2021-01-31T16:15:37.469291Z", 23 | "start_time": "2021-01-31T16:15:36.530387Z" 24 | } 25 | }, 26 | "outputs": [], 27 | "source": [ 28 | "import torch\n", 29 | "import matplotlib.pyplot as plt\n", 30 | "import torchvision\n", 31 | "import numpy as np\n", 32 | "import torch.nn as nn\n", 33 | "import torch.nn.functional as F\n", 34 | "from tqdm import tqdm_notebook\n", 35 | "import PIL\n", 36 | "from torch.nn.utils import weight_norm\n", 37 | "from torch.autograd import Variable\n", 38 | "import nengolib\n", 39 | "device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')\n", 40 | "\n", 41 | "from os.path import join\n", 42 | "import scipy.special\n", 43 | "import pandas as pd\n", 44 | "import seaborn as sn\n", 45 | "import scipy\n", 46 | "from scipy.spatial.distance import euclidean\n", 47 | "from scipy.interpolate import interp1d\n", 48 | "from tqdm.notebook import tqdm\n", 49 | "import random\n", 50 | "from csv import DictWriter\n", 51 | "from lmu import LegendreMemoryUnit\n", 52 | "\n", 53 | "# if gpu is to be used\n", 54 | "use_cuda = torch.cuda.is_available()\n", 55 | "FloatTensor = torch.cuda.FloatTensor if use_cuda else torch.FloatTensor\n", 56 | "DoubleTensor = torch.cuda.DoubleTensor if use_cuda else torch.DoubleTensor\n", 57 | "\n", 58 | "IntTensor = torch.cuda.IntTensor if use_cuda else torch.IntTensor\n", 59 | "LongTensor = torch.cuda.LongTensor if use_cuda else torch.LongTensor\n", 60 | "ByteTensor = torch.cuda.ByteTensor if use_cuda else torch.ByteTensor\n", 61 | "ttype = FloatTensor\n", 62 | "\n", 63 | "import seaborn as sns\n", 64 | "print(use_cuda)\n", 65 | "import pickle\n" 66 | ] 67 | }, 68 | { 69 | "cell_type": "code", 70 | "execution_count": null, 71 | "metadata": { 72 | "ExecuteTime": { 73 | "end_time": "2021-01-31T16:15:41.303102Z", 74 | "start_time": "2021-01-31T16:15:41.299727Z" 75 | } 76 | }, 77 | "outputs": [], 78 | "source": [ 79 | "class LMUModel(nn.Module):\n", 80 | " def __init__(self, n_out, layer_params):\n", 81 | " super(LMUModel, self).__init__()\n", 82 | " self.layers = nn.ModuleList([LegendreMemoryUnit(**layer_params[i])\n", 83 | " for i in range(len(layer_params))])\n", 84 | " self.dense = nn.Linear(layer_params[-1]['hidden_size'], n_out)\n", 85 | "\n", 86 | " \n", 87 | " def forward(self, x):\n", 88 | " for l in self.layers:\n", 89 | " x, _ = l(x) \n", 90 | " x = self.dense(x)\n", 91 | " return x" 92 | ] 93 | }, 94 | { 95 | "cell_type": "code", 96 | "execution_count": null, 97 | "metadata": { 98 | "ExecuteTime": { 99 | "end_time": "2021-01-31T16:15:43.334751Z", 100 | "start_time": "2021-01-31T16:15:41.545777Z" 101 | } 102 | }, 103 | "outputs": [], 104 | "source": [ 105 | "\n", 106 | "\n", 107 | "lmu_params = [dict(input_dim=1, hidden_size=49, order=4, theta=4),\n", 108 | " dict(input_dim=49, hidden_size=49, order=4, theta=4),\n", 109 | " dict(input_dim=49, hidden_size=49, order=4, theta=4),\n", 110 | " dict(input_dim=49, hidden_size=49, order=4, theta=4),\n", 111 | " ]\n", 112 | "model = LMUModel(1, lmu_params).cuda()\n", 113 | "\n", 114 | "tot_weights = 0\n", 115 | "for p in model.parameters():\n", 116 | " tot_weights += p.numel()\n", 117 | "print(\"Total Weights:\", tot_weights)\n", 118 | "print(model)\n" 119 | ] 120 | }, 121 | { 122 | "cell_type": "markdown", 123 | "metadata": {}, 124 | "source": [ 125 | "# Load Stimuli" 126 | ] 127 | }, 128 | { 129 | "cell_type": "code", 130 | "execution_count": null, 131 | "metadata": { 132 | "ExecuteTime": { 133 | "end_time": "2021-01-31T16:15:45.466420Z", 134 | "start_time": "2021-01-31T16:15:45.459513Z" 135 | } 136 | }, 137 | "outputs": [], 138 | "source": [ 139 | "import collections\n", 140 | "\n", 141 | "def mackey_glass(sample_len=1000, tau=17, delta_t=10, seed=None, n_samples=1):\n", 142 | " # Adapted from https://github.com/mila-iqia/summerschool2015/blob/master/rnn_tutorial/synthetic.py\n", 143 | " '''\n", 144 | " mackey_glass(sample_len=1000, tau=17, seed = None, n_samples = 1) -> input\n", 145 | " Generate the Mackey Glass time-series. Parameters are:\n", 146 | " - sample_len: length of the time-series in timesteps. Default is 1000.\n", 147 | " - tau: delay of the MG - system. Commonly used values are tau=17 (mild \n", 148 | " chaos) and tau=30 (moderate chaos). Default is 17.\n", 149 | " - seed: to seed the random generator, can be used to generate the same\n", 150 | " timeseries at each invocation.\n", 151 | " - n_samples : number of samples to generate\n", 152 | " '''\n", 153 | " history_len = tau * delta_t \n", 154 | " # Initial conditions for the history of the system\n", 155 | " timeseries = 1.2\n", 156 | " \n", 157 | " if seed is not None:\n", 158 | " np.random.seed(seed)\n", 159 | "\n", 160 | " samples = []\n", 161 | "\n", 162 | " for _ in range(n_samples):\n", 163 | " history = collections.deque(1.2 * np.ones(history_len) + 0.2 * \\\n", 164 | " (np.random.rand(history_len) - 0.5))\n", 165 | " # Preallocate the array for the time-series\n", 166 | " inp = np.zeros((sample_len,1))\n", 167 | " \n", 168 | " for timestep in range(sample_len):\n", 169 | " for _ in range(delta_t):\n", 170 | " xtau = history.popleft()\n", 171 | " history.append(timeseries)\n", 172 | " timeseries = history[-1] + (0.2 * xtau / (1.0 + xtau ** 10) - \\\n", 173 | " 0.1 * history[-1]) / delta_t\n", 174 | " inp[timestep] = timeseries\n", 175 | " \n", 176 | " # Squash timeseries through tanh\n", 177 | " inp = np.tanh(inp - 1)\n", 178 | " samples.append(inp)\n", 179 | " return samples\n", 180 | "\n", 181 | "\n", 182 | "def generate_data(n_batches, length, split=0.5, seed=0,\n", 183 | " predict_length=15, tau=17, washout=100, delta_t=1,\n", 184 | " center=True):\n", 185 | " X = np.asarray(mackey_glass(\n", 186 | " sample_len=length+predict_length+washout, tau=tau,\n", 187 | " seed=seed, n_samples=n_batches))\n", 188 | " X = X[:, washout:, :]\n", 189 | " cutoff = int(split*n_batches)\n", 190 | " if center:\n", 191 | " X -= np.mean(X) # global mean over all batches, approx -0.066\n", 192 | " Y = X[:, predict_length:, :]\n", 193 | " X = X[:, :-predict_length, :]\n", 194 | " assert X.shape == Y.shape\n", 195 | " return ((X[:cutoff], Y[:cutoff]),\n", 196 | " (X[cutoff:], Y[cutoff:]))" 197 | ] 198 | }, 199 | { 200 | "cell_type": "code", 201 | "execution_count": null, 202 | "metadata": { 203 | "ExecuteTime": { 204 | "end_time": "2021-01-31T16:15:48.056808Z", 205 | "start_time": "2021-01-31T16:15:45.670627Z" 206 | } 207 | }, 208 | "outputs": [], 209 | "source": [ 210 | "(train_X, train_Y), (test_X, test_Y) = generate_data(128, 5000)\n", 211 | "dataset = torch.utils.data.TensorDataset(torch.Tensor(train_X).cuda(), torch.Tensor(train_Y).cuda())\n", 212 | "dataset = torch.utils.data.DataLoader(dataset, batch_size=32, shuffle=True)\n", 213 | "\n", 214 | "dataset_valid = torch.utils.data.TensorDataset(torch.Tensor(test_X).cuda(), torch.Tensor(test_Y).cuda())\n", 215 | "dataset_valid = torch.utils.data.DataLoader(dataset_valid, batch_size=64, shuffle=False)\n", 216 | "\n", 217 | "print(train_X.shape, train_Y.shape, test_X.shape)" 218 | ] 219 | }, 220 | { 221 | "cell_type": "markdown", 222 | "metadata": {}, 223 | "source": [ 224 | "## Setup for Model" 225 | ] 226 | }, 227 | { 228 | "cell_type": "code", 229 | "execution_count": null, 230 | "metadata": { 231 | "ExecuteTime": { 232 | "end_time": "2021-01-31T16:16:09.137416Z", 233 | "start_time": "2021-01-31T16:16:09.127621Z" 234 | } 235 | }, 236 | "outputs": [], 237 | "source": [ 238 | "\n", 239 | "def train(model, ttype, train_loader, test_loader, optimizer, loss_func, epoch, perf_file,\n", 240 | " loss_buffer_size=800, batch_size=4, device='cuda',\n", 241 | " prog_bar=None, tau=17, pred_len=15):\n", 242 | " \n", 243 | " assert(loss_buffer_size%batch_size==0)\n", 244 | " \n", 245 | " losses = []\n", 246 | " last_test_perf = 0\n", 247 | " best_test_perf = 1000\n", 248 | " \n", 249 | " for batch_idx, (data, target) in enumerate(train_loader):\n", 250 | " model.train()\n", 251 | " data = data.to(device)\n", 252 | " target = target.to(device)\n", 253 | " optimizer.zero_grad()\n", 254 | " out = model(data)\n", 255 | " loss = loss_func(out,\n", 256 | " target)\n", 257 | "\n", 258 | " loss.backward()\n", 259 | " optimizer.step()\n", 260 | "\n", 261 | " losses.append(loss.detach().cpu().numpy())\n", 262 | " losses = losses[int(-loss_buffer_size/batch_size):]\n", 263 | " \n", 264 | " if ((batch_idx*batch_size)%loss_buffer_size == 0):\n", 265 | " loss_track = {}\n", 266 | " last_test_perf = test_model(model, 'cuda', test_loader, \n", 267 | " )\n", 268 | " loss_track['avg_loss'] = np.mean(losses)\n", 269 | " loss_track['last_test'] = last_test_perf\n", 270 | " loss_track['epoch'] = epoch\n", 271 | " loss_track['batch_idx'] = batch_idx\n", 272 | " loss_track['tau'] = tau\n", 273 | " loss_track['pred_len'] = pred_len\n", 274 | " with open(perf_file, 'a+') as fp:\n", 275 | " csv_writer = DictWriter(fp, fieldnames=list(loss_track.keys()))\n", 276 | " if fp.tell() == 0:\n", 277 | " csv_writer.writeheader()\n", 278 | " csv_writer.writerow(loss_track)\n", 279 | " fp.flush()\n", 280 | " if best_test_perf > last_test_perf:\n", 281 | " torch.save(model.state_dict(), perf_file[:-4]+\".pt\")\n", 282 | " best_test_perf = last_test_perf\n", 283 | " if not (prog_bar is None):\n", 284 | " # Update progress_bar\n", 285 | " s = \"{}:{} Loss: {:.4f},valid: {:.4f}\"\n", 286 | " format_list = [e,batch_idx*batch_size, np.mean(losses), \n", 287 | " last_test_perf] \n", 288 | " s = s.format(*format_list)\n", 289 | " prog_bar.set_description(s)\n", 290 | " \n", 291 | "def test_model(model, device, test_loader):\n", 292 | " # Test the Model\n", 293 | " nrmsd = []\n", 294 | " with torch.no_grad():\n", 295 | " for x, y in test_loader:\n", 296 | " data = x.to(device)\n", 297 | " target = y.to(device)\n", 298 | " optimizer.zero_grad()\n", 299 | " out = model(data)\n", 300 | " nrmsd.append(nengolib.signal.nrmse(out.detach().cpu().numpy().flatten(), \n", 301 | " target=target.detach().cpu().numpy().flatten()))\n", 302 | " perf = np.array(nrmsd).mean()\n", 303 | " return perf" 304 | ] 305 | }, 306 | { 307 | "cell_type": "code", 308 | "execution_count": null, 309 | "metadata": { 310 | "ExecuteTime": { 311 | "end_time": "2021-02-03T05:20:46.823977Z", 312 | "start_time": "2021-02-02T19:40:22.179973Z" 313 | } 314 | }, 315 | "outputs": [], 316 | "source": [ 317 | "start_tau = 17\n", 318 | "start_pd = 15\n", 319 | "diffs = [1,2,3,4,5]\n", 320 | "for diff in diffs:\n", 321 | " (train_X, train_Y), (test_X, test_Y) = generate_data(128, 5000, tau=start_tau*diff, \n", 322 | " predict_length=start_pd*diff)\n", 323 | "\n", 324 | " dataset = torch.utils.data.TensorDataset(torch.Tensor(train_X).cuda(), torch.Tensor(train_Y).cuda())\n", 325 | " dataset = torch.utils.data.DataLoader(dataset, batch_size=32, shuffle=True)\n", 326 | "\n", 327 | " dataset_valid = torch.utils.data.TensorDataset(torch.Tensor(test_X).cuda(), torch.Tensor(test_Y).cuda())\n", 328 | " dataset_valid = torch.utils.data.DataLoader(dataset_valid, batch_size=64, shuffle=False)\n", 329 | " lmu_params = [dict(input_dim=1, hidden_size=49, order=4, theta=4),\n", 330 | " dict(input_dim=49, hidden_size=49, order=4, theta=4),\n", 331 | " dict(input_dim=49, hidden_size=49, order=4, theta=4),\n", 332 | " dict(input_dim=49, hidden_size=49, order=4, theta=4),\n", 333 | " ]\n", 334 | " model = LMUModel(1, lmu_params).cuda()\n", 335 | " optimizer = torch.optim.Adam(model.parameters())\n", 336 | " loss_func = nn.MSELoss()\n", 337 | "\n", 338 | " epochs = 1000\n", 339 | " batch_size = 32\n", 340 | " progress_bar = tqdm(range(int(epochs)), bar_format='{l_bar}{bar:5}{r_bar}{bar:-5b}')\n", 341 | " last_perf = 1000\n", 342 | " for e in progress_bar:\n", 343 | " train(model, ttype, dataset, dataset_valid, \n", 344 | " optimizer, loss_func, batch_size=batch_size, loss_buffer_size=64,\n", 345 | " epoch=e, perf_file=join('perf','mackeyglass_lmu_ratio_3.csv'),\n", 346 | " prog_bar=progress_bar, tau=start_tau*diff, pred_len=start_pd*diff)" 347 | ] 348 | } 349 | ], 350 | "metadata": { 351 | "kernelspec": { 352 | "display_name": "Python 3", 353 | "language": "python", 354 | "name": "python3" 355 | }, 356 | "language_info": { 357 | "codemirror_mode": { 358 | "name": "ipython", 359 | "version": 3 360 | }, 361 | "file_extension": ".py", 362 | "mimetype": "text/x-python", 363 | "name": "python", 364 | "nbconvert_exporter": "python", 365 | "pygments_lexer": "ipython3", 366 | "version": "3.6.10" 367 | }, 368 | "toc": { 369 | "nav_menu": {}, 370 | "number_sections": true, 371 | "sideBar": true, 372 | "skip_h1_title": false, 373 | "title_cell": "Table of Contents", 374 | "title_sidebar": "Contents", 375 | "toc_cell": false, 376 | "toc_position": {}, 377 | "toc_section_display": true, 378 | "toc_window_display": false 379 | } 380 | }, 381 | "nbformat": 4, 382 | "nbformat_minor": 4 383 | } 384 | -------------------------------------------------------------------------------- /experiments/AddingProblem/addingproblem-LSTM.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": null, 6 | "metadata": { 7 | "ExecuteTime": { 8 | "end_time": "2020-12-15T17:01:06.036477Z", 9 | "start_time": "2020-12-15T17:01:05.842293Z" 10 | } 11 | }, 12 | "outputs": [], 13 | "source": [ 14 | "%matplotlib inline" 15 | ] 16 | }, 17 | { 18 | "cell_type": "code", 19 | "execution_count": null, 20 | "metadata": { 21 | "ExecuteTime": { 22 | "end_time": "2020-12-15T17:01:06.655046Z", 23 | "start_time": "2020-12-15T17:01:06.090606Z" 24 | } 25 | }, 26 | "outputs": [], 27 | "source": [ 28 | "import matplotlib.pylab as plt\n", 29 | "import torch\n", 30 | "import numpy as np\n", 31 | "import seaborn as sn\n", 32 | "sn.set_context(\"poster\")\n", 33 | "import os\n", 34 | "import torch\n", 35 | "import torch.nn as nn\n", 36 | "import torch.nn.functional as F\n", 37 | "from torch.nn.utils import weight_norm\n", 38 | "from torchvision import transforms, datasets\n", 39 | "\n", 40 | "from deepsith import DeepSITH\n", 41 | "\n", 42 | "import numpy as np\n", 43 | "import scipy\n", 44 | "import scipy.stats as st\n", 45 | "import scipy.special\n", 46 | "import scipy.signal\n", 47 | "import scipy.interpolate\n", 48 | "\n", 49 | "import pandas as pd\n", 50 | "\n", 51 | "from os.path import join\n", 52 | "import random\n", 53 | "from csv import DictWriter\n", 54 | "\n", 55 | "from tqdm.notebook import tqdm\n", 56 | "import pickle\n", 57 | "# if gpu is to be used\n", 58 | "use_cuda = torch.cuda.is_available()\n", 59 | "print(use_cuda)\n", 60 | "FloatTensor = torch.cuda.FloatTensor if use_cuda else torch.FloatTensor\n", 61 | "DoubleTensor = torch.cuda.DoubleTensor if use_cuda else torch.DoubleTensor\n", 62 | "IntTensor = torch.cuda.IntTensor if use_cuda else torch.IntTensor\n", 63 | "LongTensor = torch.cuda.LongTensor if use_cuda else torch.LongTensor\n", 64 | "ByteTensor = torch.cuda.ByteTensor if use_cuda else torch.ByteTensor\n", 65 | "ttype = FloatTensor" 66 | ] 67 | }, 68 | { 69 | "cell_type": "code", 70 | "execution_count": null, 71 | "metadata": { 72 | "ExecuteTime": { 73 | "end_time": "2020-12-15T17:01:06.661735Z", 74 | "start_time": "2020-12-15T17:01:06.656351Z" 75 | } 76 | }, 77 | "outputs": [], 78 | "source": [ 79 | "import torch\n", 80 | "import numpy as np\n", 81 | "\n", 82 | "def get_batch(batch_size, T, ttype):\n", 83 | " values = torch.rand(T, batch_size, requires_grad=False)\n", 84 | " indices = torch.zeros_like(values)\n", 85 | " half = int(T / 2)\n", 86 | " for i in range(batch_size):\n", 87 | " half_1 = np.random.randint(half)\n", 88 | " hals_2 = np.random.randint(half, T)\n", 89 | " indices[half_1, i] = 1\n", 90 | " indices[hals_2, i] = 1\n", 91 | "\n", 92 | " data = torch.stack((values, indices), dim=-1).type(ttype)\n", 93 | " targets = torch.mul(values, indices).sum(dim=0).type(ttype)\n", 94 | " return data, targets\n" 95 | ] 96 | }, 97 | { 98 | "cell_type": "code", 99 | "execution_count": null, 100 | "metadata": { 101 | "ExecuteTime": { 102 | "end_time": "2020-12-15T17:01:09.231546Z", 103 | "start_time": "2020-12-15T17:01:09.229034Z" 104 | } 105 | }, 106 | "outputs": [], 107 | "source": [ 108 | "torch.manual_seed(1111)" 109 | ] 110 | }, 111 | { 112 | "cell_type": "code", 113 | "execution_count": null, 114 | "metadata": { 115 | "ExecuteTime": { 116 | "end_time": "2020-12-15T17:01:09.249505Z", 117 | "start_time": "2020-12-15T17:01:09.232698Z" 118 | } 119 | }, 120 | "outputs": [], 121 | "source": [ 122 | "def train(model, ttype, seq_length, optimizer, loss_func, \n", 123 | " epoch, perf_file, loss_buffer_size=20, batch_size=1, test_size=10,\n", 124 | " device='cuda', prog_bar=None):\n", 125 | " assert(loss_buffer_size%batch_size==0)\n", 126 | "\n", 127 | " losses = []\n", 128 | " perfs = []\n", 129 | " last_test_perf = 0\n", 130 | " for batch_idx in range(20000):\n", 131 | " model.train()\n", 132 | " sig, target = get_batch(batch_size, seq_length, ttype=ttype)\n", 133 | " \n", 134 | " target = target.unsqueeze(1)\n", 135 | " optimizer.zero_grad()\n", 136 | " out = model(sig)\n", 137 | " loss = loss_func(out[-1, :, :],\n", 138 | " target)\n", 139 | " \n", 140 | " loss.backward()\n", 141 | " optimizer.step()\n", 142 | "\n", 143 | " losses.append(loss.detach().cpu().numpy())\n", 144 | " losses = losses[-loss_buffer_size:]\n", 145 | " if not (prog_bar is None):\n", 146 | " # Update progress_bar\n", 147 | " s = \"{}:{} Loss: {:.8f}\"\n", 148 | " format_list = [e, int(batch_idx/(50/batch_size)), np.mean(losses)] \n", 149 | " s = s.format(*format_list)\n", 150 | " prog_bar.set_description(s)\n", 151 | " if ((batch_idx*batch_size)%loss_buffer_size == 0) & (batch_idx != 0):\n", 152 | " loss_track = {}\n", 153 | " #last_test_perf = test_norm(model, 'cuda', test_sig, test_class,\n", 154 | " # batch_size=test_size, \n", 155 | " # )\n", 156 | " loss_track['avg_loss'] = np.mean(losses)\n", 157 | " #loss_track['last_test'] = last_test_perf\n", 158 | " loss_track['epoch'] = epoch\n", 159 | " loss_track['batch_idx'] = batch_idx\n", 160 | " with open(perf_file, 'a+') as fp:\n", 161 | " csv_writer = DictWriter(fp, fieldnames=list(loss_track.keys()))\n", 162 | " if fp.tell() == 0:\n", 163 | " csv_writer.writeheader()\n", 164 | " csv_writer.writerow(loss_track)\n", 165 | " fp.flush()\n", 166 | "def test_norm(model, device, seq_length, loss_func, batch_size=100):\n", 167 | " model.eval()\n", 168 | " correct = 0\n", 169 | " count = 0\n", 170 | " with torch.no_grad():\n", 171 | " sig, target = get_batch(batch_size, seq_length, ttype=ttype)\n", 172 | " target = target.unsqueeze(1)\n", 173 | " out = model(sig)\n", 174 | " loss = loss_func(out[-1, :, :],\n", 175 | " target)\n", 176 | " return loss" 177 | ] 178 | }, 179 | { 180 | "cell_type": "code", 181 | "execution_count": null, 182 | "metadata": { 183 | "ExecuteTime": { 184 | "end_time": "2020-12-15T17:01:09.258255Z", 185 | "start_time": "2020-12-15T17:01:09.250462Z" 186 | } 187 | }, 188 | "outputs": [], 189 | "source": [ 190 | "class LSTM_Predictor(nn.Module):\n", 191 | " def __init__(self, out_features, lstm_params):\n", 192 | " super(LSTM_Predictor, self).__init__()\n", 193 | " self.lstm = nn.LSTM(**lstm_params)\n", 194 | " self.to_out = nn.Linear(lstm_params['hidden_size'], \n", 195 | " out_features)\n", 196 | " def forward(self, inp):\n", 197 | " x = self.lstm(inp)[0]\n", 198 | " x = self.to_out(x)\n", 199 | " return x" 200 | ] 201 | }, 202 | { 203 | "cell_type": "markdown", 204 | "metadata": {}, 205 | "source": [ 206 | "# T = 100" 207 | ] 208 | }, 209 | { 210 | "cell_type": "code", 211 | "execution_count": null, 212 | "metadata": { 213 | "ExecuteTime": { 214 | "end_time": "2020-12-15T17:01:23.033948Z", 215 | "start_time": "2020-12-15T17:01:23.026221Z" 216 | } 217 | }, 218 | "outputs": [], 219 | "source": [ 220 | "lstm_params = dict(input_size=2,\n", 221 | " hidden_size=128, \n", 222 | " num_layers=1)\n", 223 | "model = LSTM_Predictor(1, lstm_params=lstm_params).cuda()\n", 224 | "\n", 225 | "tot_weights = 0\n", 226 | "for p in model.parameters():\n", 227 | " tot_weights += p.numel()\n", 228 | "print(\"Total Weights:\", tot_weights)\n", 229 | "print(model)" 230 | ] 231 | }, 232 | { 233 | "cell_type": "code", 234 | "execution_count": null, 235 | "metadata": { 236 | "ExecuteTime": { 237 | "end_time": "2020-12-15T17:03:23.646330Z", 238 | "start_time": "2020-12-15T17:01:30.304957Z" 239 | } 240 | }, 241 | "outputs": [], 242 | "source": [ 243 | "seq_length=100\n", 244 | "\n", 245 | "loss_func = nn.MSELoss()\n", 246 | "optimizer = torch.optim.Adam(model.parameters())\n", 247 | "epochs = 1\n", 248 | "progress_bar = tqdm(range(int(epochs)), bar_format='{l_bar}{bar:5}{r_bar}{bar:-5b}')\n", 249 | "for e in progress_bar:\n", 250 | " train(model, ttype, seq_length,\n", 251 | " optimizer, loss_func, batch_size=50, loss_buffer_size=100,\n", 252 | " epoch=e, perf_file=join('perf','adding100_lstm_1.csv'),\n", 253 | " prog_bar=progress_bar)" 254 | ] 255 | }, 256 | { 257 | "cell_type": "markdown", 258 | "metadata": {}, 259 | "source": [ 260 | "# T = 500" 261 | ] 262 | }, 263 | { 264 | "cell_type": "code", 265 | "execution_count": null, 266 | "metadata": { 267 | "ExecuteTime": { 268 | "end_time": "2020-11-17T18:59:57.170037Z", 269 | "start_time": "2020-11-17T18:59:57.165816Z" 270 | } 271 | }, 272 | "outputs": [], 273 | "source": [ 274 | "lstm_params = dict(input_size=2,\n", 275 | " hidden_size=128, \n", 276 | " num_layers=1)\n", 277 | "model = LSTM_Predictor(1, lstm_params=lstm_params).cuda()\n", 278 | "\n", 279 | "tot_weights = 0\n", 280 | "for p in model.parameters():\n", 281 | " tot_weights += p.numel()\n", 282 | "print(\"Total Weights:\", tot_weights)\n", 283 | "print(model)" 284 | ] 285 | }, 286 | { 287 | "cell_type": "code", 288 | "execution_count": null, 289 | "metadata": { 290 | "ExecuteTime": { 291 | "end_time": "2020-11-17T19:06:41.299334Z", 292 | "start_time": "2020-11-17T19:00:01.467672Z" 293 | } 294 | }, 295 | "outputs": [], 296 | "source": [ 297 | "seq_length=500\n", 298 | "\n", 299 | "loss_func = nn.MSELoss()\n", 300 | "optimizer = torch.optim.Adam(model.parameters())\n", 301 | "epochs = 1\n", 302 | "progress_bar = tqdm(range(int(epochs)), bar_format='{l_bar}{bar:5}{r_bar}{bar:-5b}')\n", 303 | "for e in progress_bar:\n", 304 | " train(model, ttype, seq_length,\n", 305 | " optimizer, loss_func, batch_size=50, loss_buffer_size=100,\n", 306 | " epoch=e, perf_file=join('perf','adding500_lstm_3.csv'),\n", 307 | " prog_bar=progress_bar)" 308 | ] 309 | }, 310 | { 311 | "cell_type": "markdown", 312 | "metadata": {}, 313 | "source": [ 314 | "# T = 2000" 315 | ] 316 | }, 317 | { 318 | "cell_type": "code", 319 | "execution_count": null, 320 | "metadata": { 321 | "ExecuteTime": { 322 | "end_time": "2020-11-17T19:09:55.459792Z", 323 | "start_time": "2020-11-17T19:09:55.454074Z" 324 | } 325 | }, 326 | "outputs": [], 327 | "source": [ 328 | "lstm_params = dict(input_size=2,\n", 329 | " hidden_size=128, \n", 330 | " num_layers=1)\n", 331 | "model = LSTM_Predictor(1, lstm_params=lstm_params).cuda()\n", 332 | "\n", 333 | "tot_weights = 0\n", 334 | "for p in model.parameters():\n", 335 | " tot_weights += p.numel()\n", 336 | "print(\"Total Weights:\", tot_weights)\n", 337 | "print(model)" 338 | ] 339 | }, 340 | { 341 | "cell_type": "code", 342 | "execution_count": null, 343 | "metadata": { 344 | "ExecuteTime": { 345 | "end_time": "2020-11-17T19:33:57.545928Z", 346 | "start_time": "2020-11-17T19:10:03.115405Z" 347 | } 348 | }, 349 | "outputs": [], 350 | "source": [ 351 | "seq_length=2000\n", 352 | "\n", 353 | "loss_func = nn.MSELoss()\n", 354 | "optimizer = torch.optim.Adam(model.parameters())\n", 355 | "epochs = 1\n", 356 | "progress_bar = tqdm(range(int(epochs)), bar_format='{l_bar}{bar:5}{r_bar}{bar:-5b}')\n", 357 | "for e in progress_bar:\n", 358 | " train(model, ttype, seq_length,\n", 359 | " optimizer, loss_func, batch_size=50, loss_buffer_size=100,\n", 360 | " epoch=e, perf_file=join('perf','adding2000_lstm_2.csv'),\n", 361 | " prog_bar=progress_bar)" 362 | ] 363 | }, 364 | { 365 | "cell_type": "markdown", 366 | "metadata": {}, 367 | "source": [ 368 | "# T = 5000" 369 | ] 370 | }, 371 | { 372 | "cell_type": "code", 373 | "execution_count": null, 374 | "metadata": { 375 | "ExecuteTime": { 376 | "end_time": "2020-11-17T19:33:57.551416Z", 377 | "start_time": "2020-11-17T19:33:57.546926Z" 378 | } 379 | }, 380 | "outputs": [], 381 | "source": [ 382 | "lstm_params = dict(input_size=2,\n", 383 | " hidden_size=128, \n", 384 | " num_layers=1)\n", 385 | "model = LSTM_Predictor(1, lstm_params=lstm_params).cuda()\n", 386 | "\n", 387 | "tot_weights = 0\n", 388 | "for p in model.parameters():\n", 389 | " tot_weights += p.numel()\n", 390 | "print(\"Total Weights:\", tot_weights)\n", 391 | "print(model)" 392 | ] 393 | }, 394 | { 395 | "cell_type": "code", 396 | "execution_count": null, 397 | "metadata": { 398 | "ExecuteTime": { 399 | "end_time": "2020-11-17T20:41:33.252877Z", 400 | "start_time": "2020-11-17T19:33:57.552380Z" 401 | } 402 | }, 403 | "outputs": [], 404 | "source": [ 405 | "seq_length=5000\n", 406 | "\n", 407 | "loss_func = nn.MSELoss()\n", 408 | "optimizer = torch.optim.Adam(model.parameters())\n", 409 | "epochs = 1\n", 410 | "progress_bar = tqdm(range(int(epochs)), bar_format='{l_bar}{bar:5}{r_bar}{bar:-5b}')\n", 411 | "for e in progress_bar:\n", 412 | " train(model, ttype, seq_length,\n", 413 | " optimizer, loss_func, batch_size=50, loss_buffer_size=100,\n", 414 | " epoch=e, perf_file=join('perf','adding5000_lstm_1.csv'),\n", 415 | " prog_bar=progress_bar)" 416 | ] 417 | }, 418 | { 419 | "cell_type": "code", 420 | "execution_count": null, 421 | "metadata": {}, 422 | "outputs": [], 423 | "source": [] 424 | } 425 | ], 426 | "metadata": { 427 | "kernelspec": { 428 | "display_name": "Python 3", 429 | "language": "python", 430 | "name": "python3" 431 | }, 432 | "language_info": { 433 | "codemirror_mode": { 434 | "name": "ipython", 435 | "version": 3 436 | }, 437 | "file_extension": ".py", 438 | "mimetype": "text/x-python", 439 | "name": "python", 440 | "nbconvert_exporter": "python", 441 | "pygments_lexer": "ipython3", 442 | "version": "3.6.10" 443 | }, 444 | "toc": { 445 | "nav_menu": { 446 | "height": "141px", 447 | "width": "160px" 448 | }, 449 | "number_sections": true, 450 | "sideBar": true, 451 | "skip_h1_title": false, 452 | "title_cell": "Table of Contents", 453 | "title_sidebar": "Contents", 454 | "toc_cell": false, 455 | "toc_position": {}, 456 | "toc_section_display": true, 457 | "toc_window_display": false 458 | } 459 | }, 460 | "nbformat": 4, 461 | "nbformat_minor": 4 462 | } 463 | -------------------------------------------------------------------------------- /experiments/AddingProblem/addingproblem-LMU.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": null, 6 | "metadata": { 7 | "ExecuteTime": { 8 | "end_time": "2021-03-11T17:02:45.247984Z", 9 | "start_time": "2021-03-11T17:02:45.076329Z" 10 | } 11 | }, 12 | "outputs": [], 13 | "source": [ 14 | "%matplotlib inline" 15 | ] 16 | }, 17 | { 18 | "cell_type": "code", 19 | "execution_count": null, 20 | "metadata": { 21 | "ExecuteTime": { 22 | "end_time": "2021-03-11T17:02:46.318135Z", 23 | "start_time": "2021-03-11T17:02:45.266959Z" 24 | } 25 | }, 26 | "outputs": [], 27 | "source": [ 28 | "import matplotlib.pylab as plt\n", 29 | "import torch\n", 30 | "import numpy as np\n", 31 | "import seaborn as sn\n", 32 | "sn.set_context(\"poster\")\n", 33 | "import os\n", 34 | "import torch\n", 35 | "import torch.nn as nn\n", 36 | "import torch.nn.functional as F\n", 37 | "from torch.nn.utils import weight_norm\n", 38 | "from torchvision import transforms, datasets\n", 39 | "\n", 40 | "#from lmu import LegendreMemoryUnit\n", 41 | "from LMU import LegendreMemoryUnit\n", 42 | "import numpy as np\n", 43 | "import scipy\n", 44 | "import scipy.stats as st\n", 45 | "import scipy.special\n", 46 | "import scipy.signal\n", 47 | "import scipy.interpolate\n", 48 | "\n", 49 | "import pandas as pd\n", 50 | "\n", 51 | "from os.path import join\n", 52 | "import random\n", 53 | "from csv import DictWriter\n", 54 | "\n", 55 | "from tqdm.notebook import tqdm\n", 56 | "import pickle\n", 57 | "# if gpu is to be used\n", 58 | "use_cuda = torch.cuda.is_available()\n", 59 | "print(use_cuda)\n", 60 | "FloatTensor = torch.cuda.FloatTensor if use_cuda else torch.FloatTensor\n", 61 | "DoubleTensor = torch.cuda.DoubleTensor if use_cuda else torch.DoubleTensor\n", 62 | "IntTensor = torch.cuda.IntTensor if use_cuda else torch.IntTensor\n", 63 | "LongTensor = torch.cuda.LongTensor if use_cuda else torch.LongTensor\n", 64 | "ByteTensor = torch.cuda.ByteTensor if use_cuda else torch.ByteTensor\n", 65 | "ttype = FloatTensor" 66 | ] 67 | }, 68 | { 69 | "cell_type": "code", 70 | "execution_count": null, 71 | "metadata": { 72 | "ExecuteTime": { 73 | "end_time": "2021-03-11T17:02:46.322341Z", 74 | "start_time": "2021-03-11T17:02:46.319117Z" 75 | } 76 | }, 77 | "outputs": [], 78 | "source": [ 79 | "import torch\n", 80 | "import numpy as np\n", 81 | "\n", 82 | "def get_batch(batch_size, T, ttype):\n", 83 | " values = torch.rand(T, batch_size, requires_grad=False)\n", 84 | " indices = torch.zeros_like(values)\n", 85 | " half = int(T / 2)\n", 86 | " for i in range(batch_size):\n", 87 | " half_1 = np.random.randint(half)\n", 88 | " hals_2 = np.random.randint(half, T)\n", 89 | " indices[half_1, i] = 1\n", 90 | " indices[hals_2, i] = 1\n", 91 | "\n", 92 | " data = torch.stack((values, indices), dim=-1).type(ttype)\n", 93 | " targets = torch.mul(values, indices).sum(dim=0).type(ttype)\n", 94 | " return data, targets\n" 95 | ] 96 | }, 97 | { 98 | "cell_type": "code", 99 | "execution_count": null, 100 | "metadata": { 101 | "ExecuteTime": { 102 | "end_time": "2021-03-11T17:02:51.926506Z", 103 | "start_time": "2021-03-11T17:02:51.923167Z" 104 | } 105 | }, 106 | "outputs": [], 107 | "source": [ 108 | "torch.manual_seed(1111)" 109 | ] 110 | }, 111 | { 112 | "cell_type": "code", 113 | "execution_count": null, 114 | "metadata": { 115 | "ExecuteTime": { 116 | "end_time": "2021-03-11T17:02:52.144538Z", 117 | "start_time": "2021-03-11T17:02:52.137566Z" 118 | } 119 | }, 120 | "outputs": [], 121 | "source": [ 122 | "def train(model, ttype, seq_length, optimizer, loss_func, \n", 123 | " epoch, perf_file, loss_buffer_size=20, batch_size=1, test_size=10,\n", 124 | " device='cuda', prog_bar=None):\n", 125 | " assert(loss_buffer_size%batch_size==0)\n", 126 | "\n", 127 | " losses = []\n", 128 | " perfs = []\n", 129 | " last_test_perf = 0\n", 130 | " for batch_idx in range(20000):\n", 131 | " model.train()\n", 132 | " sig, target = get_batch(batch_size, seq_length, ttype=ttype)\n", 133 | " \n", 134 | " target = target.unsqueeze(1)\n", 135 | " optimizer.zero_grad()\n", 136 | " out = model(sig.transpose(1,0))\n", 137 | " loss = loss_func(out[:, -1],\n", 138 | " target)\n", 139 | " \n", 140 | " loss.backward()\n", 141 | " optimizer.step()\n", 142 | "\n", 143 | " losses.append(loss.detach().cpu().numpy())\n", 144 | " losses = losses[-loss_buffer_size:]\n", 145 | " if not (prog_bar is None):\n", 146 | " # Update progress_bar\n", 147 | " s = \"{}:{} Loss: {:.8f}\"\n", 148 | " format_list = [e, int(batch_idx/(50/batch_size)), np.mean(losses)] \n", 149 | " s = s.format(*format_list)\n", 150 | " prog_bar.set_description(s)\n", 151 | " if ((batch_idx*batch_size)%loss_buffer_size == 0) & (batch_idx != 0):\n", 152 | " loss_track = {}\n", 153 | " #last_test_perf = test_norm(model, 'cuda', test_sig, test_class,\n", 154 | " # batch_size=test_size, \n", 155 | " # )\n", 156 | " loss_track['avg_loss'] = np.mean(losses)\n", 157 | " #loss_track['last_test'] = last_test_perf\n", 158 | " loss_track['epoch'] = epoch\n", 159 | " loss_track['batch_idx'] = batch_idx\n", 160 | " with open(perf_file, 'a+') as fp:\n", 161 | " csv_writer = DictWriter(fp, fieldnames=list(loss_track.keys()))\n", 162 | " if fp.tell() == 0:\n", 163 | " csv_writer.writeheader()\n", 164 | " csv_writer.writerow(loss_track)\n", 165 | " fp.flush()\n", 166 | "def test_norm(model, device, seq_length, loss_func, batch_size=100):\n", 167 | " model.eval()\n", 168 | " correct = 0\n", 169 | " count = 0\n", 170 | " with torch.no_grad():\n", 171 | " sig, target = get_batch(batch_size, seq_length, ttype=ttype)\n", 172 | " target = target.unsqueeze(1)\n", 173 | " out = model(sig.transpose(1,0))\n", 174 | " loss = loss_func(out[:, -1],\n", 175 | " target)\n", 176 | " return loss" 177 | ] 178 | }, 179 | { 180 | "cell_type": "code", 181 | "execution_count": null, 182 | "metadata": { 183 | "ExecuteTime": { 184 | "end_time": "2021-03-11T17:02:52.334043Z", 185 | "start_time": "2021-03-11T17:02:52.330865Z" 186 | } 187 | }, 188 | "outputs": [], 189 | "source": [ 190 | "class LMUModel(nn.Module):\n", 191 | " def __init__(self, n_out, layer_params):\n", 192 | " super(LMUModel, self).__init__()\n", 193 | " self.layers = nn.ModuleList([LegendreMemoryUnit(**layer_params[i])\n", 194 | " for i in range(len(layer_params))])\n", 195 | " self.dense = nn.Linear(layer_params[-1]['units'], n_out)\n", 196 | "\n", 197 | " \n", 198 | " def forward(self, x):\n", 199 | " for l in self.layers:\n", 200 | " x, _ = l(x) \n", 201 | " x = self.dense(x)\n", 202 | " return x" 203 | ] 204 | }, 205 | { 206 | "cell_type": "markdown", 207 | "metadata": {}, 208 | "source": [ 209 | "# T = 100" 210 | ] 211 | }, 212 | { 213 | "cell_type": "code", 214 | "execution_count": null, 215 | "metadata": { 216 | "ExecuteTime": { 217 | "end_time": "2021-03-11T17:31:48.855245Z", 218 | "start_time": "2021-03-11T17:31:48.375043Z" 219 | } 220 | }, 221 | "outputs": [], 222 | "source": [ 223 | "seq_length=100\n", 224 | "\n", 225 | "lmu_params = [dict(input_dim=2, units=25, order=1000, theta=5000),\n", 226 | " \n", 227 | " ]\n", 228 | "model = LMUModel(1, lmu_params).cuda()\n", 229 | "\n", 230 | "tot_weights = 0\n", 231 | "for p in model.parameters():\n", 232 | " tot_weights += p.numel()\n", 233 | "print(\"Total Weights:\", tot_weights)\n", 234 | "print(model)\n" 235 | ] 236 | }, 237 | { 238 | "cell_type": "code", 239 | "execution_count": null, 240 | "metadata": { 241 | "ExecuteTime": { 242 | "end_time": "2021-03-11T17:11:05.011736Z", 243 | "start_time": "2021-03-11T17:03:10.989965Z" 244 | } 245 | }, 246 | "outputs": [], 247 | "source": [ 248 | "\n", 249 | "loss_func = nn.MSELoss()\n", 250 | "optimizer = torch.optim.Adam(model.parameters(), lr=5e-3)\n", 251 | "epochs = 1\n", 252 | "progress_bar = tqdm(range(int(epochs)), bar_format='{l_bar}{bar:5}{r_bar}{bar:-5b}')\n", 253 | "for e in progress_bar:\n", 254 | " train(model, ttype, seq_length,\n", 255 | " optimizer, loss_func, batch_size=50, loss_buffer_size=100,\n", 256 | " epoch=e, perf_file=join('perf','adding100_lmu_03_weird.csv'),\n", 257 | " prog_bar=progress_bar)" 258 | ] 259 | }, 260 | { 261 | "cell_type": "markdown", 262 | "metadata": {}, 263 | "source": [ 264 | "# T = 500" 265 | ] 266 | }, 267 | { 268 | "cell_type": "code", 269 | "execution_count": null, 270 | "metadata": { 271 | "ExecuteTime": { 272 | "end_time": "2021-03-11T18:50:03.354758Z", 273 | "start_time": "2021-03-11T18:50:02.981515Z" 274 | } 275 | }, 276 | "outputs": [], 277 | "source": [ 278 | "seq_length=500\n", 279 | "\n", 280 | "lmu_params = [dict(input_dim=2, units=25, order=1000, theta=100),\n", 281 | " \n", 282 | " ]\n", 283 | "model = LMUModel(1, lmu_params).cuda()\n", 284 | "\n", 285 | "tot_weights = 0\n", 286 | "for p in model.parameters():\n", 287 | " tot_weights += p.numel()\n", 288 | "print(\"Total Weights:\", tot_weights)\n", 289 | "print(model)\n" 290 | ] 291 | }, 292 | { 293 | "cell_type": "code", 294 | "execution_count": null, 295 | "metadata": { 296 | "ExecuteTime": { 297 | "end_time": "2021-03-11T20:16:17.551864Z", 298 | "start_time": "2021-03-11T18:50:08.232460Z" 299 | } 300 | }, 301 | "outputs": [], 302 | "source": [ 303 | "\n", 304 | "loss_func = nn.MSELoss()\n", 305 | "optimizer = torch.optim.AdamW(model.parameters())\n", 306 | "epochs = 1\n", 307 | "progress_bar = tqdm(range(int(epochs)), bar_format='{l_bar}{bar:5}{r_bar}{bar:-5b}')\n", 308 | "for e in progress_bar:\n", 309 | " train(model, ttype, seq_length,\n", 310 | " optimizer, loss_func, batch_size=50, loss_buffer_size=100,\n", 311 | " epoch=e, perf_file=join('perf','adding500_lmu_04_weird.csv'),\n", 312 | " prog_bar=progress_bar)" 313 | ] 314 | }, 315 | { 316 | "cell_type": "markdown", 317 | "metadata": {}, 318 | "source": [ 319 | "# T = 2000" 320 | ] 321 | }, 322 | { 323 | "cell_type": "code", 324 | "execution_count": null, 325 | "metadata": { 326 | "ExecuteTime": { 327 | "end_time": "2021-02-22T20:50:34.258707Z", 328 | "start_time": "2021-02-22T20:50:34.234459Z" 329 | } 330 | }, 331 | "outputs": [], 332 | "source": [ 333 | "seq_length=2000\n", 334 | "\n", 335 | "lmu_params = [dict(input_dim=2, units=25, order=100, theta=5000),\n", 336 | " \n", 337 | " ]\n", 338 | "model = LMUModel(1, lmu_params).cuda()\n", 339 | "\n", 340 | "tot_weights = 0\n", 341 | "for p in model.parameters():\n", 342 | " tot_weights += p.numel()\n", 343 | "print(\"Total Weights:\", tot_weights)\n", 344 | "print(model)" 345 | ] 346 | }, 347 | { 348 | "cell_type": "code", 349 | "execution_count": null, 350 | "metadata": { 351 | "ExecuteTime": { 352 | "end_time": "2021-02-23T02:23:42.992553Z", 353 | "start_time": "2021-02-22T20:50:34.259745Z" 354 | } 355 | }, 356 | "outputs": [], 357 | "source": [ 358 | "\n", 359 | "loss_func = nn.MSELoss()\n", 360 | "optimizer = torch.optim.Adam(model.parameters(), lr=5e-3)\n", 361 | "epochs = 1\n", 362 | "progress_bar = tqdm(range(int(epochs)), bar_format='{l_bar}{bar:5}{r_bar}{bar:-5b}')\n", 363 | "for e in progress_bar:\n", 364 | " train(model, ttype, seq_length,\n", 365 | " optimizer, loss_func, batch_size=50, loss_buffer_size=100,\n", 366 | " epoch=e, perf_file=join('perf','adding2000_lmu_01_weird.csv'),\n", 367 | " prog_bar=progress_bar)" 368 | ] 369 | }, 370 | { 371 | "cell_type": "markdown", 372 | "metadata": {}, 373 | "source": [ 374 | "# T = 5000" 375 | ] 376 | }, 377 | { 378 | "cell_type": "code", 379 | "execution_count": null, 380 | "metadata": { 381 | "ExecuteTime": { 382 | "end_time": "2021-02-23T02:23:42.998412Z", 383 | "start_time": "2021-02-23T02:23:42.993559Z" 384 | } 385 | }, 386 | "outputs": [], 387 | "source": [ 388 | "seq_length=5000\n", 389 | "\n", 390 | "lmu_params = [dict(input_dim=2, units=25, order=100, theta=5000),\n", 391 | " \n", 392 | " ]\n", 393 | "model = LMUModel(1, lmu_params).cuda()\n", 394 | "\n", 395 | "tot_weights = 0\n", 396 | "for p in model.parameters():\n", 397 | " tot_weights += p.numel()\n", 398 | "print(\"Total Weights:\", tot_weights)\n", 399 | "print(model)" 400 | ] 401 | }, 402 | { 403 | "cell_type": "code", 404 | "execution_count": null, 405 | "metadata": { 406 | "ExecuteTime": { 407 | "end_time": "2021-02-23T02:58:44.760408Z", 408 | "start_time": "2021-02-23T02:23:42.999179Z" 409 | } 410 | }, 411 | "outputs": [], 412 | "source": [ 413 | "\n", 414 | "loss_func = nn.MSELoss()\n", 415 | "optimizer = torch.optim.Adam(model.parameters(), lr=5e-3)\n", 416 | "epochs = 1\n", 417 | "progress_bar = tqdm(range(int(epochs)), bar_format='{l_bar}{bar:5}{r_bar}{bar:-5b}')\n", 418 | "for e in progress_bar:\n", 419 | " train(model, ttype, seq_length,\n", 420 | " optimizer, loss_func, batch_size=50, loss_buffer_size=100,\n", 421 | " epoch=e, perf_file=join('perf','adding5000_lmu_01_weird.csv'),\n", 422 | " prog_bar=progress_bar)" 423 | ] 424 | } 425 | ], 426 | "metadata": { 427 | "kernelspec": { 428 | "display_name": "Python 3", 429 | "language": "python", 430 | "name": "python3" 431 | }, 432 | "language_info": { 433 | "codemirror_mode": { 434 | "name": "ipython", 435 | "version": 3 436 | }, 437 | "file_extension": ".py", 438 | "mimetype": "text/x-python", 439 | "name": "python", 440 | "nbconvert_exporter": "python", 441 | "pygments_lexer": "ipython3", 442 | "version": "3.6.10" 443 | }, 444 | "toc": { 445 | "nav_menu": { 446 | "height": "141px", 447 | "width": "160px" 448 | }, 449 | "number_sections": true, 450 | "sideBar": true, 451 | "skip_h1_title": false, 452 | "title_cell": "Table of Contents", 453 | "title_sidebar": "Contents", 454 | "toc_cell": false, 455 | "toc_position": {}, 456 | "toc_section_display": true, 457 | "toc_window_display": false 458 | } 459 | }, 460 | "nbformat": 4, 461 | "nbformat_minor": 4 462 | } 463 | -------------------------------------------------------------------------------- /experiments/MackeyGlass/MackeyGlass-DeepSITH-increasing_tau.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": null, 6 | "metadata": { 7 | "ExecuteTime": { 8 | "end_time": "2021-04-16T15:10:34.280995Z", 9 | "start_time": "2021-04-16T15:10:34.099264Z" 10 | } 11 | }, 12 | "outputs": [], 13 | "source": [ 14 | "%matplotlib inline\n" 15 | ] 16 | }, 17 | { 18 | "cell_type": "code", 19 | "execution_count": null, 20 | "metadata": { 21 | "ExecuteTime": { 22 | "end_time": "2021-04-16T15:10:34.999978Z", 23 | "start_time": "2021-04-16T15:10:34.333348Z" 24 | } 25 | }, 26 | "outputs": [], 27 | "source": [ 28 | "import torch\n", 29 | "import matplotlib.pyplot as plt\n", 30 | "import torchvision\n", 31 | "import nengolib\n", 32 | "import numpy as np\n", 33 | "import torch.nn as nn\n", 34 | "import torch.nn.functional as F\n", 35 | "from tqdm import tqdm_notebook\n", 36 | "from deepsith import DeepSITH\n", 37 | "import PIL\n", 38 | "from torch.nn.utils import weight_norm\n", 39 | "\n", 40 | "from os.path import join\n", 41 | "import scipy.special\n", 42 | "import pandas as pd\n", 43 | "import seaborn as sn\n", 44 | "import scipy\n", 45 | "from scipy.spatial.distance import euclidean\n", 46 | "from scipy.interpolate import interp1d\n", 47 | "from tqdm.notebook import tqdm\n", 48 | "import random\n", 49 | "from csv import DictWriter\n", 50 | "# if gpu is to be used\n", 51 | "use_cuda = torch.cuda.is_available()\n", 52 | "FloatTensor = torch.cuda.FloatTensor if use_cuda else torch.FloatTensor\n", 53 | "DoubleTensor = torch.cuda.DoubleTensor if use_cuda else torch.DoubleTensor\n", 54 | "\n", 55 | "IntTensor = torch.cuda.IntTensor if use_cuda else torch.IntTensor\n", 56 | "LongTensor = torch.cuda.LongTensor if use_cuda else torch.LongTensor\n", 57 | "ByteTensor = torch.cuda.ByteTensor if use_cuda else torch.ByteTensor\n", 58 | "ttype = FloatTensor\n", 59 | "\n", 60 | "import seaborn as sns\n", 61 | "print(use_cuda)\n", 62 | "import pickle\n" 63 | ] 64 | }, 65 | { 66 | "cell_type": "markdown", 67 | "metadata": {}, 68 | "source": [ 69 | "# Load Stimuli" 70 | ] 71 | }, 72 | { 73 | "cell_type": "code", 74 | "execution_count": null, 75 | "metadata": { 76 | "ExecuteTime": { 77 | "end_time": "2021-04-16T15:10:36.859945Z", 78 | "start_time": "2021-04-16T15:10:36.852930Z" 79 | } 80 | }, 81 | "outputs": [], 82 | "source": [ 83 | "import collections\n", 84 | "\n", 85 | "def mackey_glass(sample_len=1000, tau=17, delta_t=10, seed=None, n_samples=1):\n", 86 | " # Adapted from https://github.com/mila-iqia/summerschool2015/blob/master/rnn_tutorial/synthetic.py\n", 87 | " '''\n", 88 | " mackey_glass(sample_len=1000, tau=17, seed = None, n_samples = 1) -> input\n", 89 | " Generate the Mackey Glass time-series. Parameters are:\n", 90 | " - sample_len: length of the time-series in timesteps. Default is 1000.\n", 91 | " - tau: delay of the MG - system. Commonly used values are tau=17 (mild \n", 92 | " chaos) and tau=30 (moderate chaos). Default is 17.\n", 93 | " - seed: to seed the random generator, can be used to generate the same\n", 94 | " timeseries at each invocation.\n", 95 | " - n_samples : number of samples to generate\n", 96 | " '''\n", 97 | " history_len = tau * delta_t \n", 98 | " # Initial conditions for the history of the system\n", 99 | " timeseries = 1.2\n", 100 | " \n", 101 | " if seed is not None:\n", 102 | " np.random.seed(seed)\n", 103 | "\n", 104 | " samples = []\n", 105 | "\n", 106 | " for _ in range(n_samples):\n", 107 | " history = collections.deque(1.2 * np.ones(history_len) + 0.2 * \\\n", 108 | " (np.random.rand(history_len) - 0.5))\n", 109 | " # Preallocate the array for the time-series\n", 110 | " inp = np.zeros((sample_len,1))\n", 111 | " \n", 112 | " for timestep in range(sample_len):\n", 113 | " for _ in range(delta_t):\n", 114 | " xtau = history.popleft()\n", 115 | " history.append(timeseries)\n", 116 | " timeseries = history[-1] + (0.2 * xtau / (1.0 + xtau ** 10) - \\\n", 117 | " 0.1 * history[-1]) / delta_t\n", 118 | " inp[timestep] = timeseries\n", 119 | " \n", 120 | " # Squash timeseries through tanh\n", 121 | " inp = np.tanh(inp - 1)\n", 122 | " samples.append(inp)\n", 123 | " return samples\n", 124 | "\n", 125 | "\n", 126 | "def generate_data(n_batches, length, split=0.5, seed=0,\n", 127 | " predict_length=15, tau=17, washout=100, delta_t=1,\n", 128 | " center=True):\n", 129 | " X = np.asarray(mackey_glass(\n", 130 | " sample_len=length+predict_length+washout, tau=tau,\n", 131 | " seed=seed, n_samples=n_batches))\n", 132 | " X = X[:, washout:, :]\n", 133 | " cutoff = int(split*n_batches)\n", 134 | " if center:\n", 135 | " X -= np.mean(X) # global mean over all batches, approx -0.066\n", 136 | " Y = X[:, predict_length:, :]\n", 137 | " X = X[:, :-predict_length, :]\n", 138 | " assert X.shape == Y.shape\n", 139 | " return ((X[:cutoff], Y[:cutoff]),\n", 140 | " (X[cutoff:], Y[cutoff:]))" 141 | ] 142 | }, 143 | { 144 | "cell_type": "code", 145 | "execution_count": null, 146 | "metadata": { 147 | "ExecuteTime": { 148 | "end_time": "2021-04-16T15:10:38.927028Z", 149 | "start_time": "2021-04-16T15:10:37.026625Z" 150 | } 151 | }, 152 | "outputs": [], 153 | "source": [ 154 | "(train_X, train_Y), (test_X, test_Y) = generate_data(2, 5000)\n", 155 | "\n", 156 | "dataset = torch.utils.data.TensorDataset(torch.Tensor(train_X).cuda(), torch.Tensor(train_Y).cuda())\n", 157 | "dataset = torch.utils.data.DataLoader(dataset, batch_size=32, shuffle=True)\n", 158 | "\n", 159 | "dataset_valid = torch.utils.data.TensorDataset(torch.Tensor(test_X).cuda(), torch.Tensor(test_Y).cuda())\n", 160 | "dataset_valid = torch.utils.data.DataLoader(dataset_valid, batch_size=64, shuffle=False)\n" 161 | ] 162 | }, 163 | { 164 | "cell_type": "markdown", 165 | "metadata": {}, 166 | "source": [ 167 | "## Setup for Model" 168 | ] 169 | }, 170 | { 171 | "cell_type": "code", 172 | "execution_count": null, 173 | "metadata": { 174 | "ExecuteTime": { 175 | "end_time": "2021-04-16T15:10:39.071373Z", 176 | "start_time": "2021-04-16T15:10:39.063007Z" 177 | } 178 | }, 179 | "outputs": [], 180 | "source": [ 181 | "\n", 182 | "def train(model, ttype, train_loader, test_loader, optimizer, loss_func, epoch, perf_file,\n", 183 | " loss_buffer_size=800, batch_size=4, device='cuda',\n", 184 | " prog_bar=None, tau=17, pred_len=15):\n", 185 | " \n", 186 | " assert(loss_buffer_size%batch_size==0)\n", 187 | " \n", 188 | " losses = []\n", 189 | " last_test_perf = 0\n", 190 | " best_test_perf = -1\n", 191 | " \n", 192 | " for batch_idx, (data, target) in enumerate(train_loader):\n", 193 | " model.train()\n", 194 | " data = data.to(device).view(data.shape[0],1,1,-1)\n", 195 | " target = target.to(device)\n", 196 | " optimizer.zero_grad()\n", 197 | " out = model(data)\n", 198 | " loss = loss_func(out,\n", 199 | " target)\n", 200 | "\n", 201 | " loss.backward()\n", 202 | " optimizer.step()\n", 203 | "\n", 204 | " losses.append(loss.detach().cpu().numpy())\n", 205 | " losses = losses[int(-loss_buffer_size/batch_size):]\n", 206 | " \n", 207 | " if ((batch_idx*batch_size)%loss_buffer_size == 0):\n", 208 | " loss_track = {}\n", 209 | " last_test_perf = test_model(model, 'cuda', test_loader, \n", 210 | " )\n", 211 | " loss_track['avg_loss'] = np.mean(losses)\n", 212 | " loss_track['last_test'] = last_test_perf\n", 213 | " loss_track['epoch'] = epoch\n", 214 | " loss_track['batch_idx'] = batch_idx\n", 215 | " loss_track['tau'] = tau\n", 216 | " loss_track['pred_len'] = pred_len\n", 217 | " with open(perf_file, 'a+') as fp:\n", 218 | " csv_writer = DictWriter(fp, fieldnames=list(loss_track.keys()))\n", 219 | " if fp.tell() == 0:\n", 220 | " csv_writer.writeheader()\n", 221 | " csv_writer.writerow(loss_track)\n", 222 | " fp.flush()\n", 223 | " if best_test_perf > last_test_perf:\n", 224 | " torch.save(model.state_dict(), perf_file[:-4]+\".pt\")\n", 225 | " best_test_perf = last_test_perf\n", 226 | " if not (prog_bar is None):\n", 227 | " # Update progress_bar\n", 228 | " s = \"{}:{} Loss: {:.4f}, valid: {:.4f}\"\n", 229 | " format_list = [e,batch_idx*batch_size, np.mean(losses), \n", 230 | " last_test_perf] \n", 231 | " s = s.format(*format_list)\n", 232 | " prog_bar.set_description(s)\n", 233 | " \n", 234 | "def test_model(model, device, test_loader):\n", 235 | " # Test the Model\n", 236 | " nrmsd = []\n", 237 | " with torch.no_grad():\n", 238 | " for x, y in test_loader:\n", 239 | " data = x.to(device).view(x.shape[0],1,1,-1)\n", 240 | " target = y.to(device)\n", 241 | " optimizer.zero_grad()\n", 242 | " out = model(data)\n", 243 | " nrmsd.append(nengolib.signal.nrmse(out.detach().cpu().numpy().flatten(), \n", 244 | " target=target.detach().cpu().numpy().flatten()))\n", 245 | "\n", 246 | " perf = np.array(nrmsd).mean()\n", 247 | " return perf" 248 | ] 249 | }, 250 | { 251 | "cell_type": "code", 252 | "execution_count": null, 253 | "metadata": { 254 | "ExecuteTime": { 255 | "end_time": "2021-04-16T15:10:39.230440Z", 256 | "start_time": "2021-04-16T15:10:39.227260Z" 257 | } 258 | }, 259 | "outputs": [], 260 | "source": [ 261 | "\n", 262 | " \n", 263 | "class DeepSITH_Tracker(nn.Module):\n", 264 | " def __init__(self, out, layer_params, dropout=.5):\n", 265 | " super(DeepSITH_Tracker, self).__init__()\n", 266 | " last_hidden = layer_params[-1]['hidden_size']\n", 267 | " self.hs = DeepSITH(layer_params=layer_params, dropout=dropout)\n", 268 | " self.to_out = nn.Linear(last_hidden, out)\n", 269 | " def forward(self, inp):\n", 270 | " x = self.hs(inp)\n", 271 | " #x = torch.tanh(self.to_out(x))\n", 272 | " x = self.to_out(x)\n", 273 | " return x" 274 | ] 275 | }, 276 | { 277 | "cell_type": "code", 278 | "execution_count": null, 279 | "metadata": { 280 | "ExecuteTime": { 281 | "end_time": "2021-04-16T15:11:06.128309Z", 282 | "start_time": "2021-04-16T15:10:59.478750Z" 283 | } 284 | }, 285 | "outputs": [], 286 | "source": [ 287 | "start_tau = 17\n", 288 | "start_pd = 15\n", 289 | "diffs = [1, 2, 3, 4, 5]\n", 290 | "for diff in diffs:\n", 291 | " print(start_tau*diff, start_pd*diff)\n", 292 | " (train_X, train_Y), (test_X, test_Y) = generate_data(128, 5000, \n", 293 | " tau=start_tau*diff, \n", 294 | " predict_length=start_pd*diff)\n", 295 | "\n", 296 | " dataset = torch.utils.data.TensorDataset(torch.Tensor(train_X).cuda(), torch.Tensor(train_Y).cuda())\n", 297 | " dataset = torch.utils.data.DataLoader(dataset, batch_size=32, shuffle=True)\n", 298 | "\n", 299 | " dataset_valid = torch.utils.data.TensorDataset(torch.Tensor(test_X).cuda(), torch.Tensor(test_Y).cuda())\n", 300 | " dataset_valid = torch.utils.data.DataLoader(dataset_valid, batch_size=64, shuffle=False)\n", 301 | " \n", 302 | " sith_params1 = {\"in_features\":1, \n", 303 | " \"tau_min\":1, \"tau_max\":25.0, \n", 304 | " \"k\":15, 'dt':1,\n", 305 | " \"ntau\":8, 'g':0.0, \n", 306 | " \"ttype\":ttype, 'batch_norm':False,\n", 307 | " \"hidden_size\":25, \"act_func\":nn.ReLU()}\n", 308 | " sith_params2 = {\"in_features\":sith_params1['hidden_size'], \n", 309 | " \"tau_min\":1, \"tau_max\":50.0, \n", 310 | " \"k\":8, 'dt':1,\n", 311 | " \"ntau\":8, 'g':0.0, \n", 312 | " \"ttype\":ttype, 'batch_norm':False,\n", 313 | " \"hidden_size\":25, \"act_func\":nn.ReLU()}\n", 314 | " sith_params3 = {\"in_features\":sith_params2['hidden_size'], \n", 315 | " \"tau_min\":1, \"tau_max\":150.0, 'buff_max':600, \n", 316 | " \"k\":4, 'dt':1,\n", 317 | " \"ntau\":8, 'g':0.0, \n", 318 | " \"ttype\":ttype, 'batch_norm':False,\n", 319 | " \"hidden_size\":25, \"act_func\":nn.ReLU()}\n", 320 | " layer_params = [sith_params1, sith_params2, sith_params3]\n", 321 | "\n", 322 | " model = DeepSITH_Tracker(out=1,\n", 323 | " layer_params=layer_params, \n", 324 | " dropout=0.).cuda()\n", 325 | " optimizer = torch.optim.Adam(model.parameters(), lr=.004)\n", 326 | " loss_func = nn.MSELoss()\n", 327 | "\n", 328 | " tot = 0\n", 329 | " for p in model.parameters():\n", 330 | " tot += p.numel()\n", 331 | " print(\"tot_weights\", tot)\n", 332 | " #print(model)\n", 333 | " \n", 334 | " epochs = 1000\n", 335 | " batch_size = 32\n", 336 | " progress_bar = tqdm(range(int(epochs)), bar_format='{l_bar}{bar:5}{r_bar}{bar:-5b}')\n", 337 | " last_perf = 1000\n", 338 | " for e in progress_bar:\n", 339 | " train(model, ttype, dataset, dataset_valid, \n", 340 | " optimizer, loss_func, batch_size=batch_size, loss_buffer_size=64,\n", 341 | " epoch=e, perf_file=join('perf','mackeyglass_deepsith_ratio_1.csv'),\n", 342 | " prog_bar=progress_bar, tau=start_tau*diff, pred_len=start_pd*diff)" 343 | ] 344 | } 345 | ], 346 | "metadata": { 347 | "kernelspec": { 348 | "display_name": "Python 3", 349 | "language": "python", 350 | "name": "python3" 351 | }, 352 | "language_info": { 353 | "codemirror_mode": { 354 | "name": "ipython", 355 | "version": 3 356 | }, 357 | "file_extension": ".py", 358 | "mimetype": "text/x-python", 359 | "name": "python", 360 | "nbconvert_exporter": "python", 361 | "pygments_lexer": "ipython3", 362 | "version": "3.6.10" 363 | }, 364 | "toc": { 365 | "nav_menu": {}, 366 | "number_sections": true, 367 | "sideBar": true, 368 | "skip_h1_title": false, 369 | "title_cell": "Table of Contents", 370 | "title_sidebar": "Contents", 371 | "toc_cell": false, 372 | "toc_position": {}, 373 | "toc_section_display": true, 374 | "toc_window_display": false 375 | } 376 | }, 377 | "nbformat": 4, 378 | "nbformat_minor": 4 379 | } 380 | -------------------------------------------------------------------------------- /experiments/MackeyGlass/MackeyGlass-coRNN-increasing_tau.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": null, 6 | "metadata": { 7 | "ExecuteTime": { 8 | "end_time": "2021-01-31T16:13:20.250852Z", 9 | "start_time": "2021-01-31T16:13:20.076923Z" 10 | } 11 | }, 12 | "outputs": [], 13 | "source": [ 14 | "%matplotlib inline" 15 | ] 16 | }, 17 | { 18 | "cell_type": "code", 19 | "execution_count": null, 20 | "metadata": { 21 | "ExecuteTime": { 22 | "end_time": "2021-01-31T16:13:20.997573Z", 23 | "start_time": "2021-01-31T16:13:20.311754Z" 24 | } 25 | }, 26 | "outputs": [], 27 | "source": [ 28 | "import torch\n", 29 | "import matplotlib.pyplot as plt\n", 30 | "import torchvision\n", 31 | "import numpy as np\n", 32 | "import torch.nn as nn\n", 33 | "import torch.nn.functional as F\n", 34 | "from tqdm import tqdm_notebook\n", 35 | "import PIL\n", 36 | "from torch.nn.utils import weight_norm\n", 37 | "from torch.autograd import Variable\n", 38 | "\n", 39 | "device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')\n", 40 | "\n", 41 | "from os.path import join\n", 42 | "import scipy.special\n", 43 | "import pandas as pd\n", 44 | "import seaborn as sn\n", 45 | "import scipy\n", 46 | "from scipy.spatial.distance import euclidean\n", 47 | "from scipy.interpolate import interp1d\n", 48 | "from tqdm.notebook import tqdm\n", 49 | "import random\n", 50 | "from csv import DictWriter\n", 51 | "import nengolib\n", 52 | "# if gpu is to be used\n", 53 | "use_cuda = torch.cuda.is_available()\n", 54 | "FloatTensor = torch.cuda.FloatTensor if use_cuda else torch.FloatTensor\n", 55 | "DoubleTensor = torch.cuda.DoubleTensor if use_cuda else torch.DoubleTensor\n", 56 | "\n", 57 | "IntTensor = torch.cuda.IntTensor if use_cuda else torch.IntTensor\n", 58 | "LongTensor = torch.cuda.LongTensor if use_cuda else torch.LongTensor\n", 59 | "ByteTensor = torch.cuda.ByteTensor if use_cuda else torch.ByteTensor\n", 60 | "ttype = FloatTensor\n", 61 | "\n", 62 | "import seaborn as sns\n", 63 | "print(use_cuda)\n", 64 | "import pickle\n" 65 | ] 66 | }, 67 | { 68 | "cell_type": "code", 69 | "execution_count": null, 70 | "metadata": { 71 | "ExecuteTime": { 72 | "end_time": "2021-01-31T16:13:21.000738Z", 73 | "start_time": "2021-01-31T16:13:20.998684Z" 74 | } 75 | }, 76 | "outputs": [], 77 | "source": [ 78 | "sn.set_context(\"poster\")" 79 | ] 80 | }, 81 | { 82 | "cell_type": "code", 83 | "execution_count": null, 84 | "metadata": { 85 | "ExecuteTime": { 86 | "end_time": "2021-01-31T16:13:21.016467Z", 87 | "start_time": "2021-01-31T16:13:21.001752Z" 88 | } 89 | }, 90 | "outputs": [], 91 | "source": [ 92 | "\n", 93 | "class coRNNCell(nn.Module):\n", 94 | " def __init__(self, n_inp, n_hid, dt, gamma, epsilon):\n", 95 | " super(coRNNCell, self).__init__()\n", 96 | " self.dt = dt\n", 97 | " self.gamma = gamma\n", 98 | " self.epsilon = epsilon\n", 99 | " self.i2h = nn.Linear(n_inp + n_hid + n_hid, n_hid)\n", 100 | "\n", 101 | " def forward(self,x,hy,hz):\n", 102 | " hz = hz + self.dt * (torch.tanh(self.i2h(torch.cat((x, hz, hy),1)))\n", 103 | " - self.gamma * hy - self.epsilon * hz)\n", 104 | " hy = hy + self.dt * hz\n", 105 | "\n", 106 | " return hy, hz\n", 107 | "\n", 108 | "class coRNN(nn.Module):\n", 109 | " def __init__(self, n_inp, n_hid, n_out, dt, gamma, epsilon):\n", 110 | " super(coRNN, self).__init__()\n", 111 | " self.n_hid = n_hid\n", 112 | " self.cell = coRNNCell(n_inp,n_hid,dt,gamma,epsilon)\n", 113 | " self.readout = nn.Linear(n_hid, n_out)\n", 114 | "\n", 115 | " def forward(self, x):\n", 116 | " outputs = []\n", 117 | " ## initialize hidden states\n", 118 | " hy = Variable(torch.zeros(x.size(1),self.n_hid)).to(device)\n", 119 | " hz = Variable(torch.zeros(x.size(1),self.n_hid)).to(device)\n", 120 | "\n", 121 | " for t in range(x.size(0)):\n", 122 | " hy, hz = self.cell(x[t],hy,hz)\n", 123 | " outputs.append(hy)\n", 124 | " outputs = torch.stack(outputs)\n", 125 | " \n", 126 | " output = self.readout(outputs)\n", 127 | "\n", 128 | " return output" 129 | ] 130 | }, 131 | { 132 | "cell_type": "markdown", 133 | "metadata": {}, 134 | "source": [ 135 | "# Load Stimuli" 136 | ] 137 | }, 138 | { 139 | "cell_type": "code", 140 | "execution_count": null, 141 | "metadata": { 142 | "ExecuteTime": { 143 | "end_time": "2021-01-31T16:13:25.543469Z", 144 | "start_time": "2021-01-31T16:13:25.536756Z" 145 | } 146 | }, 147 | "outputs": [], 148 | "source": [ 149 | "import collections\n", 150 | "\n", 151 | "def mackey_glass(sample_len=1000, tau=17, delta_t=10, seed=None, n_samples=1):\n", 152 | " # Adapted from https://github.com/mila-iqia/summerschool2015/blob/master/rnn_tutorial/synthetic.py\n", 153 | " '''\n", 154 | " mackey_glass(sample_len=1000, tau=17, seed = None, n_samples = 1) -> input\n", 155 | " Generate the Mackey Glass time-series. Parameters are:\n", 156 | " - sample_len: length of the time-series in timesteps. Default is 1000.\n", 157 | " - tau: delay of the MG - system. Commonly used values are tau=17 (mild \n", 158 | " chaos) and tau=30 (moderate chaos). Default is 17.\n", 159 | " - seed: to seed the random generator, can be used to generate the same\n", 160 | " timeseries at each invocation.\n", 161 | " - n_samples : number of samples to generate\n", 162 | " '''\n", 163 | " history_len = tau * delta_t \n", 164 | " # Initial conditions for the history of the system\n", 165 | " timeseries = 1.2\n", 166 | " \n", 167 | " if seed is not None:\n", 168 | " np.random.seed(seed)\n", 169 | "\n", 170 | " samples = []\n", 171 | "\n", 172 | " for _ in range(n_samples):\n", 173 | " history = collections.deque(1.2 * np.ones(history_len) + 0.2 * \\\n", 174 | " (np.random.rand(history_len) - 0.5))\n", 175 | " # Preallocate the array for the time-series\n", 176 | " inp = np.zeros((sample_len,1))\n", 177 | " \n", 178 | " for timestep in range(sample_len):\n", 179 | " for _ in range(delta_t):\n", 180 | " xtau = history.popleft()\n", 181 | " history.append(timeseries)\n", 182 | " timeseries = history[-1] + (0.2 * xtau / (1.0 + xtau ** 10) - \\\n", 183 | " 0.1 * history[-1]) / delta_t\n", 184 | " inp[timestep] = timeseries\n", 185 | " \n", 186 | " # Squash timeseries through tanh\n", 187 | " inp = np.tanh(inp - 1)\n", 188 | " samples.append(inp)\n", 189 | " return samples\n", 190 | "\n", 191 | "\n", 192 | "def generate_data(n_batches, length, split=0.5, seed=0,\n", 193 | " predict_length=15, tau=17, washout=100, delta_t=1,\n", 194 | " center=True):\n", 195 | " X = np.asarray(mackey_glass(\n", 196 | " sample_len=length+predict_length+washout, tau=tau,\n", 197 | " seed=seed, n_samples=n_batches))\n", 198 | " X = X[:, washout:, :]\n", 199 | " cutoff = int(split*n_batches)\n", 200 | " if center:\n", 201 | " X -= np.mean(X) # global mean over all batches, approx -0.066\n", 202 | " Y = X[:, predict_length:, :]\n", 203 | " X = X[:, :-predict_length, :]\n", 204 | " assert X.shape == Y.shape\n", 205 | " return ((X[:cutoff], Y[:cutoff]),\n", 206 | " (X[cutoff:], Y[cutoff:]))" 207 | ] 208 | }, 209 | { 210 | "cell_type": "code", 211 | "execution_count": null, 212 | "metadata": { 213 | "ExecuteTime": { 214 | "end_time": "2020-11-23T17:56:09.308282Z", 215 | "start_time": "2020-11-23T17:56:00.723134Z" 216 | } 217 | }, 218 | "outputs": [], 219 | "source": [ 220 | "(train_X, train_Y), (test_X, test_Y) = generate_data(128, 5000)\n", 221 | "dataset = torch.utils.data.TensorDataset(torch.Tensor(train_X).cuda(), torch.Tensor(train_Y).cuda())\n", 222 | "dataset = torch.utils.data.DataLoader(dataset, batch_size=32, shuffle=True)\n", 223 | "\n", 224 | "dataset_valid = torch.utils.data.TensorDataset(torch.Tensor(test_X).cuda(), torch.Tensor(test_Y).cuda())\n", 225 | "dataset_valid = torch.utils.data.DataLoader(dataset_valid, batch_size=64, shuffle=False)\n", 226 | "\n", 227 | "print(train_X.shape, train_Y.shape, test_X.shape)" 228 | ] 229 | }, 230 | { 231 | "cell_type": "markdown", 232 | "metadata": {}, 233 | "source": [ 234 | "## Setup for Model" 235 | ] 236 | }, 237 | { 238 | "cell_type": "code", 239 | "execution_count": null, 240 | "metadata": { 241 | "ExecuteTime": { 242 | "end_time": "2021-01-31T16:13:46.359075Z", 243 | "start_time": "2021-01-31T16:13:46.346791Z" 244 | } 245 | }, 246 | "outputs": [], 247 | "source": [ 248 | "\n", 249 | "def train(model, ttype, train_loader, test_loader, optimizer, loss_func, epoch, perf_file,\n", 250 | " loss_buffer_size=800, batch_size=4, device='cuda',\n", 251 | " prog_bar=None, tau=17, pred_len=15):\n", 252 | " \n", 253 | " assert(loss_buffer_size%batch_size==0)\n", 254 | " \n", 255 | " losses = []\n", 256 | " last_test_perf = 0\n", 257 | " best_test_perf = 1000\n", 258 | " \n", 259 | " for batch_idx, (data, target) in enumerate(train_loader):\n", 260 | " model.train()\n", 261 | " data = data.to(device).transpose(1,0)\n", 262 | " target = target.to(device)\n", 263 | " optimizer.zero_grad()\n", 264 | " out = model(data).transpose(1,0)\n", 265 | " loss = loss_func(out,\n", 266 | " target)\n", 267 | "\n", 268 | " loss.backward()\n", 269 | " optimizer.step()\n", 270 | "\n", 271 | " losses.append(loss.detach().cpu().numpy())\n", 272 | " losses = losses[int(-loss_buffer_size/batch_size):]\n", 273 | " \n", 274 | " if ((batch_idx*batch_size)%loss_buffer_size == 0):\n", 275 | " loss_track = {}\n", 276 | " last_test_perf = test_model(model, 'cuda', test_loader, \n", 277 | " )\n", 278 | " loss_track['avg_loss'] = np.mean(losses)\n", 279 | " loss_track['last_test'] = last_test_perf\n", 280 | " loss_track['epoch'] = epoch\n", 281 | " loss_track['batch_idx'] = batch_idx\n", 282 | " loss_track['tau'] = tau\n", 283 | " loss_track['pred_len'] = pred_len\n", 284 | " with open(perf_file, 'a+') as fp:\n", 285 | " csv_writer = DictWriter(fp, fieldnames=list(loss_track.keys()))\n", 286 | " if fp.tell() == 0:\n", 287 | " csv_writer.writeheader()\n", 288 | " csv_writer.writerow(loss_track)\n", 289 | " fp.flush()\n", 290 | " if best_test_perf > last_test_perf:\n", 291 | " torch.save(model.state_dict(), perf_file[:-4]+\".pt\")\n", 292 | " best_test_perf = last_test_perf\n", 293 | " if not (prog_bar is None):\n", 294 | " # Update progress_bar\n", 295 | " s = \"{}:{} Loss: {:.4f},valid: {:.4f}\"\n", 296 | " format_list = [e,batch_idx*batch_size, np.mean(losses), \n", 297 | " last_test_perf] \n", 298 | " s = s.format(*format_list)\n", 299 | " prog_bar.set_description(s)\n", 300 | " \n", 301 | "def test_model(model, device, test_loader):\n", 302 | " # Test the Model\n", 303 | " nrmsd = []\n", 304 | " with torch.no_grad():\n", 305 | " for x, y in test_loader:\n", 306 | " data = x.to(device).transpose(1,0)\n", 307 | " target = y.to(device)\n", 308 | " optimizer.zero_grad()\n", 309 | " out = model(data).transpose(1,0)\n", 310 | " nrmsd.append(nengolib.signal.nrmse(out.detach().cpu().numpy().flatten(), \n", 311 | " target=target.detach().cpu().numpy().flatten()))\n", 312 | "\n", 313 | " perf = np.array(nrmsd).mean()\n", 314 | " return perf" 315 | ] 316 | }, 317 | { 318 | "cell_type": "code", 319 | "execution_count": null, 320 | "metadata": { 321 | "ExecuteTime": { 322 | "end_time": "2021-01-31T22:27:47.613782Z", 323 | "start_time": "2021-01-31T16:15:15.395163Z" 324 | } 325 | }, 326 | "outputs": [], 327 | "source": [ 328 | "start_tau = 17\n", 329 | "start_pd = 15\n", 330 | "diffs = [1, 2, 3, 4, 5]\n", 331 | "for diff in diffs:\n", 332 | " (train_X, train_Y), (test_X, test_Y) = generate_data(128, 5000, tau=start_tau*diff, \n", 333 | " predict_length=start_pd*diff)#tau)\n", 334 | " dataset = torch.utils.data.TensorDataset(torch.Tensor(train_X).cuda(), torch.Tensor(train_Y).cuda())\n", 335 | " dataset = torch.utils.data.DataLoader(dataset, batch_size=32, shuffle=True)\n", 336 | "\n", 337 | " dataset_valid = torch.utils.data.TensorDataset(torch.Tensor(test_X).cuda(), torch.Tensor(test_Y).cuda())\n", 338 | " dataset_valid = torch.utils.data.DataLoader(dataset_valid, batch_size=64, shuffle=False)\n", 339 | " \n", 340 | " cornn_params = dict(n_inp=1,\n", 341 | " n_hid=128, \n", 342 | " n_out=1,\n", 343 | " dt=1.6e-2,\n", 344 | " gamma=94.5,\n", 345 | " epsilon=9.5)\n", 346 | " model = coRNN(**cornn_params).cuda()\n", 347 | " \n", 348 | " optimizer = torch.optim.Adam(model.parameters(), lr=2e-2)\n", 349 | " loss_func = nn.MSELoss()\n", 350 | "\n", 351 | " epochs = 1000\n", 352 | " batch_size = 32\n", 353 | " progress_bar = tqdm(range(int(epochs)), bar_format='{l_bar}{bar:5}{r_bar}{bar:-5b}')\n", 354 | " last_perf = 1000\n", 355 | " for e in progress_bar:\n", 356 | " train(model, ttype, dataset, dataset_valid, \n", 357 | " optimizer, loss_func, batch_size=batch_size, loss_buffer_size=64,\n", 358 | " epoch=e, perf_file=join('perf','mackeyglass_coRNN_ratio_1.csv'),\n", 359 | " prog_bar=progress_bar, tau=start_tau*diff, pred_len=start_pd*diff)" 360 | ] 361 | } 362 | ], 363 | "metadata": { 364 | "kernelspec": { 365 | "display_name": "Python 3", 366 | "language": "python", 367 | "name": "python3" 368 | }, 369 | "language_info": { 370 | "codemirror_mode": { 371 | "name": "ipython", 372 | "version": 3 373 | }, 374 | "file_extension": ".py", 375 | "mimetype": "text/x-python", 376 | "name": "python", 377 | "nbconvert_exporter": "python", 378 | "pygments_lexer": "ipython3", 379 | "version": "3.6.10" 380 | }, 381 | "toc": { 382 | "nav_menu": {}, 383 | "number_sections": true, 384 | "sideBar": true, 385 | "skip_h1_title": false, 386 | "title_cell": "Table of Contents", 387 | "title_sidebar": "Contents", 388 | "toc_cell": false, 389 | "toc_position": {}, 390 | "toc_section_display": true, 391 | "toc_window_display": false 392 | } 393 | }, 394 | "nbformat": 4, 395 | "nbformat_minor": 4 396 | } 397 | -------------------------------------------------------------------------------- /experiments/AddingProblem/addingproblem-coRNN.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": null, 6 | "metadata": { 7 | "ExecuteTime": { 8 | "end_time": "2021-02-22T16:53:09.063623Z", 9 | "start_time": "2021-02-22T16:53:08.886901Z" 10 | } 11 | }, 12 | "outputs": [], 13 | "source": [ 14 | "%matplotlib inline" 15 | ] 16 | }, 17 | { 18 | "cell_type": "code", 19 | "execution_count": null, 20 | "metadata": { 21 | "ExecuteTime": { 22 | "end_time": "2021-02-22T16:53:09.671516Z", 23 | "start_time": "2021-02-22T16:53:09.100488Z" 24 | } 25 | }, 26 | "outputs": [], 27 | "source": [ 28 | "import matplotlib.pylab as plt\n", 29 | "import torch\n", 30 | "import numpy as np\n", 31 | "import seaborn as sn\n", 32 | "sn.set_context(\"poster\")\n", 33 | "import os\n", 34 | "import torch\n", 35 | "import torch.nn as nn\n", 36 | "import torch.nn.functional as F\n", 37 | "from torch.nn.utils import weight_norm\n", 38 | "from torchvision import transforms, datasets\n", 39 | "\n", 40 | "from cornn import coRNN\n", 41 | "import numpy as np\n", 42 | "import scipy\n", 43 | "import scipy.stats as st\n", 44 | "import scipy.special\n", 45 | "import scipy.signal\n", 46 | "import scipy.interpolate\n", 47 | "\n", 48 | "import pandas as pd\n", 49 | "\n", 50 | "from os.path import join\n", 51 | "import random\n", 52 | "from csv import DictWriter\n", 53 | "\n", 54 | "from tqdm.notebook import tqdm\n", 55 | "import pickle\n", 56 | "# if gpu is to be used\n", 57 | "use_cuda = torch.cuda.is_available()\n", 58 | "print(use_cuda)\n", 59 | "FloatTensor = torch.cuda.FloatTensor if use_cuda else torch.FloatTensor\n", 60 | "DoubleTensor = torch.cuda.DoubleTensor if use_cuda else torch.DoubleTensor\n", 61 | "IntTensor = torch.cuda.IntTensor if use_cuda else torch.IntTensor\n", 62 | "LongTensor = torch.cuda.LongTensor if use_cuda else torch.LongTensor\n", 63 | "ByteTensor = torch.cuda.ByteTensor if use_cuda else torch.ByteTensor\n", 64 | "ttype = FloatTensor" 65 | ] 66 | }, 67 | { 68 | "cell_type": "code", 69 | "execution_count": null, 70 | "metadata": { 71 | "ExecuteTime": { 72 | "end_time": "2021-02-22T16:53:09.677199Z", 73 | "start_time": "2021-02-22T16:53:09.672636Z" 74 | } 75 | }, 76 | "outputs": [], 77 | "source": [ 78 | "import torch\n", 79 | "import numpy as np\n", 80 | "\n", 81 | "def get_batch(batch_size, T, ttype):\n", 82 | " values = torch.rand(T, batch_size, requires_grad=False)\n", 83 | " indices = torch.zeros_like(values)\n", 84 | " half = int(T / 2)\n", 85 | " for i in range(batch_size):\n", 86 | " half_1 = np.random.randint(half)\n", 87 | " hals_2 = np.random.randint(half, T)\n", 88 | " indices[half_1, i] = 1\n", 89 | " indices[hals_2, i] = 1\n", 90 | "\n", 91 | " data = torch.stack((values, indices), dim=-1).type(ttype)\n", 92 | " targets = torch.mul(values, indices).sum(dim=0).type(ttype)\n", 93 | " return data, targets\n" 94 | ] 95 | }, 96 | { 97 | "cell_type": "code", 98 | "execution_count": null, 99 | "metadata": { 100 | "ExecuteTime": { 101 | "end_time": "2021-02-22T16:53:13.510567Z", 102 | "start_time": "2021-02-22T16:53:13.508399Z" 103 | } 104 | }, 105 | "outputs": [], 106 | "source": [ 107 | "torch.manual_seed(1111)" 108 | ] 109 | }, 110 | { 111 | "cell_type": "code", 112 | "execution_count": null, 113 | "metadata": { 114 | "ExecuteTime": { 115 | "end_time": "2021-02-22T16:53:13.523219Z", 116 | "start_time": "2021-02-22T16:53:13.511392Z" 117 | } 118 | }, 119 | "outputs": [], 120 | "source": [ 121 | "def train(model, ttype, seq_length, optimizer, loss_func, \n", 122 | " epoch, perf_file, loss_buffer_size=20, batch_size=1, test_size=10,\n", 123 | " device='cuda', prog_bar=None):\n", 124 | " assert(loss_buffer_size%batch_size==0)\n", 125 | "\n", 126 | " losses = []\n", 127 | " perfs = []\n", 128 | " last_test_perf = 0\n", 129 | " for batch_idx in range(20000):\n", 130 | " model.train()\n", 131 | " sig, target = get_batch(batch_size, seq_length, ttype=ttype)\n", 132 | " \n", 133 | " target = target.unsqueeze(1)\n", 134 | " optimizer.zero_grad()\n", 135 | " out = model(sig)\n", 136 | " loss = loss_func(out,\n", 137 | " target)\n", 138 | " \n", 139 | " loss.backward()\n", 140 | " optimizer.step()\n", 141 | "\n", 142 | " losses.append(loss.detach().cpu().numpy())\n", 143 | " losses = losses[-loss_buffer_size:]\n", 144 | " if not (prog_bar is None):\n", 145 | " # Update progress_bar\n", 146 | " s = \"{}:{} Loss: {:.8f}\"\n", 147 | " format_list = [e, int(batch_idx/(50/batch_size)), np.mean(losses)] \n", 148 | " s = s.format(*format_list)\n", 149 | " prog_bar.set_description(s)\n", 150 | " if ((batch_idx*batch_size)%loss_buffer_size == 0) & (batch_idx != 0):\n", 151 | " loss_track = {}\n", 152 | " #last_test_perf = test_norm(model, 'cuda', test_sig, test_class,\n", 153 | " # batch_size=test_size, \n", 154 | " # )\n", 155 | " loss_track['avg_loss'] = np.mean(losses)\n", 156 | " #loss_track['last_test'] = last_test_perf\n", 157 | " loss_track['epoch'] = epoch\n", 158 | " loss_track['batch_idx'] = batch_idx\n", 159 | " with open(perf_file, 'a+') as fp:\n", 160 | " csv_writer = DictWriter(fp, fieldnames=list(loss_track.keys()))\n", 161 | " if fp.tell() == 0:\n", 162 | " csv_writer.writeheader()\n", 163 | " csv_writer.writerow(loss_track)\n", 164 | " fp.flush()\n", 165 | "def test_norm(model, device, seq_length, loss_func, batch_size=100):\n", 166 | " model.eval()\n", 167 | " correct = 0\n", 168 | " count = 0\n", 169 | " with torch.no_grad():\n", 170 | " sig, target = get_batch(batch_size, seq_length, ttype=ttype)\n", 171 | " target = target.unsqueeze(1)\n", 172 | " out = model(sig)\n", 173 | " loss = loss_func(out,\n", 174 | " target)\n", 175 | " return loss" 176 | ] 177 | }, 178 | { 179 | "cell_type": "code", 180 | "execution_count": null, 181 | "metadata": { 182 | "ExecuteTime": { 183 | "end_time": "2021-02-22T16:53:13.536036Z", 184 | "start_time": "2021-02-22T16:53:13.524035Z" 185 | } 186 | }, 187 | "outputs": [], 188 | "source": [ 189 | "from torch import nn\n", 190 | "import torch\n", 191 | "from torch.autograd import Variable\n", 192 | "device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')\n", 193 | "\n", 194 | "class coRNNCell(nn.Module):\n", 195 | " def __init__(self, n_inp, n_hid, dt, gamma, epsilon):\n", 196 | " super(coRNNCell, self).__init__()\n", 197 | " self.dt = dt\n", 198 | " self.gamma = gamma\n", 199 | " self.epsilon = epsilon\n", 200 | " self.i2h = nn.Linear(n_inp + n_hid + n_hid, n_hid)\n", 201 | "\n", 202 | " def forward(self,x,hy,hz):\n", 203 | " hz = hz + self.dt * (torch.tanh(self.i2h(torch.cat((x, hz, hy),1)))\n", 204 | " - self.gamma * hy - self.epsilon * hz)\n", 205 | " hy = hy + self.dt * hz\n", 206 | "\n", 207 | " return hy, hz\n", 208 | "\n", 209 | "class coRNN(nn.Module):\n", 210 | " def __init__(self, n_inp, n_hid, n_out, dt, gamma, epsilon):\n", 211 | " super(coRNN, self).__init__()\n", 212 | " self.n_hid = n_hid\n", 213 | " self.cell = coRNNCell(n_inp,n_hid,dt,gamma,epsilon)\n", 214 | " self.readout = nn.Linear(n_hid, n_out)\n", 215 | "\n", 216 | " def forward(self, x):\n", 217 | " ## initialize hidden states\n", 218 | " hy = Variable(torch.zeros(x.size(1),self.n_hid)).to(device)\n", 219 | " hz = Variable(torch.zeros(x.size(1),self.n_hid)).to(device)\n", 220 | "\n", 221 | " for t in range(x.size(0)):\n", 222 | " hy, hz = self.cell(x[t],hy,hz)\n", 223 | " output = self.readout(hy)\n", 224 | "\n", 225 | " return output" 226 | ] 227 | }, 228 | { 229 | "cell_type": "markdown", 230 | "metadata": {}, 231 | "source": [ 232 | "# T = 100" 233 | ] 234 | }, 235 | { 236 | "cell_type": "code", 237 | "execution_count": null, 238 | "metadata": { 239 | "ExecuteTime": { 240 | "end_time": "2021-02-22T16:53:16.576467Z", 241 | "start_time": "2021-02-22T16:53:16.564088Z" 242 | } 243 | }, 244 | "outputs": [], 245 | "source": [ 246 | "cornn_params = dict(n_inp=2,\n", 247 | " n_hid=128, \n", 248 | " n_out=1,\n", 249 | " dt=1.6e-2,\n", 250 | " gamma=94.5,\n", 251 | " epsilon=9.5)\n", 252 | "model = coRNN(**cornn_params).cuda()\n", 253 | "\n", 254 | "tot_weights = 0\n", 255 | "for p in model.parameters():\n", 256 | " tot_weights += p.numel()\n", 257 | "print(\"Total Weights:\", tot_weights)\n", 258 | "print(model)" 259 | ] 260 | }, 261 | { 262 | "cell_type": "code", 263 | "execution_count": null, 264 | "metadata": { 265 | "ExecuteTime": { 266 | "start_time": "2021-02-22T16:53:19.447Z" 267 | } 268 | }, 269 | "outputs": [], 270 | "source": [ 271 | "seq_length=100\n", 272 | "\n", 273 | "loss_func = nn.MSELoss()\n", 274 | "optimizer = torch.optim.AdamW(model.parameters())\n", 275 | "epochs = 1\n", 276 | "progress_bar = tqdm(range(int(epochs)), bar_format='{l_bar}{bar:5}{r_bar}{bar:-5b}')\n", 277 | "for e in progress_bar:\n", 278 | " train(model, ttype, seq_length,\n", 279 | " optimizer, loss_func, batch_size=50, loss_buffer_size=100,\n", 280 | " epoch=e, perf_file=join('perf','adding100_cornn_big.csv'),\n", 281 | " prog_bar=progress_bar)" 282 | ] 283 | }, 284 | { 285 | "cell_type": "markdown", 286 | "metadata": {}, 287 | "source": [ 288 | "# T = 500" 289 | ] 290 | }, 291 | { 292 | "cell_type": "code", 293 | "execution_count": null, 294 | "metadata": { 295 | "ExecuteTime": { 296 | "end_time": "2020-12-15T17:16:54.804717Z", 297 | "start_time": "2020-12-15T17:16:54.799148Z" 298 | } 299 | }, 300 | "outputs": [], 301 | "source": [ 302 | "cornn_params = dict(n_inp=2,\n", 303 | " n_hid=128, \n", 304 | " n_out=1,\n", 305 | " dt=6e-2,\n", 306 | " gamma=66,\n", 307 | " epsilon=15)\n", 308 | "model = coRNN(**cornn_params).cuda()\n", 309 | "\n", 310 | "tot_weights = 0\n", 311 | "for p in model.parameters():\n", 312 | " tot_weights += p.numel()\n", 313 | "print(\"Total Weights:\", tot_weights)\n", 314 | "print(model)" 315 | ] 316 | }, 317 | { 318 | "cell_type": "code", 319 | "execution_count": null, 320 | "metadata": { 321 | "ExecuteTime": { 322 | "start_time": "2020-12-15T17:17:01.879Z" 323 | } 324 | }, 325 | "outputs": [], 326 | "source": [ 327 | "seq_length=500\n", 328 | "\n", 329 | "loss_func = nn.MSELoss()\n", 330 | "optimizer = torch.optim.Adam(model.parameters(), lr=3e-2)\n", 331 | "epochs = 1\n", 332 | "progress_bar = tqdm(range(int(epochs)), bar_format='{l_bar}{bar:5}{r_bar}{bar:-5b}')\n", 333 | "for e in progress_bar:\n", 334 | " train(model, ttype, seq_length,\n", 335 | " optimizer, loss_func, batch_size=50, loss_buffer_size=100,\n", 336 | " epoch=e, perf_file=join('perf','adding500_cornn_6.csv'),\n", 337 | " prog_bar=progress_bar)" 338 | ] 339 | }, 340 | { 341 | "cell_type": "markdown", 342 | "metadata": {}, 343 | "source": [ 344 | "# T = 2000" 345 | ] 346 | }, 347 | { 348 | "cell_type": "code", 349 | "execution_count": null, 350 | "metadata": { 351 | "ExecuteTime": { 352 | "end_time": "2020-12-02T05:15:54.017640Z", 353 | "start_time": "2020-12-02T05:15:54.014072Z" 354 | } 355 | }, 356 | "outputs": [], 357 | "source": [ 358 | "cornn_params = dict(n_inp=2,\n", 359 | " n_hid=128, \n", 360 | " n_out=1,\n", 361 | " dt=3e-2,\n", 362 | " gamma=80,\n", 363 | " epsilon=12)\n", 364 | "model = coRNN(**cornn_params).cuda()\n", 365 | "\n", 366 | "tot_weights = 0\n", 367 | "for p in model.parameters():\n", 368 | " tot_weights += p.numel()\n", 369 | "print(\"Total Weights:\", tot_weights)\n", 370 | "print(model)" 371 | ] 372 | }, 373 | { 374 | "cell_type": "code", 375 | "execution_count": null, 376 | "metadata": { 377 | "ExecuteTime": { 378 | "end_time": "2020-12-02T09:28:58.690751Z", 379 | "start_time": "2020-12-02T05:15:54.018817Z" 380 | } 381 | }, 382 | "outputs": [], 383 | "source": [ 384 | "seq_length=2000\n", 385 | "\n", 386 | "loss_func = nn.MSELoss()\n", 387 | "optimizer = torch.optim.Adam(model.parameters(), lr=2e-2)\n", 388 | "epochs = 1\n", 389 | "progress_bar = tqdm(range(int(epochs)), bar_format='{l_bar}{bar:5}{r_bar}{bar:-5b}')\n", 390 | "for e in progress_bar:\n", 391 | " train(model, ttype, seq_length,\n", 392 | " optimizer, loss_func, batch_size=50, loss_buffer_size=100,\n", 393 | " epoch=e, perf_file=join('perf','adding2000_cornn_1.csv'),\n", 394 | " prog_bar=progress_bar)" 395 | ] 396 | }, 397 | { 398 | "cell_type": "markdown", 399 | "metadata": {}, 400 | "source": [ 401 | "# T = 5000" 402 | ] 403 | }, 404 | { 405 | "cell_type": "code", 406 | "execution_count": null, 407 | "metadata": { 408 | "ExecuteTime": { 409 | "end_time": "2020-12-01T18:39:48.793500Z", 410 | "start_time": "2020-12-01T18:39:48.775403Z" 411 | } 412 | }, 413 | "outputs": [], 414 | "source": [ 415 | "cornn_params = dict(n_inp=2,\n", 416 | " n_hid=128, \n", 417 | " n_out=1,\n", 418 | " dt=1.6e-2,\n", 419 | " gamma=94.5,\n", 420 | " epsilon=9.5)\n", 421 | "model = coRNN(**cornn_params).cuda()\n", 422 | "\n", 423 | "tot_weights = 0\n", 424 | "for p in model.parameters():\n", 425 | " tot_weights += p.numel()\n", 426 | "print(\"Total Weights:\", tot_weights)\n", 427 | "print(model)" 428 | ] 429 | }, 430 | { 431 | "cell_type": "code", 432 | "execution_count": null, 433 | "metadata": { 434 | "ExecuteTime": { 435 | "end_time": "2020-12-02T05:15:54.012718Z", 436 | "start_time": "2020-12-01T18:39:52.788917Z" 437 | } 438 | }, 439 | "outputs": [], 440 | "source": [ 441 | "seq_length=5000\n", 442 | "\n", 443 | "loss_func = nn.MSELoss()\n", 444 | "optimizer = torch.optim.Adam(model.parameters(), lr=2e-2)\n", 445 | "epochs = 1\n", 446 | "progress_bar = tqdm(range(int(epochs)), bar_format='{l_bar}{bar:5}{r_bar}{bar:-5b}')\n", 447 | "for e in progress_bar:\n", 448 | " train(model, ttype, seq_length,\n", 449 | " optimizer, loss_func, batch_size=50, loss_buffer_size=100,\n", 450 | " epoch=e, perf_file=join('perf','adding5000_cornn_1.csv'),\n", 451 | " prog_bar=progress_bar)" 452 | ] 453 | } 454 | ], 455 | "metadata": { 456 | "kernelspec": { 457 | "display_name": "Python 3", 458 | "language": "python", 459 | "name": "python3" 460 | }, 461 | "language_info": { 462 | "codemirror_mode": { 463 | "name": "ipython", 464 | "version": 3 465 | }, 466 | "file_extension": ".py", 467 | "mimetype": "text/x-python", 468 | "name": "python", 469 | "nbconvert_exporter": "python", 470 | "pygments_lexer": "ipython3", 471 | "version": "3.6.10" 472 | }, 473 | "toc": { 474 | "nav_menu": { 475 | "height": "141px", 476 | "width": "160px" 477 | }, 478 | "number_sections": true, 479 | "sideBar": true, 480 | "skip_h1_title": false, 481 | "title_cell": "Table of Contents", 482 | "title_sidebar": "Contents", 483 | "toc_cell": false, 484 | "toc_position": {}, 485 | "toc_section_display": true, 486 | "toc_window_display": false 487 | } 488 | }, 489 | "nbformat": 4, 490 | "nbformat_minor": 4 491 | } 492 | -------------------------------------------------------------------------------- /experiments/psMNIST/psMNIST.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": null, 6 | "metadata": { 7 | "ExecuteTime": { 8 | "end_time": "2020-12-03T16:57:39.184101Z", 9 | "start_time": "2020-12-03T16:57:38.997252Z" 10 | } 11 | }, 12 | "outputs": [], 13 | "source": [ 14 | "%matplotlib inline\n" 15 | ] 16 | }, 17 | { 18 | "cell_type": "code", 19 | "execution_count": null, 20 | "metadata": { 21 | "ExecuteTime": { 22 | "end_time": "2020-12-05T18:29:47.565802Z", 23 | "start_time": "2020-12-05T18:29:47.560378Z" 24 | } 25 | }, 26 | "outputs": [], 27 | "source": [ 28 | "import matplotlib.pyplot as plt\n", 29 | "from torch.optim.lr_scheduler import StepLR\n", 30 | "import torch\n", 31 | "from sklearn.metrics import confusion_matrix\n", 32 | "import pandas as pd\n", 33 | "import numpy as np\n", 34 | "\n", 35 | "import torch.nn as nn\n", 36 | "\n", 37 | "from tqdm import tqdm_notebook\n", 38 | "\n", 39 | "from torchvision import transforms\n", 40 | "from torchvision import datasets\n", 41 | "from os.path import join\n", 42 | "\n", 43 | "from deepsith import DeepSITH\n", 44 | "\n", 45 | "from tqdm.notebook import tqdm\n", 46 | "\n", 47 | "import random\n", 48 | "\n", 49 | "from csv import DictWriter\n", 50 | "# if gpu is to be used\n", 51 | "use_cuda = torch.cuda.is_available()\n", 52 | "\n", 53 | "FloatTensor = torch.cuda.FloatTensor if use_cuda else torch.FloatTensor\n", 54 | "LongTensor = torch.cuda.LongTensor if use_cuda else torch.LongTensor\n", 55 | "\n", 56 | "ttype =FloatTensor\n", 57 | "\n", 58 | "import seaborn as sn\n", 59 | "print(use_cuda)\n", 60 | "import pickle\n", 61 | "\n", 62 | "sn.set_context(\"poster\")" 63 | ] 64 | }, 65 | { 66 | "cell_type": "code", 67 | "execution_count": null, 68 | "metadata": { 69 | "ExecuteTime": { 70 | "end_time": "2020-12-03T16:57:40.061632Z", 71 | "start_time": "2020-12-03T16:57:40.058700Z" 72 | } 73 | }, 74 | "outputs": [], 75 | "source": [ 76 | "class DeepSITH_Classifier(nn.Module):\n", 77 | " def __init__(self, out_features, layer_params, dropout=.1):\n", 78 | " super(DeepSITH_Classifier, self).__init__()\n", 79 | " last_hidden = layer_params[-1]['hidden_size']\n", 80 | " self.hs = DeepSITH(layer_params=layer_params, dropout=dropout)\n", 81 | " self.to_out = nn.Linear(last_hidden, out_features)\n", 82 | " def forward(self, inp):\n", 83 | " x = self.hs(inp)\n", 84 | " x = self.to_out(x)\n", 85 | " return x" 86 | ] 87 | }, 88 | { 89 | "cell_type": "markdown", 90 | "metadata": {}, 91 | "source": [ 92 | "# Load Stimuli" 93 | ] 94 | }, 95 | { 96 | "cell_type": "code", 97 | "execution_count": null, 98 | "metadata": { 99 | "ExecuteTime": { 100 | "end_time": "2020-12-03T16:57:41.468289Z", 101 | "start_time": "2020-12-03T16:57:41.465864Z" 102 | } 103 | }, 104 | "outputs": [], 105 | "source": [ 106 | "# Same seed and supposed Permutation as the coRNN paper\n", 107 | "torch.manual_seed(12008)\n", 108 | "permute = torch.randperm(784)" 109 | ] 110 | }, 111 | { 112 | "cell_type": "code", 113 | "execution_count": null, 114 | "metadata": { 115 | "ExecuteTime": { 116 | "end_time": "2020-12-03T16:57:42.535384Z", 117 | "start_time": "2020-12-03T16:57:42.527542Z" 118 | } 119 | }, 120 | "outputs": [], 121 | "source": [ 122 | "print(permute)" 123 | ] 124 | }, 125 | { 126 | "cell_type": "code", 127 | "execution_count": null, 128 | "metadata": { 129 | "ExecuteTime": { 130 | "end_time": "2020-12-03T16:57:45.892964Z", 131 | "start_time": "2020-12-03T16:57:45.890797Z" 132 | } 133 | }, 134 | "outputs": [], 135 | "source": [ 136 | "norm = transforms.Normalize((.1307,), (.3081,), )" 137 | ] 138 | }, 139 | { 140 | "cell_type": "code", 141 | "execution_count": null, 142 | "metadata": { 143 | "ExecuteTime": { 144 | "end_time": "2020-12-03T16:57:46.269650Z", 145 | "start_time": "2020-12-03T16:57:46.255928Z" 146 | } 147 | }, 148 | "outputs": [], 149 | "source": [ 150 | "batch_size = 64\n", 151 | "transform = transforms.Compose([transforms.ToTensor(),\n", 152 | " transforms.Normalize((.1307,), (.3081,))\n", 153 | " ])\n", 154 | "ds1 = datasets.MNIST('../data', train=True, download=True, transform=transform)\n", 155 | "ds2 = datasets.MNIST('../data', train=False, download=True, transform=transform)\n", 156 | "train_loader=torch.utils.data.DataLoader(ds1,batch_size=batch_size, \n", 157 | " num_workers=1, pin_memory=True, shuffle=True)\n", 158 | "test_loader=torch.utils.data.DataLoader(ds2, batch_size=batch_size, \n", 159 | " num_workers=1, pin_memory=True, shuffle=True)" 160 | ] 161 | }, 162 | { 163 | "cell_type": "code", 164 | "execution_count": null, 165 | "metadata": { 166 | "ExecuteTime": { 167 | "end_time": "2020-12-03T16:57:51.113709Z", 168 | "start_time": "2020-12-03T16:57:47.283835Z" 169 | } 170 | }, 171 | "outputs": [], 172 | "source": [ 173 | "testi = next(iter(test_loader))[0]\n", 174 | "\n", 175 | "plt.imshow(testi[0].reshape(-1)[permute].reshape(28,28))\n", 176 | "\n", 177 | "plt.colorbar()" 178 | ] 179 | }, 180 | { 181 | "cell_type": "code", 182 | "execution_count": null, 183 | "metadata": { 184 | "ExecuteTime": { 185 | "end_time": "2020-12-03T16:57:51.164838Z", 186 | "start_time": "2020-12-03T16:57:51.162192Z" 187 | } 188 | }, 189 | "outputs": [], 190 | "source": [ 191 | "testi.shape" 192 | ] 193 | }, 194 | { 195 | "cell_type": "markdown", 196 | "metadata": {}, 197 | "source": [ 198 | "# Define test and train" 199 | ] 200 | }, 201 | { 202 | "cell_type": "code", 203 | "execution_count": null, 204 | "metadata": { 205 | "ExecuteTime": { 206 | "end_time": "2020-12-04T15:09:47.998604Z", 207 | "start_time": "2020-12-04T15:09:47.988228Z" 208 | } 209 | }, 210 | "outputs": [], 211 | "source": [ 212 | "\n", 213 | "def train(model, ttype, train_loader, test_loader, optimizer, loss_func, epoch, perf_file,\n", 214 | " permute=None, loss_buffer_size=800, batch_size=4, device='cuda',\n", 215 | " prog_bar=None):\n", 216 | " \n", 217 | " assert(loss_buffer_size%batch_size==0)\n", 218 | " if permute is None:\n", 219 | " permute = torch.LongTensor(list(range(784)))\n", 220 | " \n", 221 | " losses = []\n", 222 | " perfs = []\n", 223 | " last_test_perf = 0\n", 224 | " best_test_perf = -1\n", 225 | " \n", 226 | " for batch_idx, (data, target) in enumerate(train_loader):\n", 227 | " model.train()\n", 228 | " data = data.to(device).view(data.shape[0],1,1,-1)\n", 229 | " target = target.to(device)\n", 230 | " optimizer.zero_grad()\n", 231 | " out = model(data[:, :, :, permute])\n", 232 | " loss = loss_func(out[:, -1, :],\n", 233 | " target)\n", 234 | "\n", 235 | " loss.backward()\n", 236 | " optimizer.step()\n", 237 | "\n", 238 | " perfs.append((torch.argmax(out[:, -1, :], dim=-1) == \n", 239 | " target).sum().item())\n", 240 | " perfs = perfs[int(-loss_buffer_size/batch_size):]\n", 241 | " losses.append(loss.detach().cpu().numpy())\n", 242 | " losses = losses[int(-loss_buffer_size/batch_size):]\n", 243 | " if not (prog_bar is None):\n", 244 | " # Update progress_bar\n", 245 | " s = \"{}:{} Loss: {:.4f}, perf: {:.4f}, valid: {:.4f}\"\n", 246 | " format_list = [e,batch_idx*batch_size, np.mean(losses), \n", 247 | " np.sum(perfs)/((len(perfs))*batch_size), last_test_perf] \n", 248 | " s = s.format(*format_list)\n", 249 | " prog_bar.set_description(s)\n", 250 | " \n", 251 | " if ((batch_idx*batch_size)%loss_buffer_size == 0) & (batch_idx != 0):\n", 252 | " loss_track = {}\n", 253 | " # last_test_perf = test(model, 'cuda', test_loader, \n", 254 | " # batch_size=batch_size, \n", 255 | " # permute=permute)\n", 256 | " loss_track['avg_loss'] = np.mean(losses)\n", 257 | " #loss_track['last_test'] = last_test_perf\n", 258 | " loss_track['epoch'] = epoch\n", 259 | " loss_track['batch_idx'] = batch_idx\n", 260 | " loss_track['train_perf']= np.sum(perfs)/((len(perfs))*batch_size)\n", 261 | " with open(perf_file, 'a+') as fp:\n", 262 | " csv_writer = DictWriter(fp, fieldnames=list(loss_track.keys()))\n", 263 | " if fp.tell() == 0:\n", 264 | " csv_writer.writeheader()\n", 265 | " csv_writer.writerow(loss_track)\n", 266 | " fp.flush()\n", 267 | " #if best_test_perf < last_test_perf:\n", 268 | " # torch.save(model.state_dict(), perf_file[:-4]+\".pt\")\n", 269 | " # best_test_perf = last_test_perf\n", 270 | "\n", 271 | " \n", 272 | "def test(model, device, test_loader, batch_size=4, permute=None):\n", 273 | " if permute is None:\n", 274 | " permute = torch.LongTensor(list(range(784)))\n", 275 | " \n", 276 | " model.eval()\n", 277 | " correct = 0\n", 278 | " count = 0\n", 279 | " with torch.no_grad():\n", 280 | " for data, target in test_loader:\n", 281 | " data = data.to(device).view(data.shape[0],1,1,-1)\n", 282 | " target = target.to(device)\n", 283 | " \n", 284 | " out = model(data[:,:,:, permute])\n", 285 | " pred = out[:, -1].argmax(dim=-1, keepdim=True)\n", 286 | " correct += pred.eq(target.view_as(pred)).sum().item()\n", 287 | " count += 1\n", 288 | " return correct / len(test_loader.dataset)" 289 | ] 290 | }, 291 | { 292 | "cell_type": "markdown", 293 | "metadata": {}, 294 | "source": [ 295 | "# Setup the model" 296 | ] 297 | }, 298 | { 299 | "cell_type": "code", 300 | "execution_count": null, 301 | "metadata": { 302 | "ExecuteTime": { 303 | "end_time": "2020-12-04T15:09:48.936759Z", 304 | "start_time": "2020-12-04T15:09:48.926780Z" 305 | }, 306 | "scrolled": true 307 | }, 308 | "outputs": [], 309 | "source": [ 310 | "g = 0.0\n", 311 | "sith_params1 = {\"in_features\":1, \n", 312 | " \"tau_min\":1, \"tau_max\":30.0, \"buff_max\":50,\n", 313 | " \"k\":125, 'dt':1,\n", 314 | " \"ntau\":20, 'g':g, \n", 315 | " \"ttype\":ttype, \"batch_norm\":True,\n", 316 | " \"hidden_size\":60, \"act_func\":nn.ReLU()\n", 317 | " }\n", 318 | "sith_params2 = {\"in_features\":sith_params1['hidden_size'], \n", 319 | " \"tau_min\":1, \"tau_max\":150.0, \"buff_max\":250,\n", 320 | " \"k\":61, 'dt':1,\n", 321 | " \"ntau\":20, 'g':g, \n", 322 | " \"ttype\":ttype, \"batch_norm\":True,\n", 323 | " \"hidden_size\":60, \"act_func\":nn.ReLU()\n", 324 | " }\n", 325 | "sith_params3 = {\"in_features\":sith_params2['hidden_size'], \n", 326 | " \"tau_min\":1, \"tau_max\":750.0, \"buff_max\":1500,\n", 327 | " \"k\":35, 'dt':1,\n", 328 | " \"ntau\":20, 'g':g, \n", 329 | " \"ttype\":ttype, \"batch_norm\":True,\n", 330 | " \"hidden_size\":60, \"act_func\":nn.ReLU()\n", 331 | " }\n", 332 | "\n", 333 | "\n", 334 | "\n", 335 | "layer_params = [sith_params1, sith_params2, sith_params3]\n", 336 | "\n", 337 | "\n", 338 | "model = DeepSITH_Classifier(10,\n", 339 | " layer_params=layer_params, \n", 340 | " dropout=0.2).cuda()\n", 341 | "\n", 342 | "tot_weights = 0\n", 343 | "for p in model.parameters():\n", 344 | " tot_weights += p.numel()\n", 345 | "print(\"Total Weights:\", tot_weights)\n", 346 | "print(model)" 347 | ] 348 | }, 349 | { 350 | "cell_type": "code", 351 | "execution_count": null, 352 | "metadata": { 353 | "ExecuteTime": { 354 | "end_time": "2020-12-05T00:17:49.489853Z", 355 | "start_time": "2020-12-04T15:09:54.978362Z" 356 | }, 357 | "scrolled": true 358 | }, 359 | "outputs": [], 360 | "source": [ 361 | "epochs = 90\n", 362 | "loss_func = nn.CrossEntropyLoss()\n", 363 | "optimizer = torch.optim.Adam(model.parameters(), lr=2e-3)\n", 364 | "sched = StepLR(optimizer, step_size=30, gamma=0.1)\n", 365 | "progress_bar = tqdm(range(int(epochs)), bar_format='{l_bar}{bar:5}{r_bar}{bar:-5b}')\n", 366 | "test_perf = []\n", 367 | "for e in progress_bar:\n", 368 | " train(model, ttype, train_loader, test_loader, optimizer, loss_func, batch_size=batch_size,\n", 369 | " epoch=e, perf_file=join('perf','pmnist_deepsith_78.csv'),loss_buffer_size=64*32, \n", 370 | " prog_bar=progress_bar, permute=permute)\n", 371 | " last_test_perf = test(model, 'cuda', test_loader, \n", 372 | " batch_size=batch_size, \n", 373 | " permute=permute)\n", 374 | " test_perf.append({\"epoch\":e,\n", 375 | " 'test':last_test_perf})\n", 376 | " sched.step()" 377 | ] 378 | }, 379 | { 380 | "cell_type": "markdown", 381 | "metadata": { 382 | "ExecuteTime": { 383 | "end_time": "2020-11-02T01:31:22.277851Z", 384 | "start_time": "2020-11-02T01:31:22.275272Z" 385 | } 386 | }, 387 | "source": [ 388 | "# Find Errors" 389 | ] 390 | }, 391 | { 392 | "cell_type": "code", 393 | "execution_count": null, 394 | "metadata": { 395 | "ExecuteTime": { 396 | "end_time": "2020-12-05T00:26:42.879054Z", 397 | "start_time": "2020-12-05T00:26:29.019237Z" 398 | } 399 | }, 400 | "outputs": [], 401 | "source": [ 402 | "test(model, 'cuda', test_loader, \n", 403 | " batch_size=batch_size, \n", 404 | " permute=permute)" 405 | ] 406 | }, 407 | { 408 | "cell_type": "code", 409 | "execution_count": null, 410 | "metadata": { 411 | "ExecuteTime": { 412 | "end_time": "2020-11-10T14:51:23.948941Z", 413 | "start_time": "2020-11-10T14:51:23.944246Z" 414 | } 415 | }, 416 | "outputs": [], 417 | "source": [ 418 | "def conf_mat_gen(model, device, test_loader, batch_size=4, permute=None):\n", 419 | " if permute is None:\n", 420 | " permute = torch.LongTensor(list(range(784)))\n", 421 | " evals = {'pred':[],\n", 422 | " 'actual':[]}\n", 423 | " model.eval()\n", 424 | " correct = 0\n", 425 | " count = 0\n", 426 | " with torch.no_grad():\n", 427 | " for data, target in test_loader:\n", 428 | " data = data.to(device).view(data.shape[0],1,1,-1)\n", 429 | " target = target.to(device)\n", 430 | " for x in target:\n", 431 | " evals['actual'].append(x.detach().cpu().numpy())\n", 432 | " out = model(data[:,:,:, permute])\n", 433 | " for x in out[:, -1].argmax(dim=-1, keepdim=True):\n", 434 | " evals['pred'].append(x.detach().cpu().numpy())\n", 435 | " return evals" 436 | ] 437 | }, 438 | { 439 | "cell_type": "code", 440 | "execution_count": null, 441 | "metadata": { 442 | "ExecuteTime": { 443 | "end_time": "2020-11-10T14:51:36.918179Z", 444 | "start_time": "2020-11-10T14:51:25.179622Z" 445 | } 446 | }, 447 | "outputs": [], 448 | "source": [ 449 | "evals = conf_mat_gen(model, 'cuda', test_loader, batch_size=4, permute=permute)" 450 | ] 451 | }, 452 | { 453 | "cell_type": "code", 454 | "execution_count": null, 455 | "metadata": { 456 | "ExecuteTime": { 457 | "end_time": "2020-11-10T14:53:19.261434Z", 458 | "start_time": "2020-11-10T14:53:18.926905Z" 459 | } 460 | }, 461 | "outputs": [], 462 | "source": [ 463 | "fig = plt.figure(figsize=(10,11))\n", 464 | "plt.imshow(confusion_matrix(np.array(evals['pred'])[:, 0], \n", 465 | " np.array(evals['actual'])), cmap='coolwarm')\n", 466 | "plt.colorbar()\n", 467 | "plt.xticks(list(range(10)));\n", 468 | "plt.yticks(list(range(10)));\n", 469 | "plt.savefig(join('figs', 'sMNIST_LoLa_conf'), dpi=200, bboxinches='tight')" 470 | ] 471 | } 472 | ], 473 | "metadata": { 474 | "kernelspec": { 475 | "display_name": "Python 3", 476 | "language": "python", 477 | "name": "python3" 478 | }, 479 | "language_info": { 480 | "codemirror_mode": { 481 | "name": "ipython", 482 | "version": 3 483 | }, 484 | "file_extension": ".py", 485 | "mimetype": "text/x-python", 486 | "name": "python", 487 | "nbconvert_exporter": "python", 488 | "pygments_lexer": "ipython3", 489 | "version": "3.6.10" 490 | }, 491 | "toc": { 492 | "nav_menu": {}, 493 | "number_sections": true, 494 | "sideBar": true, 495 | "skip_h1_title": false, 496 | "title_cell": "Table of Contents", 497 | "title_sidebar": "Contents", 498 | "toc_cell": false, 499 | "toc_position": {}, 500 | "toc_section_display": true, 501 | "toc_window_display": false 502 | } 503 | }, 504 | "nbformat": 4, 505 | "nbformat_minor": 4 506 | } 507 | -------------------------------------------------------------------------------- /experiments/Hateful8/hateful8-LSTM.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": null, 6 | "metadata": { 7 | "ExecuteTime": { 8 | "end_time": "2020-12-18T15:21:53.650363Z", 9 | "start_time": "2020-12-18T15:21:53.466128Z" 10 | } 11 | }, 12 | "outputs": [], 13 | "source": [ 14 | "%matplotlib inline" 15 | ] 16 | }, 17 | { 18 | "cell_type": "code", 19 | "execution_count": null, 20 | "metadata": { 21 | "ExecuteTime": { 22 | "end_time": "2020-12-18T15:21:54.221984Z", 23 | "start_time": "2020-12-18T15:21:53.678357Z" 24 | } 25 | }, 26 | "outputs": [], 27 | "source": [ 28 | "import matplotlib.pyplot as plt\n", 29 | "import torch\n", 30 | "from torch import nn as nn\n", 31 | "from math import factorial\n", 32 | "import random\n", 33 | "import torch.nn.functional as F\n", 34 | "import numpy as np\n", 35 | "import seaborn as sn\n", 36 | "import pandas as pd\n", 37 | "import os \n", 38 | "from os.path import join\n", 39 | "import glob\n", 40 | "from math import factorial\n", 41 | "ttype = torch.cuda.FloatTensor if torch.cuda.is_available() else torch.FloatTensor\n", 42 | "ctype = torch.cuda.LongTensor if torch.cuda.is_available() else torch.LongTensor\n", 43 | "print(ttype)\n", 44 | "from torch.nn.utils import weight_norm\n", 45 | "\n", 46 | "from tqdm.notebook import tqdm\n", 47 | "import pickle\n", 48 | "sn.set_context(\"poster\")\n", 49 | "import itertools\n", 50 | "from csv import DictWriter\n", 51 | "import matplotlib.pylab as plt\n", 52 | "import csv\n", 53 | "import numpy as np\n", 54 | "import pandas as pd\n", 55 | "import os\n", 56 | "import seaborn as sn\n", 57 | "sn.set_context('talk')" 58 | ] 59 | }, 60 | { 61 | "cell_type": "code", 62 | "execution_count": null, 63 | "metadata": { 64 | "ExecuteTime": { 65 | "end_time": "2020-12-18T15:24:24.946369Z", 66 | "start_time": "2020-12-18T15:24:24.822539Z" 67 | } 68 | }, 69 | "outputs": [], 70 | "source": [ 71 | "def generate_noise(maxn=18):\n", 72 | " \"\"\"Generates dot and dash based noise.\"\"\"\n", 73 | " \n", 74 | " threes = np.random.randint(int(.5*maxn), int(.75*maxn))\n", 75 | " ones = (maxn - threes) * 2\n", 76 | " noise = list(itertools.repeat([1,1,1,0], threes))\n", 77 | " noise[:int(len(noise)/3)] = list(itertools.repeat([0,0], int(len(noise)/3)))\n", 78 | " ones = ones + int(len(noise)/3)\n", 79 | " noise.extend(list(itertools.repeat([1,0], ones)))\n", 80 | " random.shuffle(noise)\n", 81 | " noise = np.concatenate(noise)\n", 82 | " return noise\n", 83 | "noise = generate_noise()\n", 84 | "print(noise.shape)\n", 85 | "plt.plot(noise)" 86 | ] 87 | }, 88 | { 89 | "cell_type": "code", 90 | "execution_count": null, 91 | "metadata": { 92 | "ExecuteTime": { 93 | "end_time": "2020-12-18T15:24:32.691623Z", 94 | "start_time": "2020-12-18T15:24:30.795867Z" 95 | } 96 | }, 97 | "outputs": [], 98 | "source": [ 99 | "sig_lets = [\"A\",\"B\",\"C\",\"D\",\"E\",\"F\",\"G\",\"H\",]\n", 100 | "\n", 101 | "signals = ttype([[0,1,1,1,0,1,1,1,0,1,0,1,0,1,0,0,0],\n", 102 | " [0,1,1,1,0,1,0,1,1,1,0,1,0,1,0,0,0],\n", 103 | " [0,1,1,1,0,1,0,1,0,1,1,1,0,1,0,0,0],\n", 104 | " [0,1,1,1,0,1,0,1,0,1,0,1,1,1,0,0,0],\n", 105 | " [0,1,0,1,1,1,0,1,1,1,0,1,1,1,0,0,0],\n", 106 | " [0,1,1,1,0,1,0,1,1,1,0,1,1,1,0,0,0],\n", 107 | " [0,1,1,1,0,1,1,1,0,1,0,1,1,1,0,0,0],\n", 108 | " [0,1,1,1,0,1,1,1,0,1,1,1,0,1,0,0,0]]\n", 109 | " ).view(8, 1, 1, -1)\n", 110 | "#signals = ms\n", 111 | "key2id = {k:i for i, k in enumerate(sig_lets)}\n", 112 | "\n", 113 | "print(key2id)\n" 114 | ] 115 | }, 116 | { 117 | "cell_type": "code", 118 | "execution_count": null, 119 | "metadata": { 120 | "ExecuteTime": { 121 | "end_time": "2020-12-11T18:54:00.513324Z", 122 | "start_time": "2020-12-11T18:54:00.368662Z" 123 | } 124 | }, 125 | "outputs": [], 126 | "source": [ 127 | "torch.manual_seed(12345)\n", 128 | "np.random.seed(12345)\n", 129 | "training_samples = 32\n", 130 | "\n", 131 | "training_signals = []\n", 132 | "training_class = []\n", 133 | "\n", 134 | "for i, sig in enumerate(signals):\n", 135 | " temp_signals = []\n", 136 | " temp_class = []\n", 137 | " for x in range(training_samples):\n", 138 | " noise = ttype(generate_noise())\n", 139 | " temp = torch.cat([sig[0,0], noise]).unsqueeze(0)\n", 140 | " while(any([(temp == c_).all() for c_ in temp_signals])):\n", 141 | " noise = ttype(generate_noise())\n", 142 | " temp = torch.cat([sig[0,0], noise]).unsqueeze(0)\n", 143 | " temp_signals.append(temp)\n", 144 | " temp_class.append(i)\n", 145 | " training_signals.extend(temp_signals)\n", 146 | " training_class.extend(temp_class)\n", 147 | "\n", 148 | "batch_rand = torch.randperm(training_samples*signals.shape[0]) \n", 149 | "training_signals = torch.cat(training_signals).cuda().unsqueeze(-1)[batch_rand]\n", 150 | "training_class = ctype(training_class).cuda().unsqueeze(-1)[batch_rand]\n", 151 | "\n", 152 | "dataset = torch.utils.data.TensorDataset(training_signals, training_class)\n", 153 | "dataset = torch.utils.data.DataLoader(dataset, batch_size=32, shuffle=True)\n", 154 | "\n" 155 | ] 156 | }, 157 | { 158 | "cell_type": "code", 159 | "execution_count": null, 160 | "metadata": { 161 | "ExecuteTime": { 162 | "end_time": "2020-12-11T18:54:02.761255Z", 163 | "start_time": "2020-12-11T18:54:02.121822Z" 164 | } 165 | }, 166 | "outputs": [], 167 | "source": [ 168 | "testing_samples = 10\n", 169 | "testing_signals = []\n", 170 | "testing_class = []\n", 171 | "\n", 172 | "for i, sig in enumerate(signals):\n", 173 | " temp_signals = []\n", 174 | " temp_class = []\n", 175 | " for x in range(testing_samples):\n", 176 | " noise = ttype(generate_noise())\n", 177 | " temp = torch.cat([sig[0,0], noise]).unsqueeze(0)\n", 178 | " while(any([(temp == c_).all() for c_ in temp_signals]) or any([(temp == c_).all() for c_ in training_signals])):\n", 179 | " noise = ttype(generate_noise())\n", 180 | " temp = torch.cat([sig[0,0], noise]).unsqueeze(0)\n", 181 | " temp_signals.append(temp)\n", 182 | " temp_class.append(i)\n", 183 | " testing_signals.extend(temp_signals)\n", 184 | " testing_class.extend(temp_class)\n", 185 | "batch_rand = torch.randperm(testing_samples*signals.shape[0])\n", 186 | "\n", 187 | "testing_signals = torch.cat(testing_signals).cuda().unsqueeze(-1)[batch_rand]\n", 188 | "testing_class = ctype(testing_class).cuda().unsqueeze(-1)[batch_rand]\n", 189 | "\n", 190 | "\n", 191 | "dataset_valid = torch.utils.data.TensorDataset(testing_signals, testing_class)\n", 192 | "dataset_valid = torch.utils.data.DataLoader(dataset_valid, batch_size=32, shuffle=False)\n" 193 | ] 194 | }, 195 | { 196 | "cell_type": "code", 197 | "execution_count": null, 198 | "metadata": { 199 | "ExecuteTime": { 200 | "end_time": "2020-12-18T15:24:33.846618Z", 201 | "start_time": "2020-12-18T15:24:33.842381Z" 202 | } 203 | }, 204 | "outputs": [], 205 | "source": [ 206 | "class LSTM_Predictor(nn.Module):\n", 207 | " def __init__(self, out_features, lstm_params):\n", 208 | " super(LSTM_Predictor, self).__init__()\n", 209 | " self.lstm = nn.LSTM(**lstm_params)\n", 210 | " self.to_out = nn.Linear(lstm_params['hidden_size'], \n", 211 | " out_features)\n", 212 | " def forward(self, inp):\n", 213 | " x = self.lstm(inp)[0].transpose(1,0)\n", 214 | " x = torch.tanh(self.to_out(x))\n", 215 | " return x" 216 | ] 217 | }, 218 | { 219 | "cell_type": "code", 220 | "execution_count": null, 221 | "metadata": { 222 | "ExecuteTime": { 223 | "end_time": "2020-12-18T15:24:48.479338Z", 224 | "start_time": "2020-12-18T15:24:48.463241Z" 225 | } 226 | }, 227 | "outputs": [], 228 | "source": [ 229 | "def train(model, ttype, train_loader, test_loader, optimizer, loss_func, epoch, perf_file,\n", 230 | " permute=None, loss_buffer_size=64, batch_size=4, device='cuda',\n", 231 | " prog_bar=None, maxn=6):\n", 232 | " \n", 233 | " assert(loss_buffer_size%batch_size==0)\n", 234 | " \n", 235 | " losses = []\n", 236 | " perfs = []\n", 237 | " last_test_perf = 0\n", 238 | " best_test_perf = -1\n", 239 | " \n", 240 | " for batch_idx, (data, target) in enumerate(train_loader):\n", 241 | " model.train()\n", 242 | " data = data.to(device).transpose(1,0)\n", 243 | " target = target.to(device)\n", 244 | " optimizer.zero_grad()\n", 245 | " out = model(data)\n", 246 | " loss = loss_func(out[:, -1, :],\n", 247 | " target[:, 0])\n", 248 | " \n", 249 | " loss.backward()\n", 250 | " optimizer.step()\n", 251 | "\n", 252 | " perfs.append((torch.argmax(out[:, -1, :], dim=-1) == \n", 253 | " target[:, 0]).sum().item())\n", 254 | " perfs = perfs[int(-loss_buffer_size/batch_size):]\n", 255 | " losses.append(loss.detach().cpu().numpy())\n", 256 | " losses = losses[int(-loss_buffer_size/batch_size):]\n", 257 | " if not (prog_bar is None):\n", 258 | " # Update progress_bar\n", 259 | " s = \"{}:{} Loss: {:.4f}, perf: {:.4f}, valid: {:.4f}\"\n", 260 | " format_list = [e,batch_idx*batch_size, np.mean(losses), \n", 261 | " np.sum(perfs)/((len(perfs))*batch_size), last_test_perf] \n", 262 | " s = s.format(*format_list)\n", 263 | " prog_bar.set_description(s)\n", 264 | " \n", 265 | " if ((batch_idx*batch_size)%loss_buffer_size == 0) & (batch_idx != 0):\n", 266 | " loss_track = {}\n", 267 | " last_test_perf = test(model, 'cuda', test_loader, \n", 268 | " batch_size=batch_size, \n", 269 | " permute=permute)\n", 270 | " loss_track['avg_loss'] = np.mean(losses)\n", 271 | " loss_track['last_test'] = last_test_perf\n", 272 | " loss_track['epoch'] = epoch\n", 273 | " loss_track['maxn'] = maxn\n", 274 | " loss_track['batch_idx'] = batch_idx\n", 275 | " loss_track['pres_num'] = batch_idx*batch_size + epoch*len(train_loader.dataset)\n", 276 | " loss_track['train_perf']= np.sum(perfs)/((len(perfs))*batch_size)\n", 277 | " with open(perf_file, 'a+') as fp:\n", 278 | " csv_writer = DictWriter(fp, fieldnames=list(loss_track.keys()))\n", 279 | " if fp.tell() == 0:\n", 280 | " csv_writer.writeheader()\n", 281 | " csv_writer.writerow(loss_track)\n", 282 | " fp.flush()\n", 283 | " #if best_test_perf < last_test_perf:\n", 284 | " # torch.save(model.state_dict(), perf_file[:-4]+\".pt\")\n", 285 | " # best_test_perf = last_test_perf\n", 286 | "\n", 287 | " \n", 288 | "def test(model, device, test_loader, batch_size=4, permute=None):\n", 289 | " model.eval()\n", 290 | " correct = 0\n", 291 | " count = 0\n", 292 | " with torch.no_grad():\n", 293 | " for data, target in test_loader:\n", 294 | " data = data.to(device).transpose(1,0)\n", 295 | " target = target.to(device)\n", 296 | " \n", 297 | " out = model(data)\n", 298 | " pred = out[:, -1, :].argmax(dim=-1, keepdim=True)\n", 299 | " \n", 300 | " correct += pred.eq(target.view_as(pred)).sum().item()\n", 301 | " count += 1\n", 302 | " return correct / len(test_loader.dataset)" 303 | ] 304 | }, 305 | { 306 | "cell_type": "code", 307 | "execution_count": null, 308 | "metadata": {}, 309 | "outputs": [], 310 | "source": [] 311 | }, 312 | { 313 | "cell_type": "markdown", 314 | "metadata": {}, 315 | "source": [ 316 | "# Training and testing" 317 | ] 318 | }, 319 | { 320 | "cell_type": "code", 321 | "execution_count": null, 322 | "metadata": { 323 | "ExecuteTime": { 324 | "end_time": "2020-12-11T18:58:50.330501Z", 325 | "start_time": "2020-12-11T18:58:50.327920Z" 326 | } 327 | }, 328 | "outputs": [], 329 | "source": [ 330 | "# You likely don't need this to be this long, but just in case.\n", 331 | "epochs = 1000\n", 332 | "\n", 333 | "# Just for visualizing average loss through time. \n", 334 | "loss_buffer_size = 100" 335 | ] 336 | }, 337 | { 338 | "cell_type": "code", 339 | "execution_count": null, 340 | "metadata": { 341 | "ExecuteTime": { 342 | "end_time": "2020-12-18T16:14:26.814579Z", 343 | "start_time": "2020-12-18T16:13:00.732980Z" 344 | }, 345 | "scrolled": true 346 | }, 347 | "outputs": [], 348 | "source": [ 349 | "test_noise_lengths = [6,7,9,13,21,37]\n", 350 | "for maxn in test_noise_lengths:\n", 351 | " torch.manual_seed(12345)\n", 352 | " np.random.seed(12345)\n", 353 | " training_samples = 32\n", 354 | "\n", 355 | " training_signals = []\n", 356 | " training_class = []\n", 357 | "\n", 358 | " for i, sig in enumerate(signals):\n", 359 | " temp_signals = []\n", 360 | " temp_class = []\n", 361 | " for x in range(training_samples):\n", 362 | " noise = ttype(generate_noise(maxn))\n", 363 | " temp = torch.cat([sig[0,0], noise]).unsqueeze(0)\n", 364 | " while(any([(temp == c_).all() for c_ in temp_signals])):\n", 365 | " noise = ttype(generate_noise(maxn))\n", 366 | " temp = torch.cat([sig[0,0], noise]).unsqueeze(0)\n", 367 | " temp_signals.append(temp)\n", 368 | " temp_class.append(i)\n", 369 | " training_signals.extend(temp_signals)\n", 370 | " training_class.extend(temp_class)\n", 371 | "\n", 372 | " batch_rand = torch.randperm(training_samples*signals.shape[0]) \n", 373 | " training_signals = torch.cat(training_signals).cuda().unsqueeze(-1)[batch_rand]\n", 374 | " training_class = ctype(training_class).cuda().unsqueeze(-1)[batch_rand]\n", 375 | "\n", 376 | " dataset = torch.utils.data.TensorDataset(training_signals, training_class)\n", 377 | " dataset = torch.utils.data.DataLoader(dataset, batch_size=32, shuffle=True)\n", 378 | " testing_samples = 10\n", 379 | " testing_signals = []\n", 380 | " testing_class = []\n", 381 | "\n", 382 | " for i, sig in enumerate(signals):\n", 383 | " temp_signals = []\n", 384 | " temp_class = []\n", 385 | " for x in range(testing_samples):\n", 386 | " noise = ttype(generate_noise(maxn))\n", 387 | " temp = torch.cat([sig[0,0], noise]).unsqueeze(0)\n", 388 | " while(any([(temp == c_).all() for c_ in temp_signals]) or any([(temp == c_).all() for c_ in training_signals])):\n", 389 | " noise = ttype(generate_noise(maxn))\n", 390 | " temp = torch.cat([sig[0,0], noise]).unsqueeze(0)\n", 391 | " temp_signals.append(temp)\n", 392 | " temp_class.append(i)\n", 393 | " testing_signals.extend(temp_signals)\n", 394 | " testing_class.extend(temp_class)\n", 395 | " batch_rand = torch.randperm(testing_samples*signals.shape[0])\n", 396 | "\n", 397 | " testing_signals = torch.cat(testing_signals).cuda().unsqueeze(-1)[batch_rand]\n", 398 | " testing_class = ctype(testing_class).cuda().unsqueeze(-1)[batch_rand]\n", 399 | "\n", 400 | "\n", 401 | " dataset_valid = torch.utils.data.TensorDataset(testing_signals, testing_class)\n", 402 | " dataset_valid = torch.utils.data.DataLoader(dataset_valid, batch_size=32, shuffle=False)\n", 403 | " lstm_params = dict(input_size=1,\n", 404 | " hidden_size=38, \n", 405 | " num_layers=3)\n", 406 | " model = LSTM_Predictor(8, lstm_params=lstm_params).cuda()\n", 407 | "\n", 408 | " tot_weights = 0\n", 409 | " for p in model.parameters():\n", 410 | " tot_weights += p.numel()\n", 411 | " print(\"Total Weights:\", tot_weights)\n", 412 | " print(model)\n", 413 | " loss_func = torch.nn.CrossEntropyLoss()\n", 414 | " optimizer = torch.optim.Adam(model.parameters())\n", 415 | " epochs = 400\n", 416 | " batch_size = 32\n", 417 | " progress_bar = tqdm(range(int(epochs)), bar_format='{l_bar}{bar:5}{r_bar}{bar:-5b}')\n", 418 | " for e in progress_bar:\n", 419 | " train(model, ttype, dataset, dataset_valid, \n", 420 | " optimizer, loss_func, batch_size=batch_size,\n", 421 | " epoch=e, perf_file=join('perf','h8_lstm_length_0.csv'),\n", 422 | " prog_bar=progress_bar, maxn=maxn)" 423 | ] 424 | } 425 | ], 426 | "metadata": { 427 | "kernelspec": { 428 | "display_name": "Python 3", 429 | "language": "python", 430 | "name": "python3" 431 | }, 432 | "language_info": { 433 | "codemirror_mode": { 434 | "name": "ipython", 435 | "version": 3 436 | }, 437 | "file_extension": ".py", 438 | "mimetype": "text/x-python", 439 | "name": "python", 440 | "nbconvert_exporter": "python", 441 | "pygments_lexer": "ipython3", 442 | "version": "3.6.10" 443 | }, 444 | "toc": { 445 | "nav_menu": {}, 446 | "number_sections": true, 447 | "sideBar": true, 448 | "skip_h1_title": false, 449 | "title_cell": "Table of Contents", 450 | "title_sidebar": "Contents", 451 | "toc_cell": false, 452 | "toc_position": {}, 453 | "toc_section_display": true, 454 | "toc_window_display": false 455 | } 456 | }, 457 | "nbformat": 4, 458 | "nbformat_minor": 4 459 | } 460 | -------------------------------------------------------------------------------- /experiments/Hateful8/hateful8-LMU.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": null, 6 | "metadata": { 7 | "ExecuteTime": { 8 | "end_time": "2021-05-12T18:04:53.607992Z", 9 | "start_time": "2021-05-12T18:04:53.437658Z" 10 | } 11 | }, 12 | "outputs": [], 13 | "source": [ 14 | "%matplotlib inline\n" 15 | ] 16 | }, 17 | { 18 | "cell_type": "code", 19 | "execution_count": null, 20 | "metadata": { 21 | "ExecuteTime": { 22 | "end_time": "2021-05-12T18:04:54.763726Z", 23 | "start_time": "2021-05-12T18:04:53.682215Z" 24 | } 25 | }, 26 | "outputs": [], 27 | "source": [ 28 | "import matplotlib.pyplot as plt\n", 29 | "import torch\n", 30 | "from torch import nn as nn\n", 31 | "from math import factorial\n", 32 | "import random\n", 33 | "import torch.nn.functional as F\n", 34 | "import numpy as np\n", 35 | "import seaborn as sn\n", 36 | "import pandas as pd\n", 37 | "import os \n", 38 | "from os.path import join\n", 39 | "import glob\n", 40 | "from math import factorial\n", 41 | "ttype = torch.cuda.FloatTensor if torch.cuda.is_available() else torch.FloatTensor\n", 42 | "ctype = torch.cuda.LongTensor if torch.cuda.is_available() else torch.LongTensor\n", 43 | "print(ttype)\n", 44 | "from torch.nn.utils import weight_norm\n", 45 | "from lmu import LegendreMemoryUnit\n", 46 | "\n", 47 | "from tqdm.notebook import tqdm\n", 48 | "import pickle\n", 49 | "sn.set_context(\"poster\")\n", 50 | "import itertools\n", 51 | "from csv import DictWriter\n", 52 | "import matplotlib.pylab as plt\n", 53 | "import csv\n", 54 | "import numpy as np\n", 55 | "import pandas as pd\n", 56 | "import os\n", 57 | "import seaborn as sn\n", 58 | "sn.set_context('talk')" 59 | ] 60 | }, 61 | { 62 | "cell_type": "code", 63 | "execution_count": null, 64 | "metadata": { 65 | "ExecuteTime": { 66 | "end_time": "2021-05-12T18:04:54.848136Z", 67 | "start_time": "2021-05-12T18:04:54.764814Z" 68 | } 69 | }, 70 | "outputs": [], 71 | "source": [ 72 | "def generate_noise(maxn=18):\n", 73 | " \"\"\"Generates dot and dash based noise.\"\"\"\n", 74 | " \n", 75 | " threes = np.random.randint(int(.5*maxn), int(.75*maxn))\n", 76 | " ones = (maxn - threes) * 2\n", 77 | " noise = list(itertools.repeat([1,1,1,0], threes))\n", 78 | " noise[:int(len(noise)/3)] = list(itertools.repeat([0,0], int(len(noise)/3)))\n", 79 | " ones = ones + int(len(noise)/3)\n", 80 | " noise.extend(list(itertools.repeat([1,0], ones)))\n", 81 | " random.shuffle(noise)\n", 82 | " noise = np.concatenate(noise)\n", 83 | " return noise\n", 84 | "noise = generate_noise()\n", 85 | "print(noise.shape)\n", 86 | "plt.plot(noise)\n" 87 | ] 88 | }, 89 | { 90 | "cell_type": "code", 91 | "execution_count": null, 92 | "metadata": { 93 | "ExecuteTime": { 94 | "end_time": "2021-05-12T18:04:55.827538Z", 95 | "start_time": "2021-05-12T18:04:54.849239Z" 96 | } 97 | }, 98 | "outputs": [], 99 | "source": [ 100 | "sig_lets = [\"A\",\"B\",\"C\",\"D\",\"E\",\"F\",\"G\",\"H\",]\n", 101 | "\n", 102 | "signals = ttype([[0,1,1,1,0,1,1,1,0,1,0,1,0,1,0,0,0],\n", 103 | " [0,1,1,1,0,1,0,1,1,1,0,1,0,1,0,0,0],\n", 104 | " [0,1,1,1,0,1,0,1,0,1,1,1,0,1,0,0,0],\n", 105 | " [0,1,1,1,0,1,0,1,0,1,0,1,1,1,0,0,0],\n", 106 | " \n", 107 | " [0,1,0,1,1,1,0,1,1,1,0,1,1,1,0,0,0],\n", 108 | " [0,1,1,1,0,1,0,1,1,1,0,1,1,1,0,0,0],\n", 109 | " [0,1,1,1,0,1,1,1,0,1,0,1,1,1,0,0,0],\n", 110 | " [0,1,1,1,0,1,1,1,0,1,1,1,0,1,0,0,0],\n", 111 | "\n", 112 | " ]\n", 113 | " ).view(8, 1, 1, -1)\n", 114 | "\n", 115 | "plt.imshow(signals[:,0,0,:].detach().cpu())\n" 116 | ] 117 | }, 118 | { 119 | "cell_type": "code", 120 | "execution_count": null, 121 | "metadata": { 122 | "ExecuteTime": { 123 | "end_time": "2021-05-12T18:04:55.987629Z", 124 | "start_time": "2021-05-12T18:04:55.828588Z" 125 | } 126 | }, 127 | "outputs": [], 128 | "source": [ 129 | "torch.manual_seed(12345)\n", 130 | "np.random.seed(12345)\n", 131 | "training_samples = 32\n", 132 | "\n", 133 | "training_signals = []\n", 134 | "training_class = []\n", 135 | "\n", 136 | "for i, sig in enumerate(signals):\n", 137 | " temp_signals = []\n", 138 | " temp_class = []\n", 139 | " for x in range(training_samples):\n", 140 | " noise = ttype(generate_noise())\n", 141 | " temp = torch.cat([sig[0,0], noise]).unsqueeze(0)\n", 142 | " while(any([(temp == c_).all() for c_ in temp_signals])):\n", 143 | " noise = ttype(generate_noise())\n", 144 | " temp = torch.cat([sig[0,0], noise]).unsqueeze(0)\n", 145 | " temp_signals.append(temp)\n", 146 | " temp_class.append(i)\n", 147 | " training_signals.extend(temp_signals)\n", 148 | " training_class.extend(temp_class)\n", 149 | "\n", 150 | "batch_rand = torch.randperm(training_samples*signals.shape[0]) \n", 151 | "training_signals = torch.cat(training_signals).cuda().unsqueeze(-1)[batch_rand]\n", 152 | "training_class = ctype(training_class).cuda().unsqueeze(-1)[batch_rand]\n", 153 | "\n", 154 | "dataset = torch.utils.data.TensorDataset(training_signals, training_class)\n", 155 | "dataset = torch.utils.data.DataLoader(dataset, batch_size=32, shuffle=True)\n", 156 | "\n" 157 | ] 158 | }, 159 | { 160 | "cell_type": "code", 161 | "execution_count": null, 162 | "metadata": { 163 | "ExecuteTime": { 164 | "end_time": "2021-05-12T18:04:56.934134Z", 165 | "start_time": "2021-05-12T18:04:56.388410Z" 166 | } 167 | }, 168 | "outputs": [], 169 | "source": [ 170 | "testing_samples = 10\n", 171 | "testing_signals = []\n", 172 | "testing_class = []\n", 173 | "\n", 174 | "for i, sig in enumerate(signals):\n", 175 | " temp_signals = []\n", 176 | " temp_class = []\n", 177 | " for x in range(testing_samples):\n", 178 | " noise = ttype(generate_noise())\n", 179 | " temp = torch.cat([sig[0,0], noise]).unsqueeze(0)\n", 180 | " while(any([(temp == c_).all() for c_ in temp_signals]) or any([(temp == c_).all() for c_ in training_signals])):\n", 181 | " noise = ttype(generate_noise())\n", 182 | " temp = torch.cat([sig[0,0], noise]).unsqueeze(0)\n", 183 | " temp_signals.append(temp)\n", 184 | " temp_class.append(i)\n", 185 | " testing_signals.extend(temp_signals)\n", 186 | " testing_class.extend(temp_class)\n", 187 | "batch_rand = torch.randperm(testing_samples*signals.shape[0])\n", 188 | "\n", 189 | "testing_signals = torch.cat(testing_signals).cuda().unsqueeze(-1)[batch_rand]\n", 190 | "testing_class = ctype(testing_class).cuda().unsqueeze(-1)[batch_rand]\n", 191 | "\n", 192 | "\n", 193 | "dataset_valid = torch.utils.data.TensorDataset(testing_signals, testing_class)\n", 194 | "dataset_valid = torch.utils.data.DataLoader(dataset_valid, batch_size=32, shuffle=False)\n" 195 | ] 196 | }, 197 | { 198 | "cell_type": "code", 199 | "execution_count": null, 200 | "metadata": { 201 | "ExecuteTime": { 202 | "end_time": "2021-05-12T18:04:57.496406Z", 203 | "start_time": "2021-05-12T18:04:57.489480Z" 204 | } 205 | }, 206 | "outputs": [], 207 | "source": [ 208 | "class LMUModel(nn.Module):\n", 209 | " def __init__(self, n_out, layer_params):\n", 210 | " super(LMUModel, self).__init__()\n", 211 | " self.layers = nn.ModuleList([LegendreMemoryUnit(**layer_params[i])\n", 212 | " for i in range(len(layer_params))])\n", 213 | " self.dense = nn.Linear(layer_params[-1]['hidden_size'], n_out)\n", 214 | "\n", 215 | " \n", 216 | " def forward(self, x):\n", 217 | " for l in self.layers:\n", 218 | " x, _ = l(x) \n", 219 | " x = self.dense(x)\n", 220 | " return x" 221 | ] 222 | }, 223 | { 224 | "cell_type": "code", 225 | "execution_count": null, 226 | "metadata": { 227 | "ExecuteTime": { 228 | "end_time": "2021-05-12T18:04:57.965310Z", 229 | "start_time": "2021-05-12T18:04:57.950810Z" 230 | } 231 | }, 232 | "outputs": [], 233 | "source": [ 234 | "def train(model, ttype, train_loader, test_loader, optimizer, loss_func, epoch, perf_file,\n", 235 | " permute=None, loss_buffer_size=64, batch_size=4, device='cuda',\n", 236 | " prog_bar=None, maxn=6):\n", 237 | " \n", 238 | " assert(loss_buffer_size%batch_size==0)\n", 239 | " \n", 240 | " losses = []\n", 241 | " perfs = []\n", 242 | " last_test_perf = 0\n", 243 | " best_test_perf = -1\n", 244 | " \n", 245 | " for batch_idx, (data, target) in enumerate(train_loader):\n", 246 | " model.train()\n", 247 | " data = data.to(device)\n", 248 | " target = target.to(device)\n", 249 | " optimizer.zero_grad()\n", 250 | " out = model(data)\n", 251 | " loss = loss_func(out[:,-1],\n", 252 | " target[:, 0])\n", 253 | " \n", 254 | " loss.backward()\n", 255 | " optimizer.step()\n", 256 | "\n", 257 | " perfs.append((torch.argmax(out[:,-1], dim=-1) == \n", 258 | " target[:, 0]).sum().item())\n", 259 | " perfs = perfs[int(-loss_buffer_size/batch_size):]\n", 260 | " losses.append(loss.detach().cpu().numpy())\n", 261 | " losses = losses[int(-loss_buffer_size/batch_size):]\n", 262 | " if not (prog_bar is None):\n", 263 | " # Update progress_bar\n", 264 | " s = \"{}:{} Loss: {:.4f}, perf: {:.4f}, valid: {:.4f}\"\n", 265 | " format_list = [e,batch_idx*batch_size, np.mean(losses), \n", 266 | " np.sum(perfs)/((len(perfs))*batch_size), last_test_perf] \n", 267 | " s = s.format(*format_list)\n", 268 | " prog_bar.set_description(s)\n", 269 | " \n", 270 | " if ((batch_idx*batch_size)%loss_buffer_size == 0) & (batch_idx != 0):\n", 271 | " loss_track = {}\n", 272 | " last_test_perf = test(model, 'cuda', test_loader, \n", 273 | " batch_size=batch_size, \n", 274 | " permute=permute)\n", 275 | " loss_track['avg_loss'] = np.mean(losses)\n", 276 | " loss_track['last_test'] = last_test_perf\n", 277 | " loss_track['epoch'] = epoch\n", 278 | " loss_track['maxn'] = maxn\n", 279 | " loss_track['batch_idx'] = batch_idx\n", 280 | " loss_track['pres_num'] = batch_idx*batch_size + epoch*len(train_loader.dataset)\n", 281 | " loss_track['train_perf']= np.sum(perfs)/((len(perfs))*batch_size)\n", 282 | " with open(perf_file, 'a+') as fp:\n", 283 | " csv_writer = DictWriter(fp, fieldnames=list(loss_track.keys()))\n", 284 | " if fp.tell() == 0:\n", 285 | " csv_writer.writeheader()\n", 286 | " csv_writer.writerow(loss_track)\n", 287 | " fp.flush()\n", 288 | " if best_test_perf < last_test_perf:\n", 289 | " torch.save(model.state_dict(), perf_file[:-4]+\".pt\")\n", 290 | " best_test_perf = last_test_perf\n", 291 | "\n", 292 | " \n", 293 | "def test(model, device, test_loader, batch_size=4, permute=None):\n", 294 | " model.eval()\n", 295 | " correct = 0\n", 296 | " count = 0\n", 297 | " with torch.no_grad():\n", 298 | " for data, target in test_loader:\n", 299 | " data = data.to(device)\n", 300 | " target = target.to(device)\n", 301 | " \n", 302 | " out = model(data)\n", 303 | " pred = out[:,-1].argmax(dim=-1, keepdim=True)\n", 304 | " correct += pred.eq(target.view_as(pred)).sum().item()\n", 305 | " count += 1\n", 306 | " return correct / len(test_loader.dataset)" 307 | ] 308 | }, 309 | { 310 | "cell_type": "markdown", 311 | "metadata": {}, 312 | "source": [ 313 | "# Training and testing" 314 | ] 315 | }, 316 | { 317 | "cell_type": "code", 318 | "execution_count": null, 319 | "metadata": { 320 | "ExecuteTime": { 321 | "end_time": "2021-05-12T18:05:02.299619Z", 322 | "start_time": "2021-05-12T18:05:02.293735Z" 323 | } 324 | }, 325 | "outputs": [], 326 | "source": [ 327 | "# You likely don't need this to be this long, but just in case.\n", 328 | "epochs = 400\n", 329 | "\n", 330 | "# Just for visualizing average loss through time. \n", 331 | "loss_buffer_size = 100" 332 | ] 333 | }, 334 | { 335 | "cell_type": "code", 336 | "execution_count": null, 337 | "metadata": { 338 | "ExecuteTime": { 339 | "end_time": "2021-05-12T18:05:07.312643Z", 340 | "start_time": "2021-05-12T18:05:03.479007Z" 341 | }, 342 | "scrolled": true 343 | }, 344 | "outputs": [], 345 | "source": [ 346 | "test_noise_lengths = [6,7,9,13,21,37]\n", 347 | "for maxn in test_noise_lengths:\n", 348 | " torch.manual_seed(12345)\n", 349 | " np.random.seed(12345)\n", 350 | " training_samples = 32\n", 351 | "\n", 352 | " training_signals = []\n", 353 | " training_class = []\n", 354 | "\n", 355 | " for i, sig in enumerate(signals):\n", 356 | " temp_signals = []\n", 357 | " temp_class = []\n", 358 | " for x in range(training_samples):\n", 359 | " noise = ttype(generate_noise(maxn))\n", 360 | " temp = torch.cat([sig[0,0], noise]).unsqueeze(0)\n", 361 | " while(any([(temp == c_).all() for c_ in temp_signals])):\n", 362 | " noise = ttype(generate_noise(maxn))\n", 363 | " temp = torch.cat([sig[0,0], noise]).unsqueeze(0)\n", 364 | " temp_signals.append(temp)\n", 365 | " temp_class.append(i)\n", 366 | " training_signals.extend(temp_signals)\n", 367 | " training_class.extend(temp_class)\n", 368 | "\n", 369 | " batch_rand = torch.randperm(training_samples*signals.shape[0]) \n", 370 | " training_signals = torch.cat(training_signals).cuda().unsqueeze(-1)[batch_rand]\n", 371 | " training_class = ctype(training_class).cuda().unsqueeze(-1)[batch_rand]\n", 372 | "\n", 373 | " dataset = torch.utils.data.TensorDataset(training_signals, training_class)\n", 374 | " dataset = torch.utils.data.DataLoader(dataset, batch_size=32, shuffle=True)\n", 375 | " testing_samples = 10\n", 376 | " testing_signals = []\n", 377 | " testing_class = []\n", 378 | "\n", 379 | " for i, sig in enumerate(signals):\n", 380 | " temp_signals = []\n", 381 | " temp_class = []\n", 382 | " for x in range(testing_samples):\n", 383 | " noise = ttype(generate_noise(maxn))\n", 384 | " temp = torch.cat([sig[0,0], noise]).unsqueeze(0)\n", 385 | " while(any([(temp == c_).all() for c_ in temp_signals]) or any([(temp == c_).all() for c_ in training_signals])):\n", 386 | " noise = ttype(generate_noise(maxn))\n", 387 | " temp = torch.cat([sig[0,0], noise]).unsqueeze(0)\n", 388 | " temp_signals.append(temp)\n", 389 | " temp_class.append(i)\n", 390 | " testing_signals.extend(temp_signals)\n", 391 | " testing_class.extend(temp_class)\n", 392 | " batch_rand = torch.randperm(testing_samples*signals.shape[0])\n", 393 | "\n", 394 | " testing_signals = torch.cat(testing_signals).cuda().unsqueeze(-1)[batch_rand]\n", 395 | " testing_class = ctype(testing_class).cuda().unsqueeze(-1)[batch_rand]\n", 396 | "\n", 397 | "\n", 398 | " dataset_valid = torch.utils.data.TensorDataset(testing_signals, testing_class)\n", 399 | " dataset_valid = torch.utils.data.DataLoader(dataset_valid, batch_size=32, shuffle=False)\n", 400 | "\n", 401 | " hz=125\n", 402 | "\n", 403 | " lmu_params = [dict(input_dim=1, hidden_size=hz, order=40, theta=temp.shape[-1]),\n", 404 | " #dict(input_dim=hz, hidden_size=hz, order=4, theta=4),\n", 405 | " #dict(input_dim=hz, hidden_size=hz, order=4, theta=4),\n", 406 | " ]\n", 407 | " model = LMUModel(8, lmu_params).cuda()\n", 408 | "\n", 409 | " tot_weights = 0\n", 410 | " for p in model.parameters():\n", 411 | " tot_weights += p.numel()\n", 412 | " print(\"Total Weights:\", tot_weights)\n", 413 | " print(model)\n", 414 | " loss_func = torch.nn.CrossEntropyLoss()\n", 415 | " optimizer = torch.optim.Adam(model.parameters())\n", 416 | " epochs = 400\n", 417 | " batch_size = 32\n", 418 | " progress_bar = tqdm(range(int(epochs)), bar_format='{l_bar}{bar:5}{r_bar}{bar:-5b}')\n", 419 | " for e in progress_bar:\n", 420 | " train(model, ttype, dataset, dataset_valid, \n", 421 | " optimizer, loss_func, batch_size=batch_size,\n", 422 | " epoch=e, perf_file=join('perf','h8_LMU_length_6.csv'),\n", 423 | " prog_bar=progress_bar, maxn=maxn)\n", 424 | " " 425 | ] 426 | }, 427 | { 428 | "cell_type": "code", 429 | "execution_count": null, 430 | "metadata": {}, 431 | "outputs": [], 432 | "source": [] 433 | } 434 | ], 435 | "metadata": { 436 | "kernelspec": { 437 | "display_name": "Python 3", 438 | "language": "python", 439 | "name": "python3" 440 | }, 441 | "language_info": { 442 | "codemirror_mode": { 443 | "name": "ipython", 444 | "version": 3 445 | }, 446 | "file_extension": ".py", 447 | "mimetype": "text/x-python", 448 | "name": "python", 449 | "nbconvert_exporter": "python", 450 | "pygments_lexer": "ipython3", 451 | "version": "3.6.10" 452 | }, 453 | "toc": { 454 | "nav_menu": {}, 455 | "number_sections": true, 456 | "sideBar": true, 457 | "skip_h1_title": false, 458 | "title_cell": "Table of Contents", 459 | "title_sidebar": "Contents", 460 | "toc_cell": false, 461 | "toc_position": {}, 462 | "toc_section_display": true, 463 | "toc_window_display": false 464 | } 465 | }, 466 | "nbformat": 4, 467 | "nbformat_minor": 4 468 | } 469 | --------------------------------------------------------------------------------