├── figures
└── model_config.png
├── experiments
├── MackeyGlass
│ ├── attract.dill
│ ├── cornn.py
│ ├── lmu.py
│ ├── MackeyGlass-LSTM-increasing_tau.ipynb
│ ├── MackeyGlass-LMU-increasing_tau.ipynb
│ ├── MackeyGlass-DeepSITH-increasing_tau.ipynb
│ └── MackeyGlass-coRNN-increasing_tau.ipynb
├── AddingProblem
│ ├── cornn.py
│ ├── lmu.py
│ ├── addingproblem-LSTM.ipynb
│ ├── addingproblem-LMU.ipynb
│ └── addingproblem-coRNN.ipynb
├── Hateful8
│ ├── cornn.py
│ ├── lmu.py
│ ├── hateful8-LSTM.ipynb
│ └── hateful8-LMU.ipynb
└── psMNIST
│ ├── visualization.ipynb
│ ├── sMNIST.ipynb
│ └── psMNIST.ipynb
├── deepsith
├── __init__.py
├── deepsith.py
└── isith.py
├── setup.py
├── .gitignore
└── README.md
/figures/model_config.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/compmem/DeepSITH/HEAD/figures/model_config.png
--------------------------------------------------------------------------------
/experiments/MackeyGlass/attract.dill:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/compmem/DeepSITH/HEAD/experiments/MackeyGlass/attract.dill
--------------------------------------------------------------------------------
/deepsith/__init__.py:
--------------------------------------------------------------------------------
1 | from .isith import iSITH
2 | from .deepsith import DeepSITH
3 |
4 | __copyright__ = "2020, Computational Memory Lab"
5 | __version__ = "0.1"
--------------------------------------------------------------------------------
/setup.py:
--------------------------------------------------------------------------------
1 | from setuptools import setup
2 |
3 | setup(name='deepsith',
4 | version='0.1',
5 | description='DeepSITH: Scale-Invariant Temporal History Across Hidden Layers',
6 | url='https://github.com/beegica/SITH-Con',
7 | license='Free for non-commercial use',
8 | author='Computational Memory Lab',
9 | author_email='bgj5hk@virginia.edu',
10 | packages=['deepsith'],
11 | install_requires=[
12 | 'torch>=1.1.0',
13 |
14 | ],
15 | zip_safe=False)
--------------------------------------------------------------------------------
/experiments/AddingProblem/cornn.py:
--------------------------------------------------------------------------------
1 | from torch import nn
2 | import torch
3 | from torch.autograd import Variable
4 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
5 |
6 | class coRNNCell(nn.Module):
7 | def __init__(self, n_inp, n_hid, dt, gamma, epsilon):
8 | super(coRNNCell, self).__init__()
9 | self.dt = dt
10 | self.gamma = gamma
11 | self.epsilon = epsilon
12 | self.i2h = nn.Linear(n_inp + n_hid + n_hid, n_hid)
13 |
14 | def forward(self,x,hy,hz):
15 | hz = hz + self.dt * (torch.tanh(self.i2h(torch.cat((x, hz, hy),1)))
16 | - self.gamma * hy - self.epsilon * hz)
17 | hy = hy + self.dt * hz
18 |
19 | return hy, hz
20 |
21 | class coRNN(nn.Module):
22 | def __init__(self, n_inp, n_hid, n_out, dt, gamma, epsilon):
23 | super(coRNN, self).__init__()
24 | self.n_hid = n_hid
25 | self.cell = coRNNCell(n_inp,n_hid,dt,gamma,epsilon)
26 | self.readout = nn.Linear(n_hid, n_out)
27 |
28 | def forward(self, x):
29 | ## initialize hidden states
30 | hy = Variable(torch.zeros(x.size(1),self.n_hid)).to(device)
31 | hz = Variable(torch.zeros(x.size(1),self.n_hid)).to(device)
32 |
33 | for t in range(x.size(0)):
34 | hy, hz = self.cell(x[t],hy,hz)
35 | output = self.readout(hy)
36 |
37 | return output
38 |
--------------------------------------------------------------------------------
/experiments/Hateful8/cornn.py:
--------------------------------------------------------------------------------
1 | from torch import nn
2 | import torch
3 | from torch.autograd import Variable
4 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
5 |
6 | class coRNNCell(nn.Module):
7 | def __init__(self, n_inp, n_hid, dt, gamma, epsilon):
8 | super(coRNNCell, self).__init__()
9 | self.dt = dt
10 | self.gamma = gamma
11 | self.epsilon = epsilon
12 | self.i2h = nn.Linear(n_inp + n_hid + n_hid, n_hid)
13 |
14 | def forward(self,x,hy,hz):
15 | hz = hz + self.dt * (torch.tanh(self.i2h(torch.cat((x, hz, hy),1)))
16 | - self.gamma * hy - self.epsilon * hz)
17 | hy = hy + self.dt * hz
18 |
19 | return hy, hz
20 |
21 | class coRNN(nn.Module):
22 | def __init__(self, n_inp, n_hid, n_out, dt, gamma, epsilon):
23 | super(coRNN, self).__init__()
24 | self.n_hid = n_hid
25 | self.cell = coRNNCell(n_inp,n_hid,dt,gamma,epsilon)
26 | self.readout = nn.Linear(n_hid, n_out)
27 |
28 | def forward(self, x):
29 | ## initialize hidden states
30 | hy = Variable(torch.zeros(x.size(1),self.n_hid)).to(device)
31 | hz = Variable(torch.zeros(x.size(1),self.n_hid)).to(device)
32 |
33 | for t in range(x.size(0)):
34 | hy, hz = self.cell(x[t],hy,hz)
35 | print(hy.shape)
36 | output = self.readout(hy)
37 |
38 | return output
39 |
--------------------------------------------------------------------------------
/experiments/MackeyGlass/cornn.py:
--------------------------------------------------------------------------------
1 | from torch import nn
2 | import torch
3 | from torch.autograd import Variable
4 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
5 |
6 | class coRNNCell(nn.Module):
7 | def __init__(self, n_inp, n_hid, dt, gamma, epsilon):
8 | super(coRNNCell, self).__init__()
9 | self.dt = dt
10 | self.gamma = gamma
11 | self.epsilon = epsilon
12 | self.i2h = nn.Linear(n_inp + n_hid + n_hid, n_hid)
13 |
14 | def forward(self,x,hy,hz):
15 | hz = hz + self.dt * (torch.tanh(self.i2h(torch.cat((x, hz, hy),1)))
16 | - self.gamma * hy - self.epsilon * hz)
17 | hy = hy + self.dt * hz
18 |
19 | return hy, hz
20 |
21 | class coRNN(nn.Module):
22 | def __init__(self, n_inp, n_hid, n_out, dt, gamma, epsilon):
23 | super(coRNN, self).__init__()
24 | self.n_hid = n_hid
25 | self.cell = coRNNCell(n_inp,n_hid,dt,gamma,epsilon)
26 | self.readout = nn.Linear(n_hid, n_out)
27 |
28 | def forward(self, x):
29 | ## initialize hidden states
30 | hy = Variable(torch.zeros(x.size(1),self.n_hid)).to(device)
31 | hz = Variable(torch.zeros(x.size(1),self.n_hid)).to(device)
32 |
33 | for t in range(x.size(0)):
34 | hy, hz = self.cell(x[t],hy,hz)
35 | print(hy.shape)
36 | output = self.readout(hy)
37 |
38 | return output
39 |
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | # Byte-compiled / optimized / DLL files
2 | __pycache__/
3 | *.py[cod]
4 | *$py.class
5 | *.png
6 | # C extensions
7 | *.so
8 |
9 | # Distribution / packaging
10 | .Python
11 | build/
12 | develop-eggs/
13 | dist/
14 | downloads/
15 | eggs/
16 | .eggs/
17 | lib/
18 | lib64/
19 | parts/
20 | sdist/
21 | var/
22 | wheels/
23 | pip-wheel-metadata/
24 | share/python-wheels/
25 | *.egg-info/
26 | .installed.cfg
27 | *.egg
28 | MANIFEST
29 |
30 | # PyInstaller
31 | # Usually these files are written by a python script from a template
32 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
33 | *.manifest
34 | *.spec
35 |
36 | # Installer logs
37 | pip-log.txt
38 | pip-delete-this-directory.txt
39 |
40 | # Unit test / coverage reports
41 | htmlcov/
42 | .tox/
43 | .nox/
44 | .coverage
45 | .coverage.*
46 | .cache
47 | nosetests.xml
48 | coverage.xml
49 | *.cover
50 | *.py,cover
51 | .hypothesis/
52 | .pytest_cache/
53 |
54 | # Translations
55 | *.mo
56 | *.pot
57 |
58 | # Django stuff:
59 | *.log
60 | local_settings.py
61 | db.sqlite3
62 | db.sqlite3-journal
63 |
64 | # Flask stuff:
65 | instance/
66 | .webassets-cache
67 |
68 | # Scrapy stuff:
69 | .scrapy
70 |
71 | # Sphinx documentation
72 | docs/_build/
73 |
74 | # PyBuilder
75 | target/
76 |
77 | # Jupyter Notebook
78 | .ipynb_checkpoints
79 |
80 | # IPython
81 | profile_default/
82 | ipython_config.py
83 |
84 | # pyenv
85 | .python-version
86 |
87 | # pipenv
88 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
89 | # However, in case of collaboration, if having platform-specific dependencies or dependencies
90 | # having no cross-platform support, pipenv may install dependencies that don't work, or not
91 | # install all needed dependencies.
92 | #Pipfile.lock
93 |
94 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow
95 | __pypackages__/
96 |
97 | # Celery stuff
98 | celerybeat-schedule
99 | celerybeat.pid
100 |
101 | # SageMath parsed files
102 | *.sage.py
103 |
104 | # Environments
105 | .env
106 | .venv
107 | env/
108 | venv/
109 | ENV/
110 | env.bak/
111 | venv.bak/
112 |
113 | # Spyder project settings
114 | .spyderproject
115 | .spyproject
116 |
117 | # Rope project settings
118 | .ropeproject
119 |
120 | # mkdocs documentation
121 | /site
122 |
123 | # mypy
124 | .mypy_cache/
125 | .dmypy.json
126 | dmypy.json
127 |
128 | # Pyre type checker
129 | .pyre/
130 |
131 | data/
132 | perf/
133 |
--------------------------------------------------------------------------------
/deepsith/deepsith.py:
--------------------------------------------------------------------------------
1 | # Deep SITH
2 | # PyTorch version 0.1.0
3 | # Authors: Brandon G. Jacques and Per B. Sederberg
4 |
5 | import torch
6 | from torch import nn
7 | from .isith import iSITH
8 |
9 |
10 | from deepsith import iSITH
11 | class _DeepSITH_core(nn.Module):
12 | def __init__(self, layer_params):
13 | super(_DeepSITH_core, self).__init__()
14 |
15 | hidden_size = layer_params.pop('hidden_size', layer_params['in_features'])
16 | in_features = layer_params.pop('in_features', None)
17 | batch_norm = layer_params.pop('batch_norm', True)
18 | act_func = layer_params.pop('act_func', None)
19 | self.batch_norm = batch_norm
20 | self.act_func = not (act_func is None)
21 | self.sith = iSITH(**layer_params)
22 |
23 | self.linear = nn.Linear(layer_params['ntau']*in_features,
24 | hidden_size)
25 | nn.init.kaiming_normal_(self.linear.weight.data)
26 |
27 | if not (act_func is None):
28 | self.act_func = act_func
29 | if batch_norm:
30 | self.dense_bn = nn.BatchNorm1d(hidden_size)
31 |
32 | def forward(self, inp):
33 | # Outputs as : [Batch, features, tau, sequence]
34 | x = self.sith(inp)
35 |
36 | x = x.transpose(3,2).transpose(2,1)
37 | x = x.view(x.shape[0], x.shape[1], -1)
38 | x = self.linear(x)
39 | if self.act_func:
40 | x = self.act_func(x)
41 | if self.batch_norm:
42 | x = x.transpose(2,1)
43 | x = self.dense_bn(x).transpose(2,1)
44 | return x
45 |
46 | class DeepSITH(nn.Module):
47 | """A Module built for SITH like an LSTM
48 |
49 | Parameters
50 | ----------
51 | layer_params: list
52 | A list of dictionaries for each layer in the desired DeepSITH. All
53 | of the parameters needed for the SITH part of the Layers, as well as
54 | a hidden_size and optional act_func are required to be present.
55 |
56 | layer_params keys
57 | -----------------
58 | hidden_size: int (default in_features)
59 | The size of the output of the hidden layer. Please note that the
60 | in_features parameter for the next layer's SITH representation should be
61 | equal to the previous layer's hidden_size. This parameter will default
62 | to the in_features of the current SITH layer if not specified.
63 | act_func: torch.nn.Module (default None)
64 | The torch layer of the desired activation function, or None if no
65 | there is no desired activation function between layers.
66 |
67 | In addition to these keys, you must include all of the non-optional SITH
68 | layer keys in each dictionary. Please see the SITH docstring for
69 | suggestions.
70 |
71 | """
72 | def __init__(self, layer_params, dropout=.5):
73 | super(DeepSITH, self).__init__()
74 | self.layers = nn.ModuleList([_DeepSITH_core(layer_params[i])
75 | for i in range(len(layer_params))])
76 | self.dropouts = nn.ModuleList([nn.Dropout(dropout) for i in range(len(layer_params) - 1)])
77 |
78 | def forward(self, inp):
79 | x = inp
80 | for i, l in enumerate(self.layers[:-1]):
81 | x = l(x)
82 | x = self.dropouts[i](x)
83 | x = x.unsqueeze(1).transpose(3,2)
84 | x = self.layers[-1](x)
85 | return x
--------------------------------------------------------------------------------
/deepsith/isith.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from math import factorial, log
3 | # Impulse-based SITH class
4 | class iSITH(torch.nn.Module):
5 | def __init__(self, tau_min=.1, tau_max=100., buff_max=None, k=50, ntau=50, dt=1, g=0.0,
6 | ttype=torch.FloatTensor):
7 | super(iSITH, self).__init__()
8 | """A SITH module using the perfect equation for the resulting ftilde
9 |
10 | Parameters
11 | ----------
12 |
13 | - tau_min: float
14 | The center of the temporal receptive field for the first taustar produced.
15 | - tau_max: float
16 | The center of the temporal receptive field for the last taustar produced.
17 | - buff_max: int
18 | The maximum time in which the filters go into the past. NOTE: In order to
19 | achieve as few edge effects as possible, buff_max needs to be bigger than
20 | tau_max, and dependent on k, such that the filters have enough time to reach
21 | very close to 0.0. Plot the filters and you will see them go to 0.
22 | - k: int
23 | Temporal Specificity of the taustars. If this number is high, then taustars
24 | will always be more narrow.
25 | - ntau: int
26 | Number of taustars produced, spread out logarithmically.
27 | - dt: float
28 | The time delta of the model. The there will be int(buff_max/dt) filters per
29 | taustar. Essentially this is the base rate of information being presented to the model
30 | - g: float
31 | Typically between 0 and 1. This parameter is the scaling factor of the output
32 | of the module. If set to 1, the output amplitude for a delta function will be
33 | identical through time. If set to 0, the amplitude will decay into the past,
34 | getting smaller and smaller. This value should be picked on an application to
35 | application basis.
36 | - ttype: Torch Tensor
37 | This is the type we set the internal mechanism of the model to before running.
38 | In order to calculate the filters, we must use a DoubleTensor, but this is no
39 | longer necessary after they are calculated. By default we set the filters to
40 | be FloatTensors. NOTE: If you plan to use CUDA, you need to pass in a
41 | cuda.FloatTensor as the ttype, as using .cuda() will not put these filters on
42 | the gpu.
43 |
44 |
45 | """
46 | self.k = k
47 | self.tau_min = tau_min
48 | self.tau_max = tau_max
49 | if buff_max is None:
50 | buff_max = 3*tau_max
51 | self.buff_max = buff_max
52 | self.ntau = ntau
53 | self.dt = dt
54 | self.g = g
55 |
56 | self.c = (tau_max/tau_min)**(1./(ntau-1))-1
57 |
58 | self.tau_star = tau_min*(1+self.c)**torch.arange(ntau).type(torch.DoubleTensor)
59 |
60 | self.times = torch.arange(dt, buff_max+dt, dt).type(torch.DoubleTensor)
61 |
62 | a = log(k)*k
63 | b = torch.log(torch.arange(2,k).type(torch.DoubleTensor)).sum()
64 |
65 | #A = ((1/self.tau_star)*(k**(k+1)/factorial(k))*(self.tau_star**self.g)).unsqueeze(1)
66 | A = ((1/self.tau_star)*(torch.exp(a-b))*(self.tau_star**self.g)).unsqueeze(1)
67 |
68 | self.filters = A*torch.exp((torch.log(self.times.unsqueeze(0)/self.tau_star.unsqueeze(1))*(k+1)) + \
69 | (k*(-self.times.unsqueeze(0)/self.tau_star.unsqueeze(1))))
70 |
71 | self.filters = torch.flip(self.filters, [-1]).unsqueeze(1).unsqueeze(1)
72 | self.filters = self.filters.type(ttype)
73 |
74 | def extra_repr(self):
75 | s = "ntau={ntau}, tau_min={tau_min}, tau_max={tau_max}, buff_max={buff_max}, dt={dt}, k={k}, g={g}"
76 | s = s.format(**self.__dict__)
77 | return s
78 |
79 | def forward(self, inp):
80 | """Takes in (Batch, 1, features, sequence) and returns (Batch, Taustar, features, sequence)"""
81 | assert(len(inp.shape) >= 4)
82 | out = torch.conv2d(inp, self.filters[:, :, :, -inp.shape[-1]:],
83 | padding=[0, self.filters[:, :, :, -inp.shape[-1]:].shape[-1]])
84 | #padding=[0, self.filters.shape[-1]])
85 | # note we're scaling the output by both dt and the k/(k+1)
86 | # Off by 1 introduced by the conv2d
87 | return out[:, :, :, 1:inp.shape[-1]+1]*self.dt*self.k/(self.k+1)
--------------------------------------------------------------------------------
/experiments/Hateful8/lmu.py:
--------------------------------------------------------------------------------
1 | """https://github.com/AbdouJaouhar/LMU-Legendre-Memory-Unit"""
2 |
3 | import torch
4 | import torch.nn as nn
5 | from sympy.matrices import Matrix, eye, zeros, ones, diag, GramSchmidt
6 | import numpy as np
7 | from functools import partial
8 | import torch.nn.functional as F
9 | import math
10 |
11 | from nengolib.signal import Identity, cont2discrete
12 | from nengolib.synapses import LegendreDelay
13 | from functools import partial
14 |
15 | '''
16 | Initialisation LECUN_UNIFOR
17 | - tensor to fill
18 | - fan_in is the input dimension size
19 | '''
20 | def lecun_uniform(tensor):
21 | fan_in = nn.init._calculate_correct_fan(tensor, 'fan_in')
22 | nn.init.uniform_(tensor, -math.sqrt(3 / fan_in), math.sqrt(3 / fan_in))
23 |
24 |
25 | class LMUCell(nn.Module):
26 |
27 | def __init__(self, input_size, hidden_size,
28 | order,
29 | theta=100, # relative to dt=1
30 | method='zoh',
31 | trainable_input_encoders=True,
32 | trainable_hidden_encoders=True,
33 | trainable_memory_encoders=True,
34 | trainable_input_kernel=True,
35 | trainable_hidden_kernel=True,
36 | trainable_memory_kernel=True,
37 | trainable_A=False,
38 | trainable_B=False,
39 | input_encoders_initializer=lecun_uniform,
40 | hidden_encoders_initializer=lecun_uniform,
41 | memory_encoders_initializer=partial(torch.nn.init.constant_, val=0),
42 | input_kernel_initializer=torch.nn.init.xavier_normal_,
43 | hidden_kernel_initializer=torch.nn.init.xavier_normal_,
44 | memory_kernel_initializer=torch.nn.init.xavier_normal_,
45 |
46 | hidden_activation='tanh',
47 | ):
48 | super(LMUCell, self).__init__()
49 |
50 | self.input_size = input_size
51 | self.hidden_size = hidden_size
52 | self.order = order
53 |
54 | if hidden_activation == 'tanh':
55 | self.hidden_activation = torch.tanh
56 | elif hidden_activation == 'relu':
57 | self.hidden_activation = torch.relu
58 | else:
59 | raise NotImplementedError("hidden activation '{}' is not implemented".format(hidden_activation))
60 |
61 | realizer = Identity()
62 | self._realizer_result = realizer(
63 | LegendreDelay(theta=theta, order=self.order))
64 | self._ss = cont2discrete(
65 | self._realizer_result.realization, dt=1., method=method)
66 | self._A = self._ss.A - np.eye(order) # puts into form: x += Ax
67 | self._B = self._ss.B
68 | self._C = self._ss.C
69 | assert np.allclose(self._ss.D, 0) # proper LTI
70 |
71 | self.input_encoders = nn.Parameter(torch.Tensor(1, input_size), requires_grad=trainable_input_encoders)
72 | self.hidden_encoders = nn.Parameter(torch.Tensor(1, hidden_size), requires_grad=trainable_hidden_encoders)
73 | self.memory_encoders = nn.Parameter(torch.Tensor(1, order), requires_grad=trainable_memory_encoders)
74 | self.input_kernel = nn.Parameter(torch.Tensor(hidden_size, input_size), requires_grad=trainable_input_kernel)
75 | self.hidden_kernel = nn.Parameter(torch.Tensor(hidden_size, hidden_size), requires_grad=trainable_hidden_kernel)
76 | self.memory_kernel = nn.Parameter(torch.Tensor(hidden_size, order), requires_grad=trainable_memory_kernel)
77 | self.AT = nn.Parameter(torch.Tensor(self._A), requires_grad=trainable_A)
78 | self.BT = nn.Parameter(torch.Tensor(self._B), requires_grad=trainable_B)
79 |
80 | # Initialize parameters
81 | input_encoders_initializer(self.input_encoders)
82 | hidden_encoders_initializer(self.hidden_encoders)
83 | memory_encoders_initializer(self.memory_encoders)
84 | input_kernel_initializer(self.input_kernel)
85 | hidden_kernel_initializer(self.hidden_kernel)
86 | memory_kernel_initializer(self.memory_kernel)
87 |
88 | def forward(self, input, hx):
89 |
90 | h, m = hx
91 |
92 | u = (F.linear(input, self.input_encoders) +
93 | F.linear(h, self.hidden_encoders) +
94 | F.linear(m, self.memory_encoders))
95 |
96 | m = m + F.linear(m, self.AT) + F.linear(u, self.BT)
97 |
98 | h = self.hidden_activation(
99 | F.linear(input, self.input_kernel) +
100 | F.linear(h, self.hidden_kernel) +
101 | F.linear(m, self.memory_kernel))
102 |
103 | return h, [h, m]
104 |
105 |
106 | class LegendreMemoryUnit(nn.Module):
107 | """
108 | Implementation of LMU using LegendreMemoryUnitCell so it can be used as LSTM or GRU in PyTorch Implementation (no GPU acceleration)
109 | """
110 | def __init__(self, input_dim, hidden_size, order, theta):
111 | super(LegendreMemoryUnit, self).__init__()
112 |
113 | self.hidden_size = hidden_size
114 | self.order = order
115 |
116 | self.lmucell = LMUCell(input_dim, hidden_size, order, theta)
117 |
118 | def forward(self, xt):
119 | outputs = []
120 |
121 | h0 = torch.zeros(xt.size(0),self.hidden_size).cuda()
122 | m0 = torch.zeros(xt.size(0),self.order).cuda()
123 | states = (h0,m0)
124 | for i in range(xt.size(1)):
125 | out, states = self.lmucell(xt[:,i,:], states)
126 | outputs += [out]
127 | return torch.stack(outputs).permute(1,0,2), states
128 |
129 |
--------------------------------------------------------------------------------
/experiments/MackeyGlass/lmu.py:
--------------------------------------------------------------------------------
1 | """https://github.com/AbdouJaouhar/LMU-Legendre-Memory-Unit"""
2 |
3 | import torch
4 | import torch.nn as nn
5 | from sympy.matrices import Matrix, eye, zeros, ones, diag, GramSchmidt
6 | import numpy as np
7 | from functools import partial
8 | import torch.nn.functional as F
9 | import math
10 |
11 | from nengolib.signal import Identity, cont2discrete
12 | from nengolib.synapses import LegendreDelay
13 | from functools import partial
14 |
15 | '''
16 | Initialisation LECUN_UNIFOR
17 | - tensor to fill
18 | - fan_in is the input dimension size
19 | '''
20 | def lecun_uniform(tensor):
21 | fan_in = nn.init._calculate_correct_fan(tensor, 'fan_in')
22 | nn.init.uniform_(tensor, -math.sqrt(3 / fan_in), math.sqrt(3 / fan_in))
23 |
24 |
25 | class LMUCell(nn.Module):
26 |
27 | def __init__(self, input_size, hidden_size,
28 | order,
29 | theta=100, # relative to dt=1
30 | method='zoh',
31 | trainable_input_encoders=True,
32 | trainable_hidden_encoders=True,
33 | trainable_memory_encoders=True,
34 | trainable_input_kernel=True,
35 | trainable_hidden_kernel=True,
36 | trainable_memory_kernel=True,
37 | trainable_A=False,
38 | trainable_B=False,
39 | input_encoders_initializer=lecun_uniform,
40 | hidden_encoders_initializer=lecun_uniform,
41 | memory_encoders_initializer=partial(torch.nn.init.constant_, val=0),
42 | input_kernel_initializer=torch.nn.init.xavier_normal_,
43 | hidden_kernel_initializer=torch.nn.init.xavier_normal_,
44 | memory_kernel_initializer=torch.nn.init.xavier_normal_,
45 |
46 | hidden_activation='tanh',
47 | ):
48 | super(LMUCell, self).__init__()
49 |
50 | self.input_size = input_size
51 | self.hidden_size = hidden_size
52 | self.order = order
53 |
54 | if hidden_activation == 'tanh':
55 | self.hidden_activation = torch.tanh
56 | elif hidden_activation == 'relu':
57 | self.hidden_activation = torch.relu
58 | else:
59 | raise NotImplementedError("hidden activation '{}' is not implemented".format(hidden_activation))
60 |
61 | realizer = Identity()
62 | self._realizer_result = realizer(
63 | LegendreDelay(theta=theta, order=self.order))
64 | self._ss = cont2discrete(
65 | self._realizer_result.realization, dt=1., method=method)
66 | self._A = self._ss.A - np.eye(order) # puts into form: x += Ax
67 | self._B = self._ss.B
68 | self._C = self._ss.C
69 | assert np.allclose(self._ss.D, 0) # proper LTI
70 |
71 | self.input_encoders = nn.Parameter(torch.Tensor(1, input_size), requires_grad=trainable_input_encoders)
72 | self.hidden_encoders = nn.Parameter(torch.Tensor(1, hidden_size), requires_grad=trainable_hidden_encoders)
73 | self.memory_encoders = nn.Parameter(torch.Tensor(1, order), requires_grad=trainable_memory_encoders)
74 | self.input_kernel = nn.Parameter(torch.Tensor(hidden_size, input_size), requires_grad=trainable_input_kernel)
75 | self.hidden_kernel = nn.Parameter(torch.Tensor(hidden_size, hidden_size), requires_grad=trainable_hidden_kernel)
76 | self.memory_kernel = nn.Parameter(torch.Tensor(hidden_size, order), requires_grad=trainable_memory_kernel)
77 | self.AT = nn.Parameter(torch.Tensor(self._A), requires_grad=trainable_A)
78 | self.BT = nn.Parameter(torch.Tensor(self._B), requires_grad=trainable_B)
79 |
80 | # Initialize parameters
81 | input_encoders_initializer(self.input_encoders)
82 | hidden_encoders_initializer(self.hidden_encoders)
83 | memory_encoders_initializer(self.memory_encoders)
84 | input_kernel_initializer(self.input_kernel)
85 | hidden_kernel_initializer(self.hidden_kernel)
86 | memory_kernel_initializer(self.memory_kernel)
87 |
88 | def forward(self, input, hx):
89 |
90 | h, m = hx
91 |
92 | u = (F.linear(input, self.input_encoders) +
93 | F.linear(h, self.hidden_encoders) +
94 | F.linear(m, self.memory_encoders))
95 |
96 | m = m + F.linear(m, self.AT) + F.linear(u, self.BT)
97 |
98 | h = self.hidden_activation(
99 | F.linear(input, self.input_kernel) +
100 | F.linear(h, self.hidden_kernel) +
101 | F.linear(m, self.memory_kernel))
102 |
103 | return h, [h, m]
104 |
105 |
106 | class LegendreMemoryUnit(nn.Module):
107 | """
108 | Implementation of LMU using LegendreMemoryUnitCell so it can be used as LSTM or GRU in PyTorch Implementation (no GPU acceleration)
109 | """
110 | def __init__(self, input_dim, hidden_size, order, theta):
111 | super(LegendreMemoryUnit, self).__init__()
112 |
113 | self.hidden_size = hidden_size
114 | self.order = order
115 |
116 | self.lmucell = LMUCell(input_dim, hidden_size, order, theta)
117 |
118 | def forward(self, xt):
119 | outputs = []
120 |
121 | h0 = torch.zeros(xt.size(0),self.hidden_size).cuda()
122 | m0 = torch.zeros(xt.size(0),self.order).cuda()
123 | states = (h0,m0)
124 | for i in range(xt.size(1)):
125 | out, states = self.lmucell(xt[:,i,:], states)
126 | outputs += [out]
127 | return torch.stack(outputs).permute(1,0,2), states
128 |
129 |
--------------------------------------------------------------------------------
/experiments/AddingProblem/lmu.py:
--------------------------------------------------------------------------------
1 | """https://github.com/AbdouJaouhar/LMU-Legendre-Memory-Unit"""
2 |
3 | import torch
4 | import torch.nn as nn
5 | from sympy.matrices import Matrix, eye, zeros, ones, diag, GramSchmidt
6 | import numpy as np
7 | from functools import partial
8 | import torch.nn.functional as F
9 | import math
10 |
11 | from nengolib.signal import Identity, cont2discrete
12 | from nengolib.synapses import LegendreDelay
13 | from functools import partial
14 |
15 | '''
16 | Initialisation LECUN_UNIFOR
17 | - tensor to fill
18 | - fan_in is the input dimension size
19 | '''
20 | def lecun_uniform(tensor):
21 | fan_in = nn.init._calculate_correct_fan(tensor, 'fan_in')
22 | nn.init.uniform_(tensor, -math.sqrt(3 / fan_in), math.sqrt(3 / fan_in))
23 |
24 |
25 | class LMUCell(nn.Module):
26 |
27 | def __init__(self, input_size, hidden_size,
28 | order,
29 | theta=100, # relative to dt=1
30 | method='zoh',
31 | trainable_input_encoders=True,
32 | trainable_hidden_encoders=True,
33 | trainable_memory_encoders=True,
34 | trainable_input_kernel=True,
35 | trainable_hidden_kernel=True,
36 | trainable_memory_kernel=True,
37 | trainable_A=False,
38 | trainable_B=False,
39 | input_encoders_initializer=lecun_uniform,
40 | hidden_encoders_initializer=lecun_uniform,
41 | memory_encoders_initializer=partial(torch.nn.init.constant_, val=0),
42 | input_kernel_initializer=torch.nn.init.xavier_normal_,
43 | hidden_kernel_initializer=torch.nn.init.xavier_normal_,
44 | memory_kernel_initializer=torch.nn.init.xavier_normal_,
45 |
46 | hidden_activation='tanh',
47 | ):
48 | super(LMUCell, self).__init__()
49 |
50 | self.input_size = input_size
51 | self.hidden_size = hidden_size
52 | self.order = order
53 |
54 | if hidden_activation == 'tanh':
55 | self.hidden_activation = torch.tanh
56 | elif hidden_activation == 'relu':
57 | self.hidden_activation = torch.relu
58 | else:
59 | raise NotImplementedError("hidden activation '{}' is not implemented".format(hidden_activation))
60 |
61 | realizer = Identity()
62 | self._realizer_result = realizer(
63 | LegendreDelay(theta=theta, order=self.order))
64 | self._ss = cont2discrete(
65 | self._realizer_result.realization, dt=1., method=method)
66 | self._A = self._ss.A - np.eye(order) # puts into form: x += Ax
67 | self._B = self._ss.B
68 | self._C = self._ss.C
69 | assert np.allclose(self._ss.D, 0) # proper LTI
70 |
71 | self.input_encoders = nn.Parameter(torch.Tensor(1, input_size), requires_grad=trainable_input_encoders)
72 | self.hidden_encoders = nn.Parameter(torch.Tensor(1, hidden_size), requires_grad=trainable_hidden_encoders)
73 | self.memory_encoders = nn.Parameter(torch.Tensor(1, order), requires_grad=trainable_memory_encoders)
74 | self.input_kernel = nn.Parameter(torch.Tensor(hidden_size, input_size), requires_grad=trainable_input_kernel)
75 | self.hidden_kernel = nn.Parameter(torch.Tensor(hidden_size, hidden_size), requires_grad=trainable_hidden_kernel)
76 | self.memory_kernel = nn.Parameter(torch.Tensor(hidden_size, order), requires_grad=trainable_memory_kernel)
77 | self.AT = nn.Parameter(torch.Tensor(self._A), requires_grad=trainable_A)
78 | self.BT = nn.Parameter(torch.Tensor(self._B), requires_grad=trainable_B)
79 |
80 | # Initialize parameters
81 | input_encoders_initializer(self.input_encoders)
82 | hidden_encoders_initializer(self.hidden_encoders)
83 | memory_encoders_initializer(self.memory_encoders)
84 | input_kernel_initializer(self.input_kernel)
85 | hidden_kernel_initializer(self.hidden_kernel)
86 | memory_kernel_initializer(self.memory_kernel)
87 |
88 | def forward(self, input, hx):
89 |
90 | h, m = hx
91 |
92 | u = (F.linear(input, self.input_encoders) +
93 | F.linear(h, self.hidden_encoders) +
94 | F.linear(m, self.memory_encoders))
95 |
96 | m = m + F.linear(m, self.AT) + F.linear(u, self.BT)
97 |
98 | h = self.hidden_activation(
99 | F.linear(input, self.input_kernel) +
100 | F.linear(h, self.hidden_kernel) +
101 | F.linear(m, self.memory_kernel))
102 |
103 | return h, [h, m]
104 |
105 |
106 | class LegendreMemoryUnit(nn.Module):
107 | """
108 | Implementation of LMU using LegendreMemoryUnitCell so it can be used as LSTM or GRU in PyTorch Implementation (no GPU acceleration)
109 | """
110 | def __init__(self, input_dim, hidden_size, order, theta):
111 | super(LegendreMemoryUnit, self).__init__()
112 |
113 | self.hidden_size = hidden_size
114 | self.order = order
115 |
116 | self.lmucell = LMUCell(input_dim, hidden_size, order, theta)
117 |
118 | def forward(self, xt):
119 | outputs = []
120 |
121 | h0 = torch.zeros(xt.size(0),self.hidden_size).cuda()
122 | m0 = torch.zeros(xt.size(0),self.order).cuda()
123 | states = (h0,m0)
124 | for i in range(xt.size(1)):
125 | out, states = self.lmucell(xt[:,i,:], states)
126 | outputs += [out]
127 | return torch.stack(outputs).permute(1,0,2), states
128 |
129 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 |
5 |
6 |
12 |
13 |
14 | ## Overview
15 |
16 | 
17 |
18 | Here, we introduce DeepSITH, a network comprising biologically-inspired Scale-Invariant Temporal History (SITH) modules in series with dense connections between layers. SITH modules respond to their inputs with a geometrically-spaced set of time constants, enabling the DeepSITH network to learn problems along a continuum of time-scales.
19 |
20 | ## Installation
21 |
22 | The easiest way to install the DeepSITH module is with pip.
23 |
24 | pip install .
25 |
26 | ### Requirements
27 |
28 | DeepSITH requires at least PyTorch 1.8.1. It works with cuda, so please follow the instructions for installing pytorch and cuda here.
29 |
30 | ## DeepSITH
31 | DeepSITH is a pytorch module implementing the neurally inspired SITH representation of working memory for use in neural networks. The paper outlining the work detailed in this repository was published at NeurIPS 2021 here.
32 |
33 | Jacques, B., Tiganj, Z., Howard, M., & Sederberg, P. (2021, December 6). DeepSITH: Efficient learning via decomposition of what and when across Time Scales. Advances in Neural Information Processing Systems.
34 |
35 | Primarily, this module utilizes SITH, the Scale-Invariant Temporal History, representation. With SITH, we are able to compress the history of a time series in the same way that human working memory might. For more information, please refer to the paper.
36 |
37 | ### DeepSITH use
38 |
39 | The DeepSITH module in pytorch will initialize as a series of deepsith layers, parameterized by the argument `layer_params`, which is a list of dictionaries. Below is an example initializing a 2 layer DeepSITH module, where the input time-series only has 1 feature.
40 |
41 | from deepsith import DeepSITH
42 | from torch import nn as nn
43 |
44 | # Tensor Type. Use torch.cuda.FloatTensor to put all SITH math
45 | # on the GPU.
46 | ttype = torch.FloatTensor
47 |
48 | sith_params1 = {"in_features":1,
49 | "tau_min":1, "tau_max":25.0, 'buff_max':40,
50 | "k":84, 'dt':1, "ntau":15, 'g':.0,
51 | "ttype":ttype, 'batch_norm':True,
52 | "hidden_size":35, "act_func":nn.ReLU()}
53 | sith_params2 = {"in_features":sith_params1['hidden_size'],
54 | "tau_min":1, "tau_max":100.0, 'buff_max':175,
55 | "k":40, 'dt':1, "ntau":15, 'g':.0,
56 | "ttype":ttype, 'batch_norm':True,
57 | "hidden_size":35, "act_func":nn.ReLU()}
58 | lp = [sith_params1, sith_params2]
59 | deepsith_layers = DeepSITH(layer_params=lp, dropout=0.2)
60 |
61 | Here, we have the first layer only having 15 taustar from `tau_min=1.0` to `tau_max=25`. The second layer is set up to go from `1.0` to `100.0`, which gives it 4 times the temporal range. We found that the logarithmic increase of layer sizes to work well for the experiments in this repository.
62 |
63 | The DeepSITH module expects an input signal of size (batch_size, 1, sith_params1["in_features"], Time).
64 |
65 | If you want to use **only** the SITH module, which is a part of any DeepSITH layer, you can initialize a SITH using the following parameters. Note, these parameters are also used in the dictionaries above.
66 |
67 | #### SITH Parameters
68 | - tau_min: float
69 | The center of the temporal receptive field for the first taustar produced.
70 | - tau_max: float
71 | The center of the temporal receptive field for the last taustar produced.
72 | - buff_max: int
73 | The maximum time in which the filters go into the past. NOTE: In order to
74 | achieve as few edge effects as possible, buff_max needs to be bigger than
75 | tau_max, and dependent on k, such that the filters have enough time to reach
76 | very close to 0.0. Plot the filters and you will see them go to 0.
77 | - k: int
78 | Temporal Specificity of the taustars. If this number is high, then taustars
79 | will always be more narrow.
80 | - ntau: int
81 | Number of taustars produced, spread out logarithmically.
82 | - dt: float
83 | The time delta of the model. There will be int(buff_max/dt) filters per
84 | taustar. Essentially this is the base rate of information being presented to the model
85 | - g: float
86 | Typically between 0 and 1. This parameter is the scaling factor of the output
87 | of the module. If set to 1, the output amplitude for a delta function will be
88 | identical through time. If set to 0, the amplitude will decay into the past,
89 | getting smaller and smaller. This value should be picked on an application to
90 | application basis.
91 | - ttype: Torch Tensor
92 | This is the type we set the internal mechanism of the model to before running.
93 | In order to calculate the filters, we must use a DoubleTensor, but this is no
94 | longer necessary after they are calculated. By default we set the filters to
95 | be FloatTensors. NOTE: If you plan to use CUDA, you need to pass in a
96 | cuda.FloatTensor as the ttype, as using .cuda() will not put these filters on
97 | the gpu.
98 |
99 | Initializing SITH will generate several attributes that depend heavily on the values of the parameters.
100 |
101 | - c: float
102 | `c = (tau_max/tau_min)**(1./(ntau-1))-1`. This is the description of how the distance between
103 | taustars evolves.
104 | - tau_star: DoubleTensor
105 | `tau_star = tau_min*(1+c)**torch.arange(ntau)`. This is the array filled with all of the
106 | centers of all the tau_star receptive fields.
107 | - filters: ttype
108 | The generated convolutional filters to generate SITH output. Will be applied as a convolution
109 | to the input time-series.
110 |
111 | Importantly, this module should be socketed into a larger pytorch model, where the final layers are transforming the last output of the **SITH->Dense Layer** into the shape of the output required for a particular task. We, for instance, use a DeepSITH_Classifier model for most of the tasks within this repository.
112 |
113 | class DeepSITH_Classifier(nn.Module):
114 | def __init__(self, out_features, layer_params, dropout=.5):
115 | super(DeepSITH_Classifier, self).__init__()
116 | last_hidden = layer_params[-1]['hidden_size']
117 | self.hs = DeepSITH(layer_params=layer_params, dropout=dropout)
118 | self.to_out = nn.Linear(last_hidden, out_features)
119 | def forward(self, inp):
120 | x = self.hs(inp)
121 | x = self.to_out(x)
122 | return x
123 |
124 | ## Examples
125 |
126 | In the `experiments` folder are the experiments that were included in the paper. Everything to recreate the results therein is included. Everything is in jupyter notebooks. We have also included everything needed to recreate the figures from the paper, but with your results if you change file names around.
127 |
128 |
129 |
--------------------------------------------------------------------------------
/experiments/psMNIST/visualization.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "code",
5 | "execution_count": null,
6 | "metadata": {
7 | "ExecuteTime": {
8 | "end_time": "2021-05-12T18:22:42.830086Z",
9 | "start_time": "2021-05-12T18:22:42.658411Z"
10 | }
11 | },
12 | "outputs": [],
13 | "source": [
14 | "%matplotlib inline"
15 | ]
16 | },
17 | {
18 | "cell_type": "code",
19 | "execution_count": null,
20 | "metadata": {
21 | "ExecuteTime": {
22 | "end_time": "2021-05-12T18:22:43.466784Z",
23 | "start_time": "2021-05-12T18:22:42.878529Z"
24 | }
25 | },
26 | "outputs": [],
27 | "source": [
28 | "import matplotlib.pylab as plt\n",
29 | "from matplotlib import gridspec\n",
30 | "import csv\n",
31 | "import numpy as np\n",
32 | "import pandas as pd\n",
33 | "import os\n",
34 | "import torch\n",
35 | "from torchvision import transforms\n",
36 | "from torchvision import datasets\n",
37 | "import seaborn as sn\n",
38 | "sn.set_context('poster')"
39 | ]
40 | },
41 | {
42 | "cell_type": "code",
43 | "execution_count": null,
44 | "metadata": {
45 | "ExecuteTime": {
46 | "end_time": "2021-05-12T18:22:43.492461Z",
47 | "start_time": "2021-05-12T18:22:43.469951Z"
48 | }
49 | },
50 | "outputs": [],
51 | "source": [
52 | "dats = pd.read_csv(os.path.join('perf', 'smnist_deepsith_11.csv'))\n",
53 | "dats.columns = ['loss', 'test_perf', 'epoch', 'presnum', 'perf']\n",
54 | "maxpres = 60000\n",
55 | "dats['presnum_epoch'] = ((dats.presnum*64) + maxpres*dats.epoch)/maxpres\n",
56 | "test_dats = pd.read_csv(\"perf/smnist_deepsith_test_11.csv\")\n",
57 | "test_dats['epoch'] = np.arange(test_dats.shape[0]) + 1\n",
58 | "test_dats['test'] = test_dats['0']"
59 | ]
60 | },
61 | {
62 | "cell_type": "code",
63 | "execution_count": null,
64 | "metadata": {
65 | "ExecuteTime": {
66 | "end_time": "2021-05-12T18:22:43.509426Z",
67 | "start_time": "2021-05-12T18:22:43.493696Z"
68 | }
69 | },
70 | "outputs": [],
71 | "source": [
72 | "dato = pd.read_csv(os.path.join('perf', 'pmnist_deepsith_78.csv'))\n",
73 | "dato.columns = ['loss', 'epoch', 'presnum', 'perf']\n",
74 | "maxpres = 60000\n",
75 | "dato['presnum_epoch'] = ((dato.presnum*64) + maxpres*dato.epoch)/maxpres\n",
76 | "test_dato = pd.read_csv(\"perf/pmnist_deesith_test_78.csv\")\n",
77 | "test_dato.epoch = test_dato.epoch+1\n",
78 | "test_dato.head()"
79 | ]
80 | },
81 | {
82 | "cell_type": "code",
83 | "execution_count": null,
84 | "metadata": {
85 | "ExecuteTime": {
86 | "end_time": "2021-05-12T18:23:02.475471Z",
87 | "start_time": "2021-05-12T18:23:02.391851Z"
88 | }
89 | },
90 | "outputs": [],
91 | "source": [
92 | "norm = transforms.Normalize((.1307,), (.3081,), )\n",
93 | "batch_size = 400\n",
94 | "transform = transforms.Compose([transforms.ToTensor(),\n",
95 | " transforms.Normalize((.1307,), (.3081,))\n",
96 | " ])\n",
97 | "ds1 = datasets.MNIST('../data', train=True, download=True, transform=transform)\n",
98 | "train_loader=torch.utils.data.DataLoader(ds1,batch_size=batch_size, \n",
99 | " num_workers=1, pin_memory=True, shuffle=True)\n"
100 | ]
101 | },
102 | {
103 | "cell_type": "code",
104 | "execution_count": null,
105 | "metadata": {
106 | "ExecuteTime": {
107 | "end_time": "2021-05-12T18:22:49.041702Z",
108 | "start_time": "2021-05-12T18:22:49.032510Z"
109 | }
110 | },
111 | "outputs": [],
112 | "source": [
113 | "# Same seed and supposed Permutation as the coRNN paper\n",
114 | "torch.manual_seed(12008)\n",
115 | "permute = torch.randperm(784)"
116 | ]
117 | },
118 | {
119 | "cell_type": "code",
120 | "execution_count": null,
121 | "metadata": {
122 | "ExecuteTime": {
123 | "end_time": "2021-05-12T18:22:49.856527Z",
124 | "start_time": "2021-05-12T18:22:49.847567Z"
125 | }
126 | },
127 | "outputs": [],
128 | "source": [
129 | "dat = next(enumerate(train_loader))[1]\n",
130 | "dat[0].shape, dat[1].shape"
131 | ]
132 | },
133 | {
134 | "cell_type": "code",
135 | "execution_count": null,
136 | "metadata": {
137 | "ExecuteTime": {
138 | "end_time": "2021-05-12T18:22:54.876834Z",
139 | "start_time": "2021-05-12T18:22:54.867463Z"
140 | }
141 | },
142 | "outputs": [],
143 | "source": [
144 | "fig_dat = []\n",
145 | "for i in range(10):\n",
146 | " fig_dat.append(dat[0][dat[1] == i][:10])\n",
147 | "fig_dat = torch.cat(fig_dat, dim=0)\n",
148 | "fig_dat = fig_dat.view(100,-1)#[:, permute]\n"
149 | ]
150 | },
151 | {
152 | "cell_type": "code",
153 | "execution_count": null,
154 | "metadata": {
155 | "ExecuteTime": {
156 | "end_time": "2021-05-12T18:23:07.689070Z",
157 | "start_time": "2021-05-12T18:23:06.492357Z"
158 | }
159 | },
160 | "outputs": [],
161 | "source": [
162 | "linew = 4\n",
163 | "with sn.plotting_context(\"notebook\", font_scale=2.8):\n",
164 | " fig2 = plt.figure(figsize=(20,18), )\n",
165 | " spec2 = gridspec.GridSpec(nrows=4, ncols=2, wspace=0.05, figure=fig2)\n",
166 | "\n",
167 | "\n",
168 | " ax = fig2.add_subplot(spec2[-2, 1])\n",
169 | " #fig, ax= plt.subplots(2,2,sharex='col', figsize=(12,10), sharey='row', )\n",
170 | " sn.lineplot(data=dato, x=dato.presnum_epoch, y='perf', ax=ax, linewidth=linew,\n",
171 | " color='darkblue', )\n",
172 | " sn.lineplot(data=test_dato, x='epoch', y='test', ax=ax, linewidth=linew,\n",
173 | " )\n",
174 | " ax.grid(True)\n",
175 | " ax.legend([\"Training\", \"Test\", \n",
176 | " ],loc='lower right')\n",
177 | " ax.set_ylabel('Accuracy')\n",
178 | " ax.set_xlabel('')\n",
179 | " ax.yaxis.tick_right()\n",
180 | " ax.yaxis.set_label_position(\"right\")\n",
181 | " plt.setp(ax.get_xticklabels(), visible=False)\n",
182 | " ax.set_xlim(0,90)#)\n",
183 | " ax.set_ylim(.99, 1.0005)\n",
184 | "\n",
185 | " ax = fig2.add_subplot(spec2[-1, 1], sharex=ax)\n",
186 | " sn.lineplot(data=dato, x=dato.presnum_epoch, y='loss', ax=ax, linewidth=linew,\n",
187 | " color='darkblue', legend=False)\n",
188 | " ax.set_ylabel('Loss')\n",
189 | "\n",
190 | " ax.set_xlabel('Epoch')\n",
191 | " ax.yaxis.tick_right()\n",
192 | " \n",
193 | " ax.yaxis.set_label_position(\"right\")\n",
194 | " ax.set_ylim(0,.01)#)\n",
195 | " ax.grid(True)\n",
196 | "\n",
197 | "\n",
198 | " \n",
199 | " ax = fig2.add_subplot(spec2[-2, 0])\n",
200 | " sn.lineplot(data=dats, x=dats.presnum_epoch, y='perf', ax=ax, linewidth=linew,\n",
201 | " color='darkblue', )\n",
202 | " sn.lineplot(data=test_dats, x='epoch', y='test', ax=ax, linewidth=linew,\n",
203 | " )\n",
204 | " ax.legend([\"Training\", \"Test\", \n",
205 | " ],loc='lower right')\n",
206 | " plt.setp(ax.get_xticklabels(), visible=False)\n",
207 | " ax.grid(True)\n",
208 | " ax.set_ylabel('Accuracy')\n",
209 | " ax.set_xlabel('')\n",
210 | " ax.set_xlim(0,60)#)\n",
211 | " ax.set_ylim(.99, 1.0005)\n",
212 | "\n",
213 | " ax = fig2.add_subplot(spec2[-1, 0], sharex=ax)\n",
214 | " sn.lineplot(data=dats, x=dats.presnum_epoch, y='loss', ax=ax, linewidth=linew,\n",
215 | " color='darkblue', legend=False)\n",
216 | " ax.set_ylabel('Loss')\n",
217 | " ax.set_xlabel('Epoch')\n",
218 | " ax.set_ylim(0,.01)#)\n",
219 | " ax.grid(True)\n",
220 | "\n",
221 | "\n",
222 | " \n",
223 | " ax = fig2.add_subplot(spec2[:-2, 0])\n",
224 | " ax.imshow(fig_dat.detach().cpu().numpy(), aspect='auto')\n",
225 | " ax.tick_params(axis=u'both', which=u'both',length=0)\n",
226 | " ax.set_yticks(np.arange(10,100,10))\n",
227 | " ax.set_xticks([])\n",
228 | " plt.setp(ax.get_yticklabels(), visible=False)\n",
229 | " ax.grid(True)\n",
230 | " ax.set_title('sMNIST')\n",
231 | "\n",
232 | " ax = fig2.add_subplot(spec2[:-2, 1])\n",
233 | " ax.imshow(fig_dat[:, permute].detach().cpu().numpy(), aspect='auto')\n",
234 | " ax.tick_params(axis=u'both', which=u'both',length=0)\n",
235 | " ax.set_yticks(np.arange(10,100,10))\n",
236 | " ax.set_xticks([])\n",
237 | " ax.yaxis.tick_right()\n",
238 | " plt.setp(ax.get_yticklabels(), visible=False)\n",
239 | " ax.grid(True)\n",
240 | " ax.set_title('psMNIST')\n",
241 | "\n",
242 | " plt.savefig('MNIST.pdf',\n",
243 | " bbox='tight',\n",
244 | " edgecolor=fig2.get_edgecolor(),\n",
245 | " facecolor=fig2.get_facecolor(),\n",
246 | " dpi=150\n",
247 | " )\n",
248 | " plt.savefig('MNIST.svg',\n",
249 | " bbox='tight',\n",
250 | " edgecolor=fig2.get_edgecolor(),\n",
251 | " facecolor=fig2.get_facecolor(),\n",
252 | " dpi=150\n",
253 | " )\n",
254 | " \n",
255 | "\n",
256 | "\n",
257 | "\n"
258 | ]
259 | },
260 | {
261 | "cell_type": "code",
262 | "execution_count": null,
263 | "metadata": {},
264 | "outputs": [],
265 | "source": []
266 | }
267 | ],
268 | "metadata": {
269 | "kernelspec": {
270 | "display_name": "Python 3",
271 | "language": "python",
272 | "name": "python3"
273 | },
274 | "language_info": {
275 | "codemirror_mode": {
276 | "name": "ipython",
277 | "version": 3
278 | },
279 | "file_extension": ".py",
280 | "mimetype": "text/x-python",
281 | "name": "python",
282 | "nbconvert_exporter": "python",
283 | "pygments_lexer": "ipython3",
284 | "version": "3.6.10"
285 | },
286 | "toc": {
287 | "nav_menu": {},
288 | "number_sections": true,
289 | "sideBar": true,
290 | "skip_h1_title": false,
291 | "title_cell": "Table of Contents",
292 | "title_sidebar": "Contents",
293 | "toc_cell": false,
294 | "toc_position": {},
295 | "toc_section_display": true,
296 | "toc_window_display": false
297 | }
298 | },
299 | "nbformat": 4,
300 | "nbformat_minor": 4
301 | }
302 |
--------------------------------------------------------------------------------
/experiments/MackeyGlass/MackeyGlass-LSTM-increasing_tau.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "code",
5 | "execution_count": null,
6 | "metadata": {
7 | "ExecuteTime": {
8 | "end_time": "2021-01-31T16:16:49.475627Z",
9 | "start_time": "2021-01-31T16:16:49.297073Z"
10 | }
11 | },
12 | "outputs": [],
13 | "source": [
14 | "%matplotlib inline"
15 | ]
16 | },
17 | {
18 | "cell_type": "code",
19 | "execution_count": null,
20 | "metadata": {
21 | "ExecuteTime": {
22 | "end_time": "2021-01-31T16:16:50.214114Z",
23 | "start_time": "2021-01-31T16:16:49.515165Z"
24 | }
25 | },
26 | "outputs": [],
27 | "source": [
28 | "import torch\n",
29 | "import matplotlib.pyplot as plt\n",
30 | "import torchvision\n",
31 | "import nengolib\n",
32 | "import numpy as np\n",
33 | "import torch.nn as nn\n",
34 | "import torch.nn.functional as F\n",
35 | "from tqdm import tqdm_notebook\n",
36 | "import PIL\n",
37 | "from torch.nn.utils import weight_norm\n",
38 | "\n",
39 | "from os.path import join\n",
40 | "import scipy.special\n",
41 | "import pandas as pd\n",
42 | "import seaborn as sn\n",
43 | "import scipy\n",
44 | "from scipy.spatial.distance import euclidean\n",
45 | "from scipy.interpolate import interp1d\n",
46 | "from tqdm.notebook import tqdm\n",
47 | "import random\n",
48 | "from csv import DictWriter\n",
49 | "# if gpu is to be used\n",
50 | "use_cuda = torch.cuda.is_available()\n",
51 | "FloatTensor = torch.cuda.FloatTensor if use_cuda else torch.FloatTensor\n",
52 | "DoubleTensor = torch.cuda.DoubleTensor if use_cuda else torch.DoubleTensor\n",
53 | "\n",
54 | "IntTensor = torch.cuda.IntTensor if use_cuda else torch.IntTensor\n",
55 | "LongTensor = torch.cuda.LongTensor if use_cuda else torch.LongTensor\n",
56 | "ByteTensor = torch.cuda.ByteTensor if use_cuda else torch.ByteTensor\n",
57 | "ttype = FloatTensor\n",
58 | "\n",
59 | "import seaborn as sns\n",
60 | "print(use_cuda)\n",
61 | "import pickle\n"
62 | ]
63 | },
64 | {
65 | "cell_type": "markdown",
66 | "metadata": {},
67 | "source": [
68 | "# Load Stimuli"
69 | ]
70 | },
71 | {
72 | "cell_type": "code",
73 | "execution_count": null,
74 | "metadata": {
75 | "ExecuteTime": {
76 | "end_time": "2021-01-31T16:16:52.360839Z",
77 | "start_time": "2021-01-31T16:16:52.350178Z"
78 | }
79 | },
80 | "outputs": [],
81 | "source": [
82 | "import collections\n",
83 | "\n",
84 | "def mackey_glass(sample_len=1000, tau=17, delta_t=10, seed=None, n_samples=1):\n",
85 | " # Adapted from https://github.com/mila-iqia/summerschool2015/blob/master/rnn_tutorial/synthetic.py\n",
86 | " '''\n",
87 | " mackey_glass(sample_len=1000, tau=17, seed = None, n_samples = 1) -> input\n",
88 | " Generate the Mackey Glass time-series. Parameters are:\n",
89 | " - sample_len: length of the time-series in timesteps. Default is 1000.\n",
90 | " - tau: delay of the MG - system. Commonly used values are tau=17 (mild \n",
91 | " chaos) and tau=30 (moderate chaos). Default is 17.\n",
92 | " - seed: to seed the random generator, can be used to generate the same\n",
93 | " timeseries at each invocation.\n",
94 | " - n_samples : number of samples to generate\n",
95 | " '''\n",
96 | " history_len = tau * delta_t \n",
97 | " # Initial conditions for the history of the system\n",
98 | " timeseries = 1.2\n",
99 | " \n",
100 | " if seed is not None:\n",
101 | " np.random.seed(seed)\n",
102 | "\n",
103 | " samples = []\n",
104 | "\n",
105 | " for _ in range(n_samples):\n",
106 | " history = collections.deque(1.2 * np.ones(history_len) + 0.2 * \\\n",
107 | " (np.random.rand(history_len) - 0.5))\n",
108 | " # Preallocate the array for the time-series\n",
109 | " inp = np.zeros((sample_len,1))\n",
110 | " \n",
111 | " for timestep in range(sample_len):\n",
112 | " for _ in range(delta_t):\n",
113 | " xtau = history.popleft()\n",
114 | " history.append(timeseries)\n",
115 | " timeseries = history[-1] + (0.2 * xtau / (1.0 + xtau ** 10) - \\\n",
116 | " 0.1 * history[-1]) / delta_t\n",
117 | " inp[timestep] = timeseries\n",
118 | " \n",
119 | " # Squash timeseries through tanh\n",
120 | " inp = np.tanh(inp - 1)\n",
121 | " samples.append(inp)\n",
122 | " return samples\n",
123 | "\n",
124 | "\n",
125 | "def generate_data(n_batches, length, split=0.5, seed=0,\n",
126 | " predict_length=15, tau=17, washout=100, delta_t=1,\n",
127 | " center=True):\n",
128 | " X = np.asarray(mackey_glass(\n",
129 | " sample_len=length+predict_length+washout, tau=tau,\n",
130 | " seed=seed, n_samples=n_batches))\n",
131 | " X = X[:, washout:, :]\n",
132 | " cutoff = int(split*n_batches)\n",
133 | " if center:\n",
134 | " X -= np.mean(X) # global mean over all batches, approx -0.066\n",
135 | " Y = X[:, predict_length:, :]\n",
136 | " X = X[:, :-predict_length, :]\n",
137 | " assert X.shape == Y.shape\n",
138 | " return ((X[:cutoff], Y[:cutoff]),\n",
139 | " (X[cutoff:], Y[cutoff:]))"
140 | ]
141 | },
142 | {
143 | "cell_type": "code",
144 | "execution_count": null,
145 | "metadata": {
146 | "ExecuteTime": {
147 | "end_time": "2020-12-28T14:39:38.387303Z",
148 | "start_time": "2020-12-28T14:39:36.507391Z"
149 | }
150 | },
151 | "outputs": [],
152 | "source": [
153 | "(train_X, train_Y), (test_X, test_Y) = generate_data(2, 5000)\n",
154 | "dataset = torch.utils.data.TensorDataset(torch.Tensor(train_X).cuda(), torch.Tensor(train_Y).cuda())\n",
155 | "dataset = torch.utils.data.DataLoader(dataset, batch_size=32, shuffle=True)\n",
156 | "\n",
157 | "dataset_valid = torch.utils.data.TensorDataset(torch.Tensor(test_X).cuda(), torch.Tensor(test_Y).cuda())\n",
158 | "dataset_valid = torch.utils.data.DataLoader(dataset_valid, batch_size=64, shuffle=False)\n",
159 | "\n",
160 | "print(train_X.shape, train_Y.shape, test_X.shape)"
161 | ]
162 | },
163 | {
164 | "cell_type": "markdown",
165 | "metadata": {},
166 | "source": [
167 | "## Setup for Model"
168 | ]
169 | },
170 | {
171 | "cell_type": "code",
172 | "execution_count": null,
173 | "metadata": {
174 | "ExecuteTime": {
175 | "end_time": "2021-01-31T16:17:16.816497Z",
176 | "start_time": "2021-01-31T16:17:16.806348Z"
177 | }
178 | },
179 | "outputs": [],
180 | "source": [
181 | "\n",
182 | "def train(model, ttype, train_loader, test_loader, optimizer, loss_func, epoch, perf_file,\n",
183 | " loss_buffer_size=800, batch_size=4, device='cuda',\n",
184 | " prog_bar=None, tau=17, pred_len=15):\n",
185 | " \n",
186 | " assert(loss_buffer_size%batch_size==0)\n",
187 | " \n",
188 | " losses = []\n",
189 | " last_test_perf = 0\n",
190 | " best_test_perf = -1\n",
191 | " \n",
192 | " for batch_idx, (data, target) in enumerate(train_loader):\n",
193 | " model.train()\n",
194 | " data = data.to(device).transpose(1,0)\n",
195 | " target = target.to(device)\n",
196 | " optimizer.zero_grad()\n",
197 | " out = model(data)\n",
198 | " loss = loss_func(out,\n",
199 | " target)\n",
200 | "\n",
201 | " loss.backward()\n",
202 | " optimizer.step()\n",
203 | "\n",
204 | " losses.append(loss.detach().cpu().numpy())\n",
205 | " losses = losses[int(-loss_buffer_size/batch_size):]\n",
206 | " \n",
207 | " if ((batch_idx*batch_size)%loss_buffer_size == 0):\n",
208 | " loss_track = {}\n",
209 | " last_test_perf = test_model(model, 'cuda', test_loader, \n",
210 | " )\n",
211 | " loss_track['avg_loss'] = np.mean(losses)\n",
212 | " loss_track['last_test'] = last_test_perf\n",
213 | " loss_track['epoch'] = epoch\n",
214 | " loss_track['batch_idx'] = batch_idx\n",
215 | " loss_track['tau'] = tau\n",
216 | " loss_track['pred_len'] = pred_len\n",
217 | "\n",
218 | " with open(perf_file, 'a+') as fp:\n",
219 | " csv_writer = DictWriter(fp, fieldnames=list(loss_track.keys()))\n",
220 | " if fp.tell() == 0:\n",
221 | " csv_writer.writeheader()\n",
222 | " csv_writer.writerow(loss_track)\n",
223 | " fp.flush()\n",
224 | " if best_test_perf < last_test_perf:\n",
225 | " torch.save(model.state_dict(), perf_file[:-4]+\".pt\")\n",
226 | " best_test_perf = last_test_perf\n",
227 | " if not (prog_bar is None):\n",
228 | " # Update progress_bar\n",
229 | " s = \"{}:{} Loss: {:.4f},valid: {:.4f}\"\n",
230 | " format_list = [e,batch_idx*batch_size, np.mean(losses), \n",
231 | " last_test_perf] \n",
232 | " s = s.format(*format_list)\n",
233 | " prog_bar.set_description(s)\n",
234 | " \n",
235 | "def test_model(model, device, test_loader):\n",
236 | " # Test the Model\n",
237 | " nrmsd = []\n",
238 | " with torch.no_grad():\n",
239 | " for x, y in test_loader:\n",
240 | " data = x.to(device).transpose(1,0)\n",
241 | " target = y.to(device)\n",
242 | " optimizer.zero_grad()\n",
243 | " out = model(data)\n",
244 | " nrmsd.append(nengolib.signal.nrmse(out.detach().cpu().numpy().flatten(), target=target.detach().cpu().numpy().flatten()))\n",
245 | " perf = np.array(nrmsd).mean()\n",
246 | " return perf"
247 | ]
248 | },
249 | {
250 | "cell_type": "code",
251 | "execution_count": null,
252 | "metadata": {
253 | "ExecuteTime": {
254 | "end_time": "2021-01-31T16:18:41.596917Z",
255 | "start_time": "2021-01-31T16:18:41.593831Z"
256 | }
257 | },
258 | "outputs": [],
259 | "source": [
260 | "class LSTM_Predictor(nn.Module):\n",
261 | " def __init__(self, out_features, lstm_params):\n",
262 | " super(LSTM_Predictor, self).__init__()\n",
263 | " self.lstm = nn.LSTM(**lstm_params)\n",
264 | " self.to_out = nn.Linear(lstm_params['hidden_size'], \n",
265 | " out_features)\n",
266 | " def forward(self, inp):\n",
267 | " x = self.lstm(inp)[0].transpose(1,0)\n",
268 | " x = torch.tanh(self.to_out(x))\n",
269 | " return x"
270 | ]
271 | },
272 | {
273 | "cell_type": "code",
274 | "execution_count": null,
275 | "metadata": {
276 | "ExecuteTime": {
277 | "end_time": "2021-01-31T18:18:18.908819Z",
278 | "start_time": "2021-01-31T16:18:42.806246Z"
279 | },
280 | "scrolled": true
281 | },
282 | "outputs": [],
283 | "source": [
284 | "start_tau = 17\n",
285 | "start_pd = 15\n",
286 | "diffs = [1, 2, 3, 4, 5]\n",
287 | "for diff in diffs:\n",
288 | " (train_X, train_Y), (test_X, test_Y) = generate_data(128, 5000, tau=start_tau*diff, \n",
289 | " predict_length=start_pd*diff)\n",
290 | " dataset = torch.utils.data.TensorDataset(torch.Tensor(train_X).cuda(), torch.Tensor(train_Y).cuda())\n",
291 | " dataset = torch.utils.data.DataLoader(dataset, batch_size=32, shuffle=True)\n",
292 | "\n",
293 | " dataset_valid = torch.utils.data.TensorDataset(torch.Tensor(test_X).cuda(), torch.Tensor(test_Y).cuda())\n",
294 | " dataset_valid = torch.utils.data.DataLoader(dataset_valid, batch_size=64, shuffle=False)\n",
295 | "\n",
296 | " print(train_X.shape, train_Y.shape, test_X.shape)\n",
297 | "\n",
298 | " lstm_params = dict(input_size=1,\n",
299 | " hidden_size=25, \n",
300 | " num_layers=4)\n",
301 | " model = LSTM_Predictor(1, lstm_params=lstm_params).cuda()\n",
302 | "\n",
303 | "\n",
304 | " optimizer = torch.optim.Adam(model.parameters())\n",
305 | " loss_func = nn.MSELoss()\n",
306 | " \n",
307 | " epochs = 1000\n",
308 | " batch_size = 32\n",
309 | " progress_bar = tqdm(range(int(epochs)), bar_format='{l_bar}{bar:5}{r_bar}{bar:-5b}')\n",
310 | " last_perf = 1000\n",
311 | " for e in progress_bar:\n",
312 | " train(model, ttype, dataset, dataset_valid, \n",
313 | " optimizer, loss_func, batch_size=batch_size, loss_buffer_size=64,\n",
314 | " epoch=e, perf_file=join('perf','mackeyglass_lstm_ratio_1.csv'),\n",
315 | " prog_bar=progress_bar, tau=start_tau*diff, pred_len=start_pd*diff)"
316 | ]
317 | }
318 | ],
319 | "metadata": {
320 | "kernelspec": {
321 | "display_name": "Python 3",
322 | "language": "python",
323 | "name": "python3"
324 | },
325 | "language_info": {
326 | "codemirror_mode": {
327 | "name": "ipython",
328 | "version": 3
329 | },
330 | "file_extension": ".py",
331 | "mimetype": "text/x-python",
332 | "name": "python",
333 | "nbconvert_exporter": "python",
334 | "pygments_lexer": "ipython3",
335 | "version": "3.6.10"
336 | },
337 | "toc": {
338 | "nav_menu": {},
339 | "number_sections": true,
340 | "sideBar": true,
341 | "skip_h1_title": false,
342 | "title_cell": "Table of Contents",
343 | "title_sidebar": "Contents",
344 | "toc_cell": false,
345 | "toc_position": {},
346 | "toc_section_display": true,
347 | "toc_window_display": false
348 | }
349 | },
350 | "nbformat": 4,
351 | "nbformat_minor": 4
352 | }
353 |
--------------------------------------------------------------------------------
/experiments/psMNIST/sMNIST.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "code",
5 | "execution_count": null,
6 | "metadata": {
7 | "ExecuteTime": {
8 | "end_time": "2021-03-15T13:46:07.365933Z",
9 | "start_time": "2021-03-15T13:46:07.144247Z"
10 | }
11 | },
12 | "outputs": [],
13 | "source": [
14 | "%matplotlib inline"
15 | ]
16 | },
17 | {
18 | "cell_type": "code",
19 | "execution_count": null,
20 | "metadata": {
21 | "ExecuteTime": {
22 | "end_time": "2021-03-15T13:46:08.069258Z",
23 | "start_time": "2021-03-15T13:46:07.366877Z"
24 | }
25 | },
26 | "outputs": [],
27 | "source": [
28 | "import matplotlib.pyplot as plt\n",
29 | "\n",
30 | "import torch\n",
31 | "from torch.optim.lr_scheduler import StepLR\n",
32 | "from sklearn.metrics import confusion_matrix\n",
33 | "\n",
34 | "import numpy as np\n",
35 | "\n",
36 | "import torch.nn as nn\n",
37 | "\n",
38 | "from tqdm import tqdm_notebook\n",
39 | "\n",
40 | "from torchvision import transforms\n",
41 | "from torchvision import datasets\n",
42 | "from os.path import join\n",
43 | "\n",
44 | "from deepsith import DeepSITH\n",
45 | "\n",
46 | "from tqdm.notebook import tqdm\n",
47 | "\n",
48 | "import random\n",
49 | "\n",
50 | "from csv import DictWriter\n",
51 | "# if gpu is to be used\n",
52 | "use_cuda = torch.cuda.is_available()\n",
53 | "\n",
54 | "FloatTensor = torch.cuda.FloatTensor if use_cuda else torch.FloatTensor\n",
55 | "LongTensor = torch.cuda.LongTensor if use_cuda else torch.LongTensor\n",
56 | "\n",
57 | "ttype =FloatTensor\n",
58 | "\n",
59 | "import seaborn as sn\n",
60 | "print(use_cuda)\n",
61 | "import pickle\n",
62 | "\n",
63 | "sn.set_context(\"poster\")"
64 | ]
65 | },
66 | {
67 | "cell_type": "code",
68 | "execution_count": null,
69 | "metadata": {
70 | "ExecuteTime": {
71 | "end_time": "2021-03-15T13:46:08.073469Z",
72 | "start_time": "2021-03-15T13:46:08.070605Z"
73 | }
74 | },
75 | "outputs": [],
76 | "source": [
77 | "class DeepSITH_Classifier(nn.Module):\n",
78 | " def __init__(self, out_features, layer_params, dropout=.1):\n",
79 | " super(DeepSITH_Classifier, self).__init__()\n",
80 | " last_hidden = layer_params[-1]['hidden_size']\n",
81 | " self.hs = DeepSITH(layer_params=layer_params, dropout=dropout)\n",
82 | " self.to_out = nn.Linear(last_hidden, out_features)\n",
83 | " def forward(self, inp):\n",
84 | " x = self.hs(inp)\n",
85 | " x = self.to_out(x)\n",
86 | " return x"
87 | ]
88 | },
89 | {
90 | "cell_type": "code",
91 | "execution_count": null,
92 | "metadata": {
93 | "ExecuteTime": {
94 | "end_time": "2021-03-15T13:46:08.083443Z",
95 | "start_time": "2021-03-15T13:46:08.074275Z"
96 | }
97 | },
98 | "outputs": [],
99 | "source": [
100 | "import scipy.optimize as opt\n",
101 | "from deepsith import iSITH\n",
102 | "def min_fun(x, *args):\n",
103 | " ntau = int(x[0])\n",
104 | " k = int(x[1])\n",
105 | " if k < 4 or k>125:\n",
106 | " return np.inf\n",
107 | " tau_min = args[0]\n",
108 | " tau_max = args[1] \n",
109 | " ev = iSITH(tau_min=tau_min, tau_max=tau_max, buff_max=tau_max*5, k=k, ntau=ntau, dt=1, g=1.0) \n",
110 | " std_0 = ev.filters[:, 0, 0, :].detach().cpu().T.numpy()[::-1].sum(1)[int(tau_min):int(tau_max)].std()\n",
111 | " std_1 = ev.filters[:, 0, 0, :].detach().cpu().T.numpy()[::-1, ::2].sum(1)[int(tau_min):int(tau_max)].std() \n",
112 | " to_min = std_0/std_1\n",
113 | " return to_min"
114 | ]
115 | },
116 | {
117 | "cell_type": "markdown",
118 | "metadata": {},
119 | "source": [
120 | "# Load Stimuli"
121 | ]
122 | },
123 | {
124 | "cell_type": "code",
125 | "execution_count": null,
126 | "metadata": {
127 | "ExecuteTime": {
128 | "end_time": "2021-03-15T13:46:36.407379Z",
129 | "start_time": "2021-03-15T13:46:36.405404Z"
130 | }
131 | },
132 | "outputs": [],
133 | "source": [
134 | "norm = transforms.Normalize((.1307,), (.3081,), )"
135 | ]
136 | },
137 | {
138 | "cell_type": "code",
139 | "execution_count": null,
140 | "metadata": {
141 | "ExecuteTime": {
142 | "end_time": "2021-03-15T13:46:37.836490Z",
143 | "start_time": "2021-03-15T13:46:37.822290Z"
144 | }
145 | },
146 | "outputs": [],
147 | "source": [
148 | "batch_size = 64\n",
149 | "transform = transforms.Compose([transforms.ToTensor(),\n",
150 | " transforms.Normalize((.1307,), (.3081,))\n",
151 | " ])\n",
152 | "ds1 = datasets.MNIST('../data', train=True, download=True, transform=transform)\n",
153 | "ds2 = datasets.MNIST('../data', train=False, download=True, transform=transform)\n",
154 | "train_loader=torch.utils.data.DataLoader(ds1,batch_size=batch_size, \n",
155 | " num_workers=1, pin_memory=True, shuffle=True)\n",
156 | "test_loader=torch.utils.data.DataLoader(ds2, batch_size=batch_size, \n",
157 | " num_workers=1, pin_memory=True, shuffle=True)"
158 | ]
159 | },
160 | {
161 | "cell_type": "code",
162 | "execution_count": null,
163 | "metadata": {
164 | "ExecuteTime": {
165 | "end_time": "2021-03-15T13:44:18.485182Z",
166 | "start_time": "2021-03-15T13:44:14.672022Z"
167 | }
168 | },
169 | "outputs": [],
170 | "source": [
171 | "test = next(iter(test_loader))[0]\n",
172 | "\n",
173 | "plt.imshow(test[0].reshape(-1).reshape(28,28))\n",
174 | "\n",
175 | "plt.colorbar()"
176 | ]
177 | },
178 | {
179 | "cell_type": "markdown",
180 | "metadata": {},
181 | "source": [
182 | "# Define test and train"
183 | ]
184 | },
185 | {
186 | "cell_type": "code",
187 | "execution_count": null,
188 | "metadata": {
189 | "ExecuteTime": {
190 | "end_time": "2021-03-15T13:46:11.244757Z",
191 | "start_time": "2021-03-15T13:46:11.235002Z"
192 | }
193 | },
194 | "outputs": [],
195 | "source": [
196 | "\n",
197 | "def train(model, ttype, train_loader, test_loader, optimizer, loss_func, epoch, perf_file,\n",
198 | " loss_buffer_size=800, batch_size=4, device='cuda',\n",
199 | " prog_bar=None, last_test_perf=0):\n",
200 | " \n",
201 | " assert(loss_buffer_size%batch_size==0)\n",
202 | "\n",
203 | " \n",
204 | " losses = []\n",
205 | " perfs = []\n",
206 | " best_test_perf = -1\n",
207 | " \n",
208 | " for batch_idx, (data, target) in enumerate(train_loader):\n",
209 | " model.train()\n",
210 | " data = data.to(device).view(data.shape[0],1,1,-1)\n",
211 | " target = target.to(device)\n",
212 | " optimizer.zero_grad()\n",
213 | " out = model(data)\n",
214 | " loss = loss_func(out[:, -1, :],\n",
215 | " target)\n",
216 | "\n",
217 | " loss.backward()\n",
218 | " optimizer.step()\n",
219 | "\n",
220 | " perfs.append((torch.argmax(out[:, -1, :], dim=-1) == \n",
221 | " target).sum().item())\n",
222 | " perfs = perfs[int(-loss_buffer_size/batch_size):]\n",
223 | " losses.append(loss.detach().cpu().numpy())\n",
224 | " losses = losses[int(-loss_buffer_size/batch_size):]\n",
225 | " if not (prog_bar is None):\n",
226 | " # Update progress_bar\n",
227 | " s = \"{}:{} Loss: {:.4f}, perf: {:.4f}, valid: {:.4f}\"\n",
228 | " format_list = [e,batch_idx*batch_size, np.mean(losses), \n",
229 | " np.sum(perfs)/((len(perfs))*batch_size), last_test_perf] \n",
230 | " s = s.format(*format_list)\n",
231 | " prog_bar.set_description(s)\n",
232 | " \n",
233 | " if ((batch_idx*batch_size)%loss_buffer_size == 0) & (batch_idx != 0):\n",
234 | " loss_track = {}\n",
235 | " #last_test_perf = test(model, 'cuda', test_loader, \n",
236 | " # batch_size=batch_size, \n",
237 | " # )\n",
238 | " loss_track['avg_loss'] = np.mean(losses)\n",
239 | " loss_track['last_test'] = last_test_perf\n",
240 | " loss_track['epoch'] = epoch\n",
241 | " loss_track['batch_idx'] = batch_idx\n",
242 | " loss_track['train_perf']= np.sum(perfs)/((len(perfs))*batch_size)\n",
243 | " with open(perf_file, 'a+') as fp:\n",
244 | " csv_writer = DictWriter(fp, fieldnames=list(loss_track.keys()))\n",
245 | " if fp.tell() == 0:\n",
246 | " csv_writer.writeheader()\n",
247 | " csv_writer.writerow(loss_track)\n",
248 | " fp.flush()\n",
249 | " #if best_test_perf < last_test_perf:\n",
250 | " # torch.save(model.state_dict(), perf_file[:-4]+\".pt\")\n",
251 | " # best_test_perf = last_test_perf\n",
252 | "\n",
253 | " \n",
254 | "def test(model, device, test_loader, batch_size=4):\n",
255 | " model.eval()\n",
256 | " correct = 0\n",
257 | " count = 0\n",
258 | " with torch.no_grad():\n",
259 | " for data, target in test_loader:\n",
260 | " data = data.to(device).view(data.shape[0],1,1,-1)\n",
261 | " target = target.to(device)\n",
262 | " \n",
263 | " out = model(data)\n",
264 | " pred = out[:, -1].argmax(dim=-1, keepdim=True)\n",
265 | " correct += pred.eq(target.view_as(pred)).sum().item()\n",
266 | " count += 1\n",
267 | " return correct / len(test_loader.dataset)"
268 | ]
269 | },
270 | {
271 | "cell_type": "markdown",
272 | "metadata": {},
273 | "source": [
274 | "# Setup the model"
275 | ]
276 | },
277 | {
278 | "cell_type": "code",
279 | "execution_count": null,
280 | "metadata": {
281 | "ExecuteTime": {
282 | "end_time": "2021-03-15T13:46:18.898666Z",
283 | "start_time": "2021-03-15T13:46:15.331245Z"
284 | },
285 | "scrolled": true
286 | },
287 | "outputs": [],
288 | "source": [
289 | "g = 0.0\n",
290 | "sith_params1 = {\"in_features\":1, \n",
291 | " \"tau_min\":1, \"tau_max\":30.0, \"buff_max\":50,\n",
292 | " \"k\":125, 'dt':1,\n",
293 | " \"ntau\":20, 'g':g, \n",
294 | " \"ttype\":ttype, \"batch_norm\":True,\n",
295 | " \"hidden_size\":60, \"act_func\":nn.ReLU()\n",
296 | " }\n",
297 | "sith_params2 = {\"in_features\":sith_params1['hidden_size'], \n",
298 | " \"tau_min\":1, \"tau_max\":150.0, \"buff_max\":250,\n",
299 | " \"k\":61, 'dt':1,\n",
300 | " \"ntau\":20, 'g':g, \n",
301 | " \"ttype\":ttype, \"batch_norm\":True,\n",
302 | " \"hidden_size\":60, \"act_func\":nn.ReLU()\n",
303 | " }\n",
304 | "sith_params3 = {\"in_features\":sith_params2['hidden_size'], \n",
305 | " \"tau_min\":1, \"tau_max\":750.0, \"buff_max\":1500,\n",
306 | " \"k\":35, 'dt':1,\n",
307 | " \"ntau\":20, 'g':g, \n",
308 | " \"ttype\":ttype, \"batch_norm\":True,\n",
309 | " \"hidden_size\":60, \"act_func\":nn.ReLU()\n",
310 | " }\n",
311 | "\n",
312 | "layer_params = [sith_params1, sith_params2, sith_params3]\n",
313 | "\n",
314 | "\n",
315 | "\n",
316 | "model = DeepSITH_Classifier(10,\n",
317 | " layer_params=layer_params, \n",
318 | " dropout=0.2).cuda()\n",
319 | "\n",
320 | "tot_weights = 0\n",
321 | "for p in model.parameters():\n",
322 | " tot_weights += p.numel()\n",
323 | "print(\"Total Weights:\", tot_weights)\n",
324 | "print(model)"
325 | ]
326 | },
327 | {
328 | "cell_type": "code",
329 | "execution_count": null,
330 | "metadata": {
331 | "ExecuteTime": {
332 | "end_time": "2021-03-16T07:38:12.591705Z",
333 | "start_time": "2021-03-15T13:46:45.607305Z"
334 | },
335 | "scrolled": true
336 | },
337 | "outputs": [],
338 | "source": [
339 | "epochs = 40\n",
340 | "loss_func = nn.CrossEntropyLoss()\n",
341 | "optimizer = torch.optim.Adam(model.parameters(), lr=2e-3)\n",
342 | "sched = StepLR(optimizer, step_size=int(epochs / 4), gamma=0.1)\n",
343 | "#sched = None\n",
344 | "perf_file = join('perf','smnist_deepsith_4layer_01.csv')\n",
345 | "test_perf = []\n",
346 | "progress_bar = tqdm(range(int(epochs)), bar_format='{l_bar}{bar:5}{r_bar}{bar:-5b}')\n",
347 | "best_test_perf = 0\n",
348 | "t_p = .0000\n",
349 | "for e in progress_bar:\n",
350 | " train(model, ttype, train_loader, test_loader, optimizer, loss_func, batch_size=batch_size,\n",
351 | " epoch=e, perf_file=perf_file,loss_buffer_size=64*32, \n",
352 | " prog_bar=progress_bar, last_test_perf=t_p)\n",
353 | " \n",
354 | " t_p = test(model, 'cuda', test_loader, \n",
355 | " batch_size=batch_size, \n",
356 | " )\n",
357 | " if t_p > best_test_perf:\n",
358 | " best_test_perf = t_p\n",
359 | " torch.save(model.state_dict(), perf_file[:-4]+\".pt\")\n",
360 | " \n",
361 | " test_perf.append({\"epoch\":e,\n",
362 | " 'test':t_p})\n",
363 | " \n",
364 | " sched.step()"
365 | ]
366 | }
367 | ],
368 | "metadata": {
369 | "kernelspec": {
370 | "display_name": "Python 3",
371 | "language": "python",
372 | "name": "python3"
373 | },
374 | "language_info": {
375 | "codemirror_mode": {
376 | "name": "ipython",
377 | "version": 3
378 | },
379 | "file_extension": ".py",
380 | "mimetype": "text/x-python",
381 | "name": "python",
382 | "nbconvert_exporter": "python",
383 | "pygments_lexer": "ipython3",
384 | "version": "3.6.10"
385 | },
386 | "toc": {
387 | "nav_menu": {
388 | "height": "163px",
389 | "width": "250px"
390 | },
391 | "number_sections": true,
392 | "sideBar": true,
393 | "skip_h1_title": false,
394 | "title_cell": "Table of Contents",
395 | "title_sidebar": "Contents",
396 | "toc_cell": false,
397 | "toc_position": {},
398 | "toc_section_display": true,
399 | "toc_window_display": false
400 | }
401 | },
402 | "nbformat": 4,
403 | "nbformat_minor": 4
404 | }
405 |
--------------------------------------------------------------------------------
/experiments/MackeyGlass/MackeyGlass-LMU-increasing_tau.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "code",
5 | "execution_count": null,
6 | "metadata": {
7 | "ExecuteTime": {
8 | "end_time": "2021-01-31T16:15:36.355232Z",
9 | "start_time": "2021-01-31T16:15:36.174356Z"
10 | }
11 | },
12 | "outputs": [],
13 | "source": [
14 | "%matplotlib inline"
15 | ]
16 | },
17 | {
18 | "cell_type": "code",
19 | "execution_count": null,
20 | "metadata": {
21 | "ExecuteTime": {
22 | "end_time": "2021-01-31T16:15:37.469291Z",
23 | "start_time": "2021-01-31T16:15:36.530387Z"
24 | }
25 | },
26 | "outputs": [],
27 | "source": [
28 | "import torch\n",
29 | "import matplotlib.pyplot as plt\n",
30 | "import torchvision\n",
31 | "import numpy as np\n",
32 | "import torch.nn as nn\n",
33 | "import torch.nn.functional as F\n",
34 | "from tqdm import tqdm_notebook\n",
35 | "import PIL\n",
36 | "from torch.nn.utils import weight_norm\n",
37 | "from torch.autograd import Variable\n",
38 | "import nengolib\n",
39 | "device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')\n",
40 | "\n",
41 | "from os.path import join\n",
42 | "import scipy.special\n",
43 | "import pandas as pd\n",
44 | "import seaborn as sn\n",
45 | "import scipy\n",
46 | "from scipy.spatial.distance import euclidean\n",
47 | "from scipy.interpolate import interp1d\n",
48 | "from tqdm.notebook import tqdm\n",
49 | "import random\n",
50 | "from csv import DictWriter\n",
51 | "from lmu import LegendreMemoryUnit\n",
52 | "\n",
53 | "# if gpu is to be used\n",
54 | "use_cuda = torch.cuda.is_available()\n",
55 | "FloatTensor = torch.cuda.FloatTensor if use_cuda else torch.FloatTensor\n",
56 | "DoubleTensor = torch.cuda.DoubleTensor if use_cuda else torch.DoubleTensor\n",
57 | "\n",
58 | "IntTensor = torch.cuda.IntTensor if use_cuda else torch.IntTensor\n",
59 | "LongTensor = torch.cuda.LongTensor if use_cuda else torch.LongTensor\n",
60 | "ByteTensor = torch.cuda.ByteTensor if use_cuda else torch.ByteTensor\n",
61 | "ttype = FloatTensor\n",
62 | "\n",
63 | "import seaborn as sns\n",
64 | "print(use_cuda)\n",
65 | "import pickle\n"
66 | ]
67 | },
68 | {
69 | "cell_type": "code",
70 | "execution_count": null,
71 | "metadata": {
72 | "ExecuteTime": {
73 | "end_time": "2021-01-31T16:15:41.303102Z",
74 | "start_time": "2021-01-31T16:15:41.299727Z"
75 | }
76 | },
77 | "outputs": [],
78 | "source": [
79 | "class LMUModel(nn.Module):\n",
80 | " def __init__(self, n_out, layer_params):\n",
81 | " super(LMUModel, self).__init__()\n",
82 | " self.layers = nn.ModuleList([LegendreMemoryUnit(**layer_params[i])\n",
83 | " for i in range(len(layer_params))])\n",
84 | " self.dense = nn.Linear(layer_params[-1]['hidden_size'], n_out)\n",
85 | "\n",
86 | " \n",
87 | " def forward(self, x):\n",
88 | " for l in self.layers:\n",
89 | " x, _ = l(x) \n",
90 | " x = self.dense(x)\n",
91 | " return x"
92 | ]
93 | },
94 | {
95 | "cell_type": "code",
96 | "execution_count": null,
97 | "metadata": {
98 | "ExecuteTime": {
99 | "end_time": "2021-01-31T16:15:43.334751Z",
100 | "start_time": "2021-01-31T16:15:41.545777Z"
101 | }
102 | },
103 | "outputs": [],
104 | "source": [
105 | "\n",
106 | "\n",
107 | "lmu_params = [dict(input_dim=1, hidden_size=49, order=4, theta=4),\n",
108 | " dict(input_dim=49, hidden_size=49, order=4, theta=4),\n",
109 | " dict(input_dim=49, hidden_size=49, order=4, theta=4),\n",
110 | " dict(input_dim=49, hidden_size=49, order=4, theta=4),\n",
111 | " ]\n",
112 | "model = LMUModel(1, lmu_params).cuda()\n",
113 | "\n",
114 | "tot_weights = 0\n",
115 | "for p in model.parameters():\n",
116 | " tot_weights += p.numel()\n",
117 | "print(\"Total Weights:\", tot_weights)\n",
118 | "print(model)\n"
119 | ]
120 | },
121 | {
122 | "cell_type": "markdown",
123 | "metadata": {},
124 | "source": [
125 | "# Load Stimuli"
126 | ]
127 | },
128 | {
129 | "cell_type": "code",
130 | "execution_count": null,
131 | "metadata": {
132 | "ExecuteTime": {
133 | "end_time": "2021-01-31T16:15:45.466420Z",
134 | "start_time": "2021-01-31T16:15:45.459513Z"
135 | }
136 | },
137 | "outputs": [],
138 | "source": [
139 | "import collections\n",
140 | "\n",
141 | "def mackey_glass(sample_len=1000, tau=17, delta_t=10, seed=None, n_samples=1):\n",
142 | " # Adapted from https://github.com/mila-iqia/summerschool2015/blob/master/rnn_tutorial/synthetic.py\n",
143 | " '''\n",
144 | " mackey_glass(sample_len=1000, tau=17, seed = None, n_samples = 1) -> input\n",
145 | " Generate the Mackey Glass time-series. Parameters are:\n",
146 | " - sample_len: length of the time-series in timesteps. Default is 1000.\n",
147 | " - tau: delay of the MG - system. Commonly used values are tau=17 (mild \n",
148 | " chaos) and tau=30 (moderate chaos). Default is 17.\n",
149 | " - seed: to seed the random generator, can be used to generate the same\n",
150 | " timeseries at each invocation.\n",
151 | " - n_samples : number of samples to generate\n",
152 | " '''\n",
153 | " history_len = tau * delta_t \n",
154 | " # Initial conditions for the history of the system\n",
155 | " timeseries = 1.2\n",
156 | " \n",
157 | " if seed is not None:\n",
158 | " np.random.seed(seed)\n",
159 | "\n",
160 | " samples = []\n",
161 | "\n",
162 | " for _ in range(n_samples):\n",
163 | " history = collections.deque(1.2 * np.ones(history_len) + 0.2 * \\\n",
164 | " (np.random.rand(history_len) - 0.5))\n",
165 | " # Preallocate the array for the time-series\n",
166 | " inp = np.zeros((sample_len,1))\n",
167 | " \n",
168 | " for timestep in range(sample_len):\n",
169 | " for _ in range(delta_t):\n",
170 | " xtau = history.popleft()\n",
171 | " history.append(timeseries)\n",
172 | " timeseries = history[-1] + (0.2 * xtau / (1.0 + xtau ** 10) - \\\n",
173 | " 0.1 * history[-1]) / delta_t\n",
174 | " inp[timestep] = timeseries\n",
175 | " \n",
176 | " # Squash timeseries through tanh\n",
177 | " inp = np.tanh(inp - 1)\n",
178 | " samples.append(inp)\n",
179 | " return samples\n",
180 | "\n",
181 | "\n",
182 | "def generate_data(n_batches, length, split=0.5, seed=0,\n",
183 | " predict_length=15, tau=17, washout=100, delta_t=1,\n",
184 | " center=True):\n",
185 | " X = np.asarray(mackey_glass(\n",
186 | " sample_len=length+predict_length+washout, tau=tau,\n",
187 | " seed=seed, n_samples=n_batches))\n",
188 | " X = X[:, washout:, :]\n",
189 | " cutoff = int(split*n_batches)\n",
190 | " if center:\n",
191 | " X -= np.mean(X) # global mean over all batches, approx -0.066\n",
192 | " Y = X[:, predict_length:, :]\n",
193 | " X = X[:, :-predict_length, :]\n",
194 | " assert X.shape == Y.shape\n",
195 | " return ((X[:cutoff], Y[:cutoff]),\n",
196 | " (X[cutoff:], Y[cutoff:]))"
197 | ]
198 | },
199 | {
200 | "cell_type": "code",
201 | "execution_count": null,
202 | "metadata": {
203 | "ExecuteTime": {
204 | "end_time": "2021-01-31T16:15:48.056808Z",
205 | "start_time": "2021-01-31T16:15:45.670627Z"
206 | }
207 | },
208 | "outputs": [],
209 | "source": [
210 | "(train_X, train_Y), (test_X, test_Y) = generate_data(128, 5000)\n",
211 | "dataset = torch.utils.data.TensorDataset(torch.Tensor(train_X).cuda(), torch.Tensor(train_Y).cuda())\n",
212 | "dataset = torch.utils.data.DataLoader(dataset, batch_size=32, shuffle=True)\n",
213 | "\n",
214 | "dataset_valid = torch.utils.data.TensorDataset(torch.Tensor(test_X).cuda(), torch.Tensor(test_Y).cuda())\n",
215 | "dataset_valid = torch.utils.data.DataLoader(dataset_valid, batch_size=64, shuffle=False)\n",
216 | "\n",
217 | "print(train_X.shape, train_Y.shape, test_X.shape)"
218 | ]
219 | },
220 | {
221 | "cell_type": "markdown",
222 | "metadata": {},
223 | "source": [
224 | "## Setup for Model"
225 | ]
226 | },
227 | {
228 | "cell_type": "code",
229 | "execution_count": null,
230 | "metadata": {
231 | "ExecuteTime": {
232 | "end_time": "2021-01-31T16:16:09.137416Z",
233 | "start_time": "2021-01-31T16:16:09.127621Z"
234 | }
235 | },
236 | "outputs": [],
237 | "source": [
238 | "\n",
239 | "def train(model, ttype, train_loader, test_loader, optimizer, loss_func, epoch, perf_file,\n",
240 | " loss_buffer_size=800, batch_size=4, device='cuda',\n",
241 | " prog_bar=None, tau=17, pred_len=15):\n",
242 | " \n",
243 | " assert(loss_buffer_size%batch_size==0)\n",
244 | " \n",
245 | " losses = []\n",
246 | " last_test_perf = 0\n",
247 | " best_test_perf = 1000\n",
248 | " \n",
249 | " for batch_idx, (data, target) in enumerate(train_loader):\n",
250 | " model.train()\n",
251 | " data = data.to(device)\n",
252 | " target = target.to(device)\n",
253 | " optimizer.zero_grad()\n",
254 | " out = model(data)\n",
255 | " loss = loss_func(out,\n",
256 | " target)\n",
257 | "\n",
258 | " loss.backward()\n",
259 | " optimizer.step()\n",
260 | "\n",
261 | " losses.append(loss.detach().cpu().numpy())\n",
262 | " losses = losses[int(-loss_buffer_size/batch_size):]\n",
263 | " \n",
264 | " if ((batch_idx*batch_size)%loss_buffer_size == 0):\n",
265 | " loss_track = {}\n",
266 | " last_test_perf = test_model(model, 'cuda', test_loader, \n",
267 | " )\n",
268 | " loss_track['avg_loss'] = np.mean(losses)\n",
269 | " loss_track['last_test'] = last_test_perf\n",
270 | " loss_track['epoch'] = epoch\n",
271 | " loss_track['batch_idx'] = batch_idx\n",
272 | " loss_track['tau'] = tau\n",
273 | " loss_track['pred_len'] = pred_len\n",
274 | " with open(perf_file, 'a+') as fp:\n",
275 | " csv_writer = DictWriter(fp, fieldnames=list(loss_track.keys()))\n",
276 | " if fp.tell() == 0:\n",
277 | " csv_writer.writeheader()\n",
278 | " csv_writer.writerow(loss_track)\n",
279 | " fp.flush()\n",
280 | " if best_test_perf > last_test_perf:\n",
281 | " torch.save(model.state_dict(), perf_file[:-4]+\".pt\")\n",
282 | " best_test_perf = last_test_perf\n",
283 | " if not (prog_bar is None):\n",
284 | " # Update progress_bar\n",
285 | " s = \"{}:{} Loss: {:.4f},valid: {:.4f}\"\n",
286 | " format_list = [e,batch_idx*batch_size, np.mean(losses), \n",
287 | " last_test_perf] \n",
288 | " s = s.format(*format_list)\n",
289 | " prog_bar.set_description(s)\n",
290 | " \n",
291 | "def test_model(model, device, test_loader):\n",
292 | " # Test the Model\n",
293 | " nrmsd = []\n",
294 | " with torch.no_grad():\n",
295 | " for x, y in test_loader:\n",
296 | " data = x.to(device)\n",
297 | " target = y.to(device)\n",
298 | " optimizer.zero_grad()\n",
299 | " out = model(data)\n",
300 | " nrmsd.append(nengolib.signal.nrmse(out.detach().cpu().numpy().flatten(), \n",
301 | " target=target.detach().cpu().numpy().flatten()))\n",
302 | " perf = np.array(nrmsd).mean()\n",
303 | " return perf"
304 | ]
305 | },
306 | {
307 | "cell_type": "code",
308 | "execution_count": null,
309 | "metadata": {
310 | "ExecuteTime": {
311 | "end_time": "2021-02-03T05:20:46.823977Z",
312 | "start_time": "2021-02-02T19:40:22.179973Z"
313 | }
314 | },
315 | "outputs": [],
316 | "source": [
317 | "start_tau = 17\n",
318 | "start_pd = 15\n",
319 | "diffs = [1,2,3,4,5]\n",
320 | "for diff in diffs:\n",
321 | " (train_X, train_Y), (test_X, test_Y) = generate_data(128, 5000, tau=start_tau*diff, \n",
322 | " predict_length=start_pd*diff)\n",
323 | "\n",
324 | " dataset = torch.utils.data.TensorDataset(torch.Tensor(train_X).cuda(), torch.Tensor(train_Y).cuda())\n",
325 | " dataset = torch.utils.data.DataLoader(dataset, batch_size=32, shuffle=True)\n",
326 | "\n",
327 | " dataset_valid = torch.utils.data.TensorDataset(torch.Tensor(test_X).cuda(), torch.Tensor(test_Y).cuda())\n",
328 | " dataset_valid = torch.utils.data.DataLoader(dataset_valid, batch_size=64, shuffle=False)\n",
329 | " lmu_params = [dict(input_dim=1, hidden_size=49, order=4, theta=4),\n",
330 | " dict(input_dim=49, hidden_size=49, order=4, theta=4),\n",
331 | " dict(input_dim=49, hidden_size=49, order=4, theta=4),\n",
332 | " dict(input_dim=49, hidden_size=49, order=4, theta=4),\n",
333 | " ]\n",
334 | " model = LMUModel(1, lmu_params).cuda()\n",
335 | " optimizer = torch.optim.Adam(model.parameters())\n",
336 | " loss_func = nn.MSELoss()\n",
337 | "\n",
338 | " epochs = 1000\n",
339 | " batch_size = 32\n",
340 | " progress_bar = tqdm(range(int(epochs)), bar_format='{l_bar}{bar:5}{r_bar}{bar:-5b}')\n",
341 | " last_perf = 1000\n",
342 | " for e in progress_bar:\n",
343 | " train(model, ttype, dataset, dataset_valid, \n",
344 | " optimizer, loss_func, batch_size=batch_size, loss_buffer_size=64,\n",
345 | " epoch=e, perf_file=join('perf','mackeyglass_lmu_ratio_3.csv'),\n",
346 | " prog_bar=progress_bar, tau=start_tau*diff, pred_len=start_pd*diff)"
347 | ]
348 | }
349 | ],
350 | "metadata": {
351 | "kernelspec": {
352 | "display_name": "Python 3",
353 | "language": "python",
354 | "name": "python3"
355 | },
356 | "language_info": {
357 | "codemirror_mode": {
358 | "name": "ipython",
359 | "version": 3
360 | },
361 | "file_extension": ".py",
362 | "mimetype": "text/x-python",
363 | "name": "python",
364 | "nbconvert_exporter": "python",
365 | "pygments_lexer": "ipython3",
366 | "version": "3.6.10"
367 | },
368 | "toc": {
369 | "nav_menu": {},
370 | "number_sections": true,
371 | "sideBar": true,
372 | "skip_h1_title": false,
373 | "title_cell": "Table of Contents",
374 | "title_sidebar": "Contents",
375 | "toc_cell": false,
376 | "toc_position": {},
377 | "toc_section_display": true,
378 | "toc_window_display": false
379 | }
380 | },
381 | "nbformat": 4,
382 | "nbformat_minor": 4
383 | }
384 |
--------------------------------------------------------------------------------
/experiments/AddingProblem/addingproblem-LSTM.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "code",
5 | "execution_count": null,
6 | "metadata": {
7 | "ExecuteTime": {
8 | "end_time": "2020-12-15T17:01:06.036477Z",
9 | "start_time": "2020-12-15T17:01:05.842293Z"
10 | }
11 | },
12 | "outputs": [],
13 | "source": [
14 | "%matplotlib inline"
15 | ]
16 | },
17 | {
18 | "cell_type": "code",
19 | "execution_count": null,
20 | "metadata": {
21 | "ExecuteTime": {
22 | "end_time": "2020-12-15T17:01:06.655046Z",
23 | "start_time": "2020-12-15T17:01:06.090606Z"
24 | }
25 | },
26 | "outputs": [],
27 | "source": [
28 | "import matplotlib.pylab as plt\n",
29 | "import torch\n",
30 | "import numpy as np\n",
31 | "import seaborn as sn\n",
32 | "sn.set_context(\"poster\")\n",
33 | "import os\n",
34 | "import torch\n",
35 | "import torch.nn as nn\n",
36 | "import torch.nn.functional as F\n",
37 | "from torch.nn.utils import weight_norm\n",
38 | "from torchvision import transforms, datasets\n",
39 | "\n",
40 | "from deepsith import DeepSITH\n",
41 | "\n",
42 | "import numpy as np\n",
43 | "import scipy\n",
44 | "import scipy.stats as st\n",
45 | "import scipy.special\n",
46 | "import scipy.signal\n",
47 | "import scipy.interpolate\n",
48 | "\n",
49 | "import pandas as pd\n",
50 | "\n",
51 | "from os.path import join\n",
52 | "import random\n",
53 | "from csv import DictWriter\n",
54 | "\n",
55 | "from tqdm.notebook import tqdm\n",
56 | "import pickle\n",
57 | "# if gpu is to be used\n",
58 | "use_cuda = torch.cuda.is_available()\n",
59 | "print(use_cuda)\n",
60 | "FloatTensor = torch.cuda.FloatTensor if use_cuda else torch.FloatTensor\n",
61 | "DoubleTensor = torch.cuda.DoubleTensor if use_cuda else torch.DoubleTensor\n",
62 | "IntTensor = torch.cuda.IntTensor if use_cuda else torch.IntTensor\n",
63 | "LongTensor = torch.cuda.LongTensor if use_cuda else torch.LongTensor\n",
64 | "ByteTensor = torch.cuda.ByteTensor if use_cuda else torch.ByteTensor\n",
65 | "ttype = FloatTensor"
66 | ]
67 | },
68 | {
69 | "cell_type": "code",
70 | "execution_count": null,
71 | "metadata": {
72 | "ExecuteTime": {
73 | "end_time": "2020-12-15T17:01:06.661735Z",
74 | "start_time": "2020-12-15T17:01:06.656351Z"
75 | }
76 | },
77 | "outputs": [],
78 | "source": [
79 | "import torch\n",
80 | "import numpy as np\n",
81 | "\n",
82 | "def get_batch(batch_size, T, ttype):\n",
83 | " values = torch.rand(T, batch_size, requires_grad=False)\n",
84 | " indices = torch.zeros_like(values)\n",
85 | " half = int(T / 2)\n",
86 | " for i in range(batch_size):\n",
87 | " half_1 = np.random.randint(half)\n",
88 | " hals_2 = np.random.randint(half, T)\n",
89 | " indices[half_1, i] = 1\n",
90 | " indices[hals_2, i] = 1\n",
91 | "\n",
92 | " data = torch.stack((values, indices), dim=-1).type(ttype)\n",
93 | " targets = torch.mul(values, indices).sum(dim=0).type(ttype)\n",
94 | " return data, targets\n"
95 | ]
96 | },
97 | {
98 | "cell_type": "code",
99 | "execution_count": null,
100 | "metadata": {
101 | "ExecuteTime": {
102 | "end_time": "2020-12-15T17:01:09.231546Z",
103 | "start_time": "2020-12-15T17:01:09.229034Z"
104 | }
105 | },
106 | "outputs": [],
107 | "source": [
108 | "torch.manual_seed(1111)"
109 | ]
110 | },
111 | {
112 | "cell_type": "code",
113 | "execution_count": null,
114 | "metadata": {
115 | "ExecuteTime": {
116 | "end_time": "2020-12-15T17:01:09.249505Z",
117 | "start_time": "2020-12-15T17:01:09.232698Z"
118 | }
119 | },
120 | "outputs": [],
121 | "source": [
122 | "def train(model, ttype, seq_length, optimizer, loss_func, \n",
123 | " epoch, perf_file, loss_buffer_size=20, batch_size=1, test_size=10,\n",
124 | " device='cuda', prog_bar=None):\n",
125 | " assert(loss_buffer_size%batch_size==0)\n",
126 | "\n",
127 | " losses = []\n",
128 | " perfs = []\n",
129 | " last_test_perf = 0\n",
130 | " for batch_idx in range(20000):\n",
131 | " model.train()\n",
132 | " sig, target = get_batch(batch_size, seq_length, ttype=ttype)\n",
133 | " \n",
134 | " target = target.unsqueeze(1)\n",
135 | " optimizer.zero_grad()\n",
136 | " out = model(sig)\n",
137 | " loss = loss_func(out[-1, :, :],\n",
138 | " target)\n",
139 | " \n",
140 | " loss.backward()\n",
141 | " optimizer.step()\n",
142 | "\n",
143 | " losses.append(loss.detach().cpu().numpy())\n",
144 | " losses = losses[-loss_buffer_size:]\n",
145 | " if not (prog_bar is None):\n",
146 | " # Update progress_bar\n",
147 | " s = \"{}:{} Loss: {:.8f}\"\n",
148 | " format_list = [e, int(batch_idx/(50/batch_size)), np.mean(losses)] \n",
149 | " s = s.format(*format_list)\n",
150 | " prog_bar.set_description(s)\n",
151 | " if ((batch_idx*batch_size)%loss_buffer_size == 0) & (batch_idx != 0):\n",
152 | " loss_track = {}\n",
153 | " #last_test_perf = test_norm(model, 'cuda', test_sig, test_class,\n",
154 | " # batch_size=test_size, \n",
155 | " # )\n",
156 | " loss_track['avg_loss'] = np.mean(losses)\n",
157 | " #loss_track['last_test'] = last_test_perf\n",
158 | " loss_track['epoch'] = epoch\n",
159 | " loss_track['batch_idx'] = batch_idx\n",
160 | " with open(perf_file, 'a+') as fp:\n",
161 | " csv_writer = DictWriter(fp, fieldnames=list(loss_track.keys()))\n",
162 | " if fp.tell() == 0:\n",
163 | " csv_writer.writeheader()\n",
164 | " csv_writer.writerow(loss_track)\n",
165 | " fp.flush()\n",
166 | "def test_norm(model, device, seq_length, loss_func, batch_size=100):\n",
167 | " model.eval()\n",
168 | " correct = 0\n",
169 | " count = 0\n",
170 | " with torch.no_grad():\n",
171 | " sig, target = get_batch(batch_size, seq_length, ttype=ttype)\n",
172 | " target = target.unsqueeze(1)\n",
173 | " out = model(sig)\n",
174 | " loss = loss_func(out[-1, :, :],\n",
175 | " target)\n",
176 | " return loss"
177 | ]
178 | },
179 | {
180 | "cell_type": "code",
181 | "execution_count": null,
182 | "metadata": {
183 | "ExecuteTime": {
184 | "end_time": "2020-12-15T17:01:09.258255Z",
185 | "start_time": "2020-12-15T17:01:09.250462Z"
186 | }
187 | },
188 | "outputs": [],
189 | "source": [
190 | "class LSTM_Predictor(nn.Module):\n",
191 | " def __init__(self, out_features, lstm_params):\n",
192 | " super(LSTM_Predictor, self).__init__()\n",
193 | " self.lstm = nn.LSTM(**lstm_params)\n",
194 | " self.to_out = nn.Linear(lstm_params['hidden_size'], \n",
195 | " out_features)\n",
196 | " def forward(self, inp):\n",
197 | " x = self.lstm(inp)[0]\n",
198 | " x = self.to_out(x)\n",
199 | " return x"
200 | ]
201 | },
202 | {
203 | "cell_type": "markdown",
204 | "metadata": {},
205 | "source": [
206 | "# T = 100"
207 | ]
208 | },
209 | {
210 | "cell_type": "code",
211 | "execution_count": null,
212 | "metadata": {
213 | "ExecuteTime": {
214 | "end_time": "2020-12-15T17:01:23.033948Z",
215 | "start_time": "2020-12-15T17:01:23.026221Z"
216 | }
217 | },
218 | "outputs": [],
219 | "source": [
220 | "lstm_params = dict(input_size=2,\n",
221 | " hidden_size=128, \n",
222 | " num_layers=1)\n",
223 | "model = LSTM_Predictor(1, lstm_params=lstm_params).cuda()\n",
224 | "\n",
225 | "tot_weights = 0\n",
226 | "for p in model.parameters():\n",
227 | " tot_weights += p.numel()\n",
228 | "print(\"Total Weights:\", tot_weights)\n",
229 | "print(model)"
230 | ]
231 | },
232 | {
233 | "cell_type": "code",
234 | "execution_count": null,
235 | "metadata": {
236 | "ExecuteTime": {
237 | "end_time": "2020-12-15T17:03:23.646330Z",
238 | "start_time": "2020-12-15T17:01:30.304957Z"
239 | }
240 | },
241 | "outputs": [],
242 | "source": [
243 | "seq_length=100\n",
244 | "\n",
245 | "loss_func = nn.MSELoss()\n",
246 | "optimizer = torch.optim.Adam(model.parameters())\n",
247 | "epochs = 1\n",
248 | "progress_bar = tqdm(range(int(epochs)), bar_format='{l_bar}{bar:5}{r_bar}{bar:-5b}')\n",
249 | "for e in progress_bar:\n",
250 | " train(model, ttype, seq_length,\n",
251 | " optimizer, loss_func, batch_size=50, loss_buffer_size=100,\n",
252 | " epoch=e, perf_file=join('perf','adding100_lstm_1.csv'),\n",
253 | " prog_bar=progress_bar)"
254 | ]
255 | },
256 | {
257 | "cell_type": "markdown",
258 | "metadata": {},
259 | "source": [
260 | "# T = 500"
261 | ]
262 | },
263 | {
264 | "cell_type": "code",
265 | "execution_count": null,
266 | "metadata": {
267 | "ExecuteTime": {
268 | "end_time": "2020-11-17T18:59:57.170037Z",
269 | "start_time": "2020-11-17T18:59:57.165816Z"
270 | }
271 | },
272 | "outputs": [],
273 | "source": [
274 | "lstm_params = dict(input_size=2,\n",
275 | " hidden_size=128, \n",
276 | " num_layers=1)\n",
277 | "model = LSTM_Predictor(1, lstm_params=lstm_params).cuda()\n",
278 | "\n",
279 | "tot_weights = 0\n",
280 | "for p in model.parameters():\n",
281 | " tot_weights += p.numel()\n",
282 | "print(\"Total Weights:\", tot_weights)\n",
283 | "print(model)"
284 | ]
285 | },
286 | {
287 | "cell_type": "code",
288 | "execution_count": null,
289 | "metadata": {
290 | "ExecuteTime": {
291 | "end_time": "2020-11-17T19:06:41.299334Z",
292 | "start_time": "2020-11-17T19:00:01.467672Z"
293 | }
294 | },
295 | "outputs": [],
296 | "source": [
297 | "seq_length=500\n",
298 | "\n",
299 | "loss_func = nn.MSELoss()\n",
300 | "optimizer = torch.optim.Adam(model.parameters())\n",
301 | "epochs = 1\n",
302 | "progress_bar = tqdm(range(int(epochs)), bar_format='{l_bar}{bar:5}{r_bar}{bar:-5b}')\n",
303 | "for e in progress_bar:\n",
304 | " train(model, ttype, seq_length,\n",
305 | " optimizer, loss_func, batch_size=50, loss_buffer_size=100,\n",
306 | " epoch=e, perf_file=join('perf','adding500_lstm_3.csv'),\n",
307 | " prog_bar=progress_bar)"
308 | ]
309 | },
310 | {
311 | "cell_type": "markdown",
312 | "metadata": {},
313 | "source": [
314 | "# T = 2000"
315 | ]
316 | },
317 | {
318 | "cell_type": "code",
319 | "execution_count": null,
320 | "metadata": {
321 | "ExecuteTime": {
322 | "end_time": "2020-11-17T19:09:55.459792Z",
323 | "start_time": "2020-11-17T19:09:55.454074Z"
324 | }
325 | },
326 | "outputs": [],
327 | "source": [
328 | "lstm_params = dict(input_size=2,\n",
329 | " hidden_size=128, \n",
330 | " num_layers=1)\n",
331 | "model = LSTM_Predictor(1, lstm_params=lstm_params).cuda()\n",
332 | "\n",
333 | "tot_weights = 0\n",
334 | "for p in model.parameters():\n",
335 | " tot_weights += p.numel()\n",
336 | "print(\"Total Weights:\", tot_weights)\n",
337 | "print(model)"
338 | ]
339 | },
340 | {
341 | "cell_type": "code",
342 | "execution_count": null,
343 | "metadata": {
344 | "ExecuteTime": {
345 | "end_time": "2020-11-17T19:33:57.545928Z",
346 | "start_time": "2020-11-17T19:10:03.115405Z"
347 | }
348 | },
349 | "outputs": [],
350 | "source": [
351 | "seq_length=2000\n",
352 | "\n",
353 | "loss_func = nn.MSELoss()\n",
354 | "optimizer = torch.optim.Adam(model.parameters())\n",
355 | "epochs = 1\n",
356 | "progress_bar = tqdm(range(int(epochs)), bar_format='{l_bar}{bar:5}{r_bar}{bar:-5b}')\n",
357 | "for e in progress_bar:\n",
358 | " train(model, ttype, seq_length,\n",
359 | " optimizer, loss_func, batch_size=50, loss_buffer_size=100,\n",
360 | " epoch=e, perf_file=join('perf','adding2000_lstm_2.csv'),\n",
361 | " prog_bar=progress_bar)"
362 | ]
363 | },
364 | {
365 | "cell_type": "markdown",
366 | "metadata": {},
367 | "source": [
368 | "# T = 5000"
369 | ]
370 | },
371 | {
372 | "cell_type": "code",
373 | "execution_count": null,
374 | "metadata": {
375 | "ExecuteTime": {
376 | "end_time": "2020-11-17T19:33:57.551416Z",
377 | "start_time": "2020-11-17T19:33:57.546926Z"
378 | }
379 | },
380 | "outputs": [],
381 | "source": [
382 | "lstm_params = dict(input_size=2,\n",
383 | " hidden_size=128, \n",
384 | " num_layers=1)\n",
385 | "model = LSTM_Predictor(1, lstm_params=lstm_params).cuda()\n",
386 | "\n",
387 | "tot_weights = 0\n",
388 | "for p in model.parameters():\n",
389 | " tot_weights += p.numel()\n",
390 | "print(\"Total Weights:\", tot_weights)\n",
391 | "print(model)"
392 | ]
393 | },
394 | {
395 | "cell_type": "code",
396 | "execution_count": null,
397 | "metadata": {
398 | "ExecuteTime": {
399 | "end_time": "2020-11-17T20:41:33.252877Z",
400 | "start_time": "2020-11-17T19:33:57.552380Z"
401 | }
402 | },
403 | "outputs": [],
404 | "source": [
405 | "seq_length=5000\n",
406 | "\n",
407 | "loss_func = nn.MSELoss()\n",
408 | "optimizer = torch.optim.Adam(model.parameters())\n",
409 | "epochs = 1\n",
410 | "progress_bar = tqdm(range(int(epochs)), bar_format='{l_bar}{bar:5}{r_bar}{bar:-5b}')\n",
411 | "for e in progress_bar:\n",
412 | " train(model, ttype, seq_length,\n",
413 | " optimizer, loss_func, batch_size=50, loss_buffer_size=100,\n",
414 | " epoch=e, perf_file=join('perf','adding5000_lstm_1.csv'),\n",
415 | " prog_bar=progress_bar)"
416 | ]
417 | },
418 | {
419 | "cell_type": "code",
420 | "execution_count": null,
421 | "metadata": {},
422 | "outputs": [],
423 | "source": []
424 | }
425 | ],
426 | "metadata": {
427 | "kernelspec": {
428 | "display_name": "Python 3",
429 | "language": "python",
430 | "name": "python3"
431 | },
432 | "language_info": {
433 | "codemirror_mode": {
434 | "name": "ipython",
435 | "version": 3
436 | },
437 | "file_extension": ".py",
438 | "mimetype": "text/x-python",
439 | "name": "python",
440 | "nbconvert_exporter": "python",
441 | "pygments_lexer": "ipython3",
442 | "version": "3.6.10"
443 | },
444 | "toc": {
445 | "nav_menu": {
446 | "height": "141px",
447 | "width": "160px"
448 | },
449 | "number_sections": true,
450 | "sideBar": true,
451 | "skip_h1_title": false,
452 | "title_cell": "Table of Contents",
453 | "title_sidebar": "Contents",
454 | "toc_cell": false,
455 | "toc_position": {},
456 | "toc_section_display": true,
457 | "toc_window_display": false
458 | }
459 | },
460 | "nbformat": 4,
461 | "nbformat_minor": 4
462 | }
463 |
--------------------------------------------------------------------------------
/experiments/AddingProblem/addingproblem-LMU.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "code",
5 | "execution_count": null,
6 | "metadata": {
7 | "ExecuteTime": {
8 | "end_time": "2021-03-11T17:02:45.247984Z",
9 | "start_time": "2021-03-11T17:02:45.076329Z"
10 | }
11 | },
12 | "outputs": [],
13 | "source": [
14 | "%matplotlib inline"
15 | ]
16 | },
17 | {
18 | "cell_type": "code",
19 | "execution_count": null,
20 | "metadata": {
21 | "ExecuteTime": {
22 | "end_time": "2021-03-11T17:02:46.318135Z",
23 | "start_time": "2021-03-11T17:02:45.266959Z"
24 | }
25 | },
26 | "outputs": [],
27 | "source": [
28 | "import matplotlib.pylab as plt\n",
29 | "import torch\n",
30 | "import numpy as np\n",
31 | "import seaborn as sn\n",
32 | "sn.set_context(\"poster\")\n",
33 | "import os\n",
34 | "import torch\n",
35 | "import torch.nn as nn\n",
36 | "import torch.nn.functional as F\n",
37 | "from torch.nn.utils import weight_norm\n",
38 | "from torchvision import transforms, datasets\n",
39 | "\n",
40 | "#from lmu import LegendreMemoryUnit\n",
41 | "from LMU import LegendreMemoryUnit\n",
42 | "import numpy as np\n",
43 | "import scipy\n",
44 | "import scipy.stats as st\n",
45 | "import scipy.special\n",
46 | "import scipy.signal\n",
47 | "import scipy.interpolate\n",
48 | "\n",
49 | "import pandas as pd\n",
50 | "\n",
51 | "from os.path import join\n",
52 | "import random\n",
53 | "from csv import DictWriter\n",
54 | "\n",
55 | "from tqdm.notebook import tqdm\n",
56 | "import pickle\n",
57 | "# if gpu is to be used\n",
58 | "use_cuda = torch.cuda.is_available()\n",
59 | "print(use_cuda)\n",
60 | "FloatTensor = torch.cuda.FloatTensor if use_cuda else torch.FloatTensor\n",
61 | "DoubleTensor = torch.cuda.DoubleTensor if use_cuda else torch.DoubleTensor\n",
62 | "IntTensor = torch.cuda.IntTensor if use_cuda else torch.IntTensor\n",
63 | "LongTensor = torch.cuda.LongTensor if use_cuda else torch.LongTensor\n",
64 | "ByteTensor = torch.cuda.ByteTensor if use_cuda else torch.ByteTensor\n",
65 | "ttype = FloatTensor"
66 | ]
67 | },
68 | {
69 | "cell_type": "code",
70 | "execution_count": null,
71 | "metadata": {
72 | "ExecuteTime": {
73 | "end_time": "2021-03-11T17:02:46.322341Z",
74 | "start_time": "2021-03-11T17:02:46.319117Z"
75 | }
76 | },
77 | "outputs": [],
78 | "source": [
79 | "import torch\n",
80 | "import numpy as np\n",
81 | "\n",
82 | "def get_batch(batch_size, T, ttype):\n",
83 | " values = torch.rand(T, batch_size, requires_grad=False)\n",
84 | " indices = torch.zeros_like(values)\n",
85 | " half = int(T / 2)\n",
86 | " for i in range(batch_size):\n",
87 | " half_1 = np.random.randint(half)\n",
88 | " hals_2 = np.random.randint(half, T)\n",
89 | " indices[half_1, i] = 1\n",
90 | " indices[hals_2, i] = 1\n",
91 | "\n",
92 | " data = torch.stack((values, indices), dim=-1).type(ttype)\n",
93 | " targets = torch.mul(values, indices).sum(dim=0).type(ttype)\n",
94 | " return data, targets\n"
95 | ]
96 | },
97 | {
98 | "cell_type": "code",
99 | "execution_count": null,
100 | "metadata": {
101 | "ExecuteTime": {
102 | "end_time": "2021-03-11T17:02:51.926506Z",
103 | "start_time": "2021-03-11T17:02:51.923167Z"
104 | }
105 | },
106 | "outputs": [],
107 | "source": [
108 | "torch.manual_seed(1111)"
109 | ]
110 | },
111 | {
112 | "cell_type": "code",
113 | "execution_count": null,
114 | "metadata": {
115 | "ExecuteTime": {
116 | "end_time": "2021-03-11T17:02:52.144538Z",
117 | "start_time": "2021-03-11T17:02:52.137566Z"
118 | }
119 | },
120 | "outputs": [],
121 | "source": [
122 | "def train(model, ttype, seq_length, optimizer, loss_func, \n",
123 | " epoch, perf_file, loss_buffer_size=20, batch_size=1, test_size=10,\n",
124 | " device='cuda', prog_bar=None):\n",
125 | " assert(loss_buffer_size%batch_size==0)\n",
126 | "\n",
127 | " losses = []\n",
128 | " perfs = []\n",
129 | " last_test_perf = 0\n",
130 | " for batch_idx in range(20000):\n",
131 | " model.train()\n",
132 | " sig, target = get_batch(batch_size, seq_length, ttype=ttype)\n",
133 | " \n",
134 | " target = target.unsqueeze(1)\n",
135 | " optimizer.zero_grad()\n",
136 | " out = model(sig.transpose(1,0))\n",
137 | " loss = loss_func(out[:, -1],\n",
138 | " target)\n",
139 | " \n",
140 | " loss.backward()\n",
141 | " optimizer.step()\n",
142 | "\n",
143 | " losses.append(loss.detach().cpu().numpy())\n",
144 | " losses = losses[-loss_buffer_size:]\n",
145 | " if not (prog_bar is None):\n",
146 | " # Update progress_bar\n",
147 | " s = \"{}:{} Loss: {:.8f}\"\n",
148 | " format_list = [e, int(batch_idx/(50/batch_size)), np.mean(losses)] \n",
149 | " s = s.format(*format_list)\n",
150 | " prog_bar.set_description(s)\n",
151 | " if ((batch_idx*batch_size)%loss_buffer_size == 0) & (batch_idx != 0):\n",
152 | " loss_track = {}\n",
153 | " #last_test_perf = test_norm(model, 'cuda', test_sig, test_class,\n",
154 | " # batch_size=test_size, \n",
155 | " # )\n",
156 | " loss_track['avg_loss'] = np.mean(losses)\n",
157 | " #loss_track['last_test'] = last_test_perf\n",
158 | " loss_track['epoch'] = epoch\n",
159 | " loss_track['batch_idx'] = batch_idx\n",
160 | " with open(perf_file, 'a+') as fp:\n",
161 | " csv_writer = DictWriter(fp, fieldnames=list(loss_track.keys()))\n",
162 | " if fp.tell() == 0:\n",
163 | " csv_writer.writeheader()\n",
164 | " csv_writer.writerow(loss_track)\n",
165 | " fp.flush()\n",
166 | "def test_norm(model, device, seq_length, loss_func, batch_size=100):\n",
167 | " model.eval()\n",
168 | " correct = 0\n",
169 | " count = 0\n",
170 | " with torch.no_grad():\n",
171 | " sig, target = get_batch(batch_size, seq_length, ttype=ttype)\n",
172 | " target = target.unsqueeze(1)\n",
173 | " out = model(sig.transpose(1,0))\n",
174 | " loss = loss_func(out[:, -1],\n",
175 | " target)\n",
176 | " return loss"
177 | ]
178 | },
179 | {
180 | "cell_type": "code",
181 | "execution_count": null,
182 | "metadata": {
183 | "ExecuteTime": {
184 | "end_time": "2021-03-11T17:02:52.334043Z",
185 | "start_time": "2021-03-11T17:02:52.330865Z"
186 | }
187 | },
188 | "outputs": [],
189 | "source": [
190 | "class LMUModel(nn.Module):\n",
191 | " def __init__(self, n_out, layer_params):\n",
192 | " super(LMUModel, self).__init__()\n",
193 | " self.layers = nn.ModuleList([LegendreMemoryUnit(**layer_params[i])\n",
194 | " for i in range(len(layer_params))])\n",
195 | " self.dense = nn.Linear(layer_params[-1]['units'], n_out)\n",
196 | "\n",
197 | " \n",
198 | " def forward(self, x):\n",
199 | " for l in self.layers:\n",
200 | " x, _ = l(x) \n",
201 | " x = self.dense(x)\n",
202 | " return x"
203 | ]
204 | },
205 | {
206 | "cell_type": "markdown",
207 | "metadata": {},
208 | "source": [
209 | "# T = 100"
210 | ]
211 | },
212 | {
213 | "cell_type": "code",
214 | "execution_count": null,
215 | "metadata": {
216 | "ExecuteTime": {
217 | "end_time": "2021-03-11T17:31:48.855245Z",
218 | "start_time": "2021-03-11T17:31:48.375043Z"
219 | }
220 | },
221 | "outputs": [],
222 | "source": [
223 | "seq_length=100\n",
224 | "\n",
225 | "lmu_params = [dict(input_dim=2, units=25, order=1000, theta=5000),\n",
226 | " \n",
227 | " ]\n",
228 | "model = LMUModel(1, lmu_params).cuda()\n",
229 | "\n",
230 | "tot_weights = 0\n",
231 | "for p in model.parameters():\n",
232 | " tot_weights += p.numel()\n",
233 | "print(\"Total Weights:\", tot_weights)\n",
234 | "print(model)\n"
235 | ]
236 | },
237 | {
238 | "cell_type": "code",
239 | "execution_count": null,
240 | "metadata": {
241 | "ExecuteTime": {
242 | "end_time": "2021-03-11T17:11:05.011736Z",
243 | "start_time": "2021-03-11T17:03:10.989965Z"
244 | }
245 | },
246 | "outputs": [],
247 | "source": [
248 | "\n",
249 | "loss_func = nn.MSELoss()\n",
250 | "optimizer = torch.optim.Adam(model.parameters(), lr=5e-3)\n",
251 | "epochs = 1\n",
252 | "progress_bar = tqdm(range(int(epochs)), bar_format='{l_bar}{bar:5}{r_bar}{bar:-5b}')\n",
253 | "for e in progress_bar:\n",
254 | " train(model, ttype, seq_length,\n",
255 | " optimizer, loss_func, batch_size=50, loss_buffer_size=100,\n",
256 | " epoch=e, perf_file=join('perf','adding100_lmu_03_weird.csv'),\n",
257 | " prog_bar=progress_bar)"
258 | ]
259 | },
260 | {
261 | "cell_type": "markdown",
262 | "metadata": {},
263 | "source": [
264 | "# T = 500"
265 | ]
266 | },
267 | {
268 | "cell_type": "code",
269 | "execution_count": null,
270 | "metadata": {
271 | "ExecuteTime": {
272 | "end_time": "2021-03-11T18:50:03.354758Z",
273 | "start_time": "2021-03-11T18:50:02.981515Z"
274 | }
275 | },
276 | "outputs": [],
277 | "source": [
278 | "seq_length=500\n",
279 | "\n",
280 | "lmu_params = [dict(input_dim=2, units=25, order=1000, theta=100),\n",
281 | " \n",
282 | " ]\n",
283 | "model = LMUModel(1, lmu_params).cuda()\n",
284 | "\n",
285 | "tot_weights = 0\n",
286 | "for p in model.parameters():\n",
287 | " tot_weights += p.numel()\n",
288 | "print(\"Total Weights:\", tot_weights)\n",
289 | "print(model)\n"
290 | ]
291 | },
292 | {
293 | "cell_type": "code",
294 | "execution_count": null,
295 | "metadata": {
296 | "ExecuteTime": {
297 | "end_time": "2021-03-11T20:16:17.551864Z",
298 | "start_time": "2021-03-11T18:50:08.232460Z"
299 | }
300 | },
301 | "outputs": [],
302 | "source": [
303 | "\n",
304 | "loss_func = nn.MSELoss()\n",
305 | "optimizer = torch.optim.AdamW(model.parameters())\n",
306 | "epochs = 1\n",
307 | "progress_bar = tqdm(range(int(epochs)), bar_format='{l_bar}{bar:5}{r_bar}{bar:-5b}')\n",
308 | "for e in progress_bar:\n",
309 | " train(model, ttype, seq_length,\n",
310 | " optimizer, loss_func, batch_size=50, loss_buffer_size=100,\n",
311 | " epoch=e, perf_file=join('perf','adding500_lmu_04_weird.csv'),\n",
312 | " prog_bar=progress_bar)"
313 | ]
314 | },
315 | {
316 | "cell_type": "markdown",
317 | "metadata": {},
318 | "source": [
319 | "# T = 2000"
320 | ]
321 | },
322 | {
323 | "cell_type": "code",
324 | "execution_count": null,
325 | "metadata": {
326 | "ExecuteTime": {
327 | "end_time": "2021-02-22T20:50:34.258707Z",
328 | "start_time": "2021-02-22T20:50:34.234459Z"
329 | }
330 | },
331 | "outputs": [],
332 | "source": [
333 | "seq_length=2000\n",
334 | "\n",
335 | "lmu_params = [dict(input_dim=2, units=25, order=100, theta=5000),\n",
336 | " \n",
337 | " ]\n",
338 | "model = LMUModel(1, lmu_params).cuda()\n",
339 | "\n",
340 | "tot_weights = 0\n",
341 | "for p in model.parameters():\n",
342 | " tot_weights += p.numel()\n",
343 | "print(\"Total Weights:\", tot_weights)\n",
344 | "print(model)"
345 | ]
346 | },
347 | {
348 | "cell_type": "code",
349 | "execution_count": null,
350 | "metadata": {
351 | "ExecuteTime": {
352 | "end_time": "2021-02-23T02:23:42.992553Z",
353 | "start_time": "2021-02-22T20:50:34.259745Z"
354 | }
355 | },
356 | "outputs": [],
357 | "source": [
358 | "\n",
359 | "loss_func = nn.MSELoss()\n",
360 | "optimizer = torch.optim.Adam(model.parameters(), lr=5e-3)\n",
361 | "epochs = 1\n",
362 | "progress_bar = tqdm(range(int(epochs)), bar_format='{l_bar}{bar:5}{r_bar}{bar:-5b}')\n",
363 | "for e in progress_bar:\n",
364 | " train(model, ttype, seq_length,\n",
365 | " optimizer, loss_func, batch_size=50, loss_buffer_size=100,\n",
366 | " epoch=e, perf_file=join('perf','adding2000_lmu_01_weird.csv'),\n",
367 | " prog_bar=progress_bar)"
368 | ]
369 | },
370 | {
371 | "cell_type": "markdown",
372 | "metadata": {},
373 | "source": [
374 | "# T = 5000"
375 | ]
376 | },
377 | {
378 | "cell_type": "code",
379 | "execution_count": null,
380 | "metadata": {
381 | "ExecuteTime": {
382 | "end_time": "2021-02-23T02:23:42.998412Z",
383 | "start_time": "2021-02-23T02:23:42.993559Z"
384 | }
385 | },
386 | "outputs": [],
387 | "source": [
388 | "seq_length=5000\n",
389 | "\n",
390 | "lmu_params = [dict(input_dim=2, units=25, order=100, theta=5000),\n",
391 | " \n",
392 | " ]\n",
393 | "model = LMUModel(1, lmu_params).cuda()\n",
394 | "\n",
395 | "tot_weights = 0\n",
396 | "for p in model.parameters():\n",
397 | " tot_weights += p.numel()\n",
398 | "print(\"Total Weights:\", tot_weights)\n",
399 | "print(model)"
400 | ]
401 | },
402 | {
403 | "cell_type": "code",
404 | "execution_count": null,
405 | "metadata": {
406 | "ExecuteTime": {
407 | "end_time": "2021-02-23T02:58:44.760408Z",
408 | "start_time": "2021-02-23T02:23:42.999179Z"
409 | }
410 | },
411 | "outputs": [],
412 | "source": [
413 | "\n",
414 | "loss_func = nn.MSELoss()\n",
415 | "optimizer = torch.optim.Adam(model.parameters(), lr=5e-3)\n",
416 | "epochs = 1\n",
417 | "progress_bar = tqdm(range(int(epochs)), bar_format='{l_bar}{bar:5}{r_bar}{bar:-5b}')\n",
418 | "for e in progress_bar:\n",
419 | " train(model, ttype, seq_length,\n",
420 | " optimizer, loss_func, batch_size=50, loss_buffer_size=100,\n",
421 | " epoch=e, perf_file=join('perf','adding5000_lmu_01_weird.csv'),\n",
422 | " prog_bar=progress_bar)"
423 | ]
424 | }
425 | ],
426 | "metadata": {
427 | "kernelspec": {
428 | "display_name": "Python 3",
429 | "language": "python",
430 | "name": "python3"
431 | },
432 | "language_info": {
433 | "codemirror_mode": {
434 | "name": "ipython",
435 | "version": 3
436 | },
437 | "file_extension": ".py",
438 | "mimetype": "text/x-python",
439 | "name": "python",
440 | "nbconvert_exporter": "python",
441 | "pygments_lexer": "ipython3",
442 | "version": "3.6.10"
443 | },
444 | "toc": {
445 | "nav_menu": {
446 | "height": "141px",
447 | "width": "160px"
448 | },
449 | "number_sections": true,
450 | "sideBar": true,
451 | "skip_h1_title": false,
452 | "title_cell": "Table of Contents",
453 | "title_sidebar": "Contents",
454 | "toc_cell": false,
455 | "toc_position": {},
456 | "toc_section_display": true,
457 | "toc_window_display": false
458 | }
459 | },
460 | "nbformat": 4,
461 | "nbformat_minor": 4
462 | }
463 |
--------------------------------------------------------------------------------
/experiments/MackeyGlass/MackeyGlass-DeepSITH-increasing_tau.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "code",
5 | "execution_count": null,
6 | "metadata": {
7 | "ExecuteTime": {
8 | "end_time": "2021-04-16T15:10:34.280995Z",
9 | "start_time": "2021-04-16T15:10:34.099264Z"
10 | }
11 | },
12 | "outputs": [],
13 | "source": [
14 | "%matplotlib inline\n"
15 | ]
16 | },
17 | {
18 | "cell_type": "code",
19 | "execution_count": null,
20 | "metadata": {
21 | "ExecuteTime": {
22 | "end_time": "2021-04-16T15:10:34.999978Z",
23 | "start_time": "2021-04-16T15:10:34.333348Z"
24 | }
25 | },
26 | "outputs": [],
27 | "source": [
28 | "import torch\n",
29 | "import matplotlib.pyplot as plt\n",
30 | "import torchvision\n",
31 | "import nengolib\n",
32 | "import numpy as np\n",
33 | "import torch.nn as nn\n",
34 | "import torch.nn.functional as F\n",
35 | "from tqdm import tqdm_notebook\n",
36 | "from deepsith import DeepSITH\n",
37 | "import PIL\n",
38 | "from torch.nn.utils import weight_norm\n",
39 | "\n",
40 | "from os.path import join\n",
41 | "import scipy.special\n",
42 | "import pandas as pd\n",
43 | "import seaborn as sn\n",
44 | "import scipy\n",
45 | "from scipy.spatial.distance import euclidean\n",
46 | "from scipy.interpolate import interp1d\n",
47 | "from tqdm.notebook import tqdm\n",
48 | "import random\n",
49 | "from csv import DictWriter\n",
50 | "# if gpu is to be used\n",
51 | "use_cuda = torch.cuda.is_available()\n",
52 | "FloatTensor = torch.cuda.FloatTensor if use_cuda else torch.FloatTensor\n",
53 | "DoubleTensor = torch.cuda.DoubleTensor if use_cuda else torch.DoubleTensor\n",
54 | "\n",
55 | "IntTensor = torch.cuda.IntTensor if use_cuda else torch.IntTensor\n",
56 | "LongTensor = torch.cuda.LongTensor if use_cuda else torch.LongTensor\n",
57 | "ByteTensor = torch.cuda.ByteTensor if use_cuda else torch.ByteTensor\n",
58 | "ttype = FloatTensor\n",
59 | "\n",
60 | "import seaborn as sns\n",
61 | "print(use_cuda)\n",
62 | "import pickle\n"
63 | ]
64 | },
65 | {
66 | "cell_type": "markdown",
67 | "metadata": {},
68 | "source": [
69 | "# Load Stimuli"
70 | ]
71 | },
72 | {
73 | "cell_type": "code",
74 | "execution_count": null,
75 | "metadata": {
76 | "ExecuteTime": {
77 | "end_time": "2021-04-16T15:10:36.859945Z",
78 | "start_time": "2021-04-16T15:10:36.852930Z"
79 | }
80 | },
81 | "outputs": [],
82 | "source": [
83 | "import collections\n",
84 | "\n",
85 | "def mackey_glass(sample_len=1000, tau=17, delta_t=10, seed=None, n_samples=1):\n",
86 | " # Adapted from https://github.com/mila-iqia/summerschool2015/blob/master/rnn_tutorial/synthetic.py\n",
87 | " '''\n",
88 | " mackey_glass(sample_len=1000, tau=17, seed = None, n_samples = 1) -> input\n",
89 | " Generate the Mackey Glass time-series. Parameters are:\n",
90 | " - sample_len: length of the time-series in timesteps. Default is 1000.\n",
91 | " - tau: delay of the MG - system. Commonly used values are tau=17 (mild \n",
92 | " chaos) and tau=30 (moderate chaos). Default is 17.\n",
93 | " - seed: to seed the random generator, can be used to generate the same\n",
94 | " timeseries at each invocation.\n",
95 | " - n_samples : number of samples to generate\n",
96 | " '''\n",
97 | " history_len = tau * delta_t \n",
98 | " # Initial conditions for the history of the system\n",
99 | " timeseries = 1.2\n",
100 | " \n",
101 | " if seed is not None:\n",
102 | " np.random.seed(seed)\n",
103 | "\n",
104 | " samples = []\n",
105 | "\n",
106 | " for _ in range(n_samples):\n",
107 | " history = collections.deque(1.2 * np.ones(history_len) + 0.2 * \\\n",
108 | " (np.random.rand(history_len) - 0.5))\n",
109 | " # Preallocate the array for the time-series\n",
110 | " inp = np.zeros((sample_len,1))\n",
111 | " \n",
112 | " for timestep in range(sample_len):\n",
113 | " for _ in range(delta_t):\n",
114 | " xtau = history.popleft()\n",
115 | " history.append(timeseries)\n",
116 | " timeseries = history[-1] + (0.2 * xtau / (1.0 + xtau ** 10) - \\\n",
117 | " 0.1 * history[-1]) / delta_t\n",
118 | " inp[timestep] = timeseries\n",
119 | " \n",
120 | " # Squash timeseries through tanh\n",
121 | " inp = np.tanh(inp - 1)\n",
122 | " samples.append(inp)\n",
123 | " return samples\n",
124 | "\n",
125 | "\n",
126 | "def generate_data(n_batches, length, split=0.5, seed=0,\n",
127 | " predict_length=15, tau=17, washout=100, delta_t=1,\n",
128 | " center=True):\n",
129 | " X = np.asarray(mackey_glass(\n",
130 | " sample_len=length+predict_length+washout, tau=tau,\n",
131 | " seed=seed, n_samples=n_batches))\n",
132 | " X = X[:, washout:, :]\n",
133 | " cutoff = int(split*n_batches)\n",
134 | " if center:\n",
135 | " X -= np.mean(X) # global mean over all batches, approx -0.066\n",
136 | " Y = X[:, predict_length:, :]\n",
137 | " X = X[:, :-predict_length, :]\n",
138 | " assert X.shape == Y.shape\n",
139 | " return ((X[:cutoff], Y[:cutoff]),\n",
140 | " (X[cutoff:], Y[cutoff:]))"
141 | ]
142 | },
143 | {
144 | "cell_type": "code",
145 | "execution_count": null,
146 | "metadata": {
147 | "ExecuteTime": {
148 | "end_time": "2021-04-16T15:10:38.927028Z",
149 | "start_time": "2021-04-16T15:10:37.026625Z"
150 | }
151 | },
152 | "outputs": [],
153 | "source": [
154 | "(train_X, train_Y), (test_X, test_Y) = generate_data(2, 5000)\n",
155 | "\n",
156 | "dataset = torch.utils.data.TensorDataset(torch.Tensor(train_X).cuda(), torch.Tensor(train_Y).cuda())\n",
157 | "dataset = torch.utils.data.DataLoader(dataset, batch_size=32, shuffle=True)\n",
158 | "\n",
159 | "dataset_valid = torch.utils.data.TensorDataset(torch.Tensor(test_X).cuda(), torch.Tensor(test_Y).cuda())\n",
160 | "dataset_valid = torch.utils.data.DataLoader(dataset_valid, batch_size=64, shuffle=False)\n"
161 | ]
162 | },
163 | {
164 | "cell_type": "markdown",
165 | "metadata": {},
166 | "source": [
167 | "## Setup for Model"
168 | ]
169 | },
170 | {
171 | "cell_type": "code",
172 | "execution_count": null,
173 | "metadata": {
174 | "ExecuteTime": {
175 | "end_time": "2021-04-16T15:10:39.071373Z",
176 | "start_time": "2021-04-16T15:10:39.063007Z"
177 | }
178 | },
179 | "outputs": [],
180 | "source": [
181 | "\n",
182 | "def train(model, ttype, train_loader, test_loader, optimizer, loss_func, epoch, perf_file,\n",
183 | " loss_buffer_size=800, batch_size=4, device='cuda',\n",
184 | " prog_bar=None, tau=17, pred_len=15):\n",
185 | " \n",
186 | " assert(loss_buffer_size%batch_size==0)\n",
187 | " \n",
188 | " losses = []\n",
189 | " last_test_perf = 0\n",
190 | " best_test_perf = -1\n",
191 | " \n",
192 | " for batch_idx, (data, target) in enumerate(train_loader):\n",
193 | " model.train()\n",
194 | " data = data.to(device).view(data.shape[0],1,1,-1)\n",
195 | " target = target.to(device)\n",
196 | " optimizer.zero_grad()\n",
197 | " out = model(data)\n",
198 | " loss = loss_func(out,\n",
199 | " target)\n",
200 | "\n",
201 | " loss.backward()\n",
202 | " optimizer.step()\n",
203 | "\n",
204 | " losses.append(loss.detach().cpu().numpy())\n",
205 | " losses = losses[int(-loss_buffer_size/batch_size):]\n",
206 | " \n",
207 | " if ((batch_idx*batch_size)%loss_buffer_size == 0):\n",
208 | " loss_track = {}\n",
209 | " last_test_perf = test_model(model, 'cuda', test_loader, \n",
210 | " )\n",
211 | " loss_track['avg_loss'] = np.mean(losses)\n",
212 | " loss_track['last_test'] = last_test_perf\n",
213 | " loss_track['epoch'] = epoch\n",
214 | " loss_track['batch_idx'] = batch_idx\n",
215 | " loss_track['tau'] = tau\n",
216 | " loss_track['pred_len'] = pred_len\n",
217 | " with open(perf_file, 'a+') as fp:\n",
218 | " csv_writer = DictWriter(fp, fieldnames=list(loss_track.keys()))\n",
219 | " if fp.tell() == 0:\n",
220 | " csv_writer.writeheader()\n",
221 | " csv_writer.writerow(loss_track)\n",
222 | " fp.flush()\n",
223 | " if best_test_perf > last_test_perf:\n",
224 | " torch.save(model.state_dict(), perf_file[:-4]+\".pt\")\n",
225 | " best_test_perf = last_test_perf\n",
226 | " if not (prog_bar is None):\n",
227 | " # Update progress_bar\n",
228 | " s = \"{}:{} Loss: {:.4f}, valid: {:.4f}\"\n",
229 | " format_list = [e,batch_idx*batch_size, np.mean(losses), \n",
230 | " last_test_perf] \n",
231 | " s = s.format(*format_list)\n",
232 | " prog_bar.set_description(s)\n",
233 | " \n",
234 | "def test_model(model, device, test_loader):\n",
235 | " # Test the Model\n",
236 | " nrmsd = []\n",
237 | " with torch.no_grad():\n",
238 | " for x, y in test_loader:\n",
239 | " data = x.to(device).view(x.shape[0],1,1,-1)\n",
240 | " target = y.to(device)\n",
241 | " optimizer.zero_grad()\n",
242 | " out = model(data)\n",
243 | " nrmsd.append(nengolib.signal.nrmse(out.detach().cpu().numpy().flatten(), \n",
244 | " target=target.detach().cpu().numpy().flatten()))\n",
245 | "\n",
246 | " perf = np.array(nrmsd).mean()\n",
247 | " return perf"
248 | ]
249 | },
250 | {
251 | "cell_type": "code",
252 | "execution_count": null,
253 | "metadata": {
254 | "ExecuteTime": {
255 | "end_time": "2021-04-16T15:10:39.230440Z",
256 | "start_time": "2021-04-16T15:10:39.227260Z"
257 | }
258 | },
259 | "outputs": [],
260 | "source": [
261 | "\n",
262 | " \n",
263 | "class DeepSITH_Tracker(nn.Module):\n",
264 | " def __init__(self, out, layer_params, dropout=.5):\n",
265 | " super(DeepSITH_Tracker, self).__init__()\n",
266 | " last_hidden = layer_params[-1]['hidden_size']\n",
267 | " self.hs = DeepSITH(layer_params=layer_params, dropout=dropout)\n",
268 | " self.to_out = nn.Linear(last_hidden, out)\n",
269 | " def forward(self, inp):\n",
270 | " x = self.hs(inp)\n",
271 | " #x = torch.tanh(self.to_out(x))\n",
272 | " x = self.to_out(x)\n",
273 | " return x"
274 | ]
275 | },
276 | {
277 | "cell_type": "code",
278 | "execution_count": null,
279 | "metadata": {
280 | "ExecuteTime": {
281 | "end_time": "2021-04-16T15:11:06.128309Z",
282 | "start_time": "2021-04-16T15:10:59.478750Z"
283 | }
284 | },
285 | "outputs": [],
286 | "source": [
287 | "start_tau = 17\n",
288 | "start_pd = 15\n",
289 | "diffs = [1, 2, 3, 4, 5]\n",
290 | "for diff in diffs:\n",
291 | " print(start_tau*diff, start_pd*diff)\n",
292 | " (train_X, train_Y), (test_X, test_Y) = generate_data(128, 5000, \n",
293 | " tau=start_tau*diff, \n",
294 | " predict_length=start_pd*diff)\n",
295 | "\n",
296 | " dataset = torch.utils.data.TensorDataset(torch.Tensor(train_X).cuda(), torch.Tensor(train_Y).cuda())\n",
297 | " dataset = torch.utils.data.DataLoader(dataset, batch_size=32, shuffle=True)\n",
298 | "\n",
299 | " dataset_valid = torch.utils.data.TensorDataset(torch.Tensor(test_X).cuda(), torch.Tensor(test_Y).cuda())\n",
300 | " dataset_valid = torch.utils.data.DataLoader(dataset_valid, batch_size=64, shuffle=False)\n",
301 | " \n",
302 | " sith_params1 = {\"in_features\":1, \n",
303 | " \"tau_min\":1, \"tau_max\":25.0, \n",
304 | " \"k\":15, 'dt':1,\n",
305 | " \"ntau\":8, 'g':0.0, \n",
306 | " \"ttype\":ttype, 'batch_norm':False,\n",
307 | " \"hidden_size\":25, \"act_func\":nn.ReLU()}\n",
308 | " sith_params2 = {\"in_features\":sith_params1['hidden_size'], \n",
309 | " \"tau_min\":1, \"tau_max\":50.0, \n",
310 | " \"k\":8, 'dt':1,\n",
311 | " \"ntau\":8, 'g':0.0, \n",
312 | " \"ttype\":ttype, 'batch_norm':False,\n",
313 | " \"hidden_size\":25, \"act_func\":nn.ReLU()}\n",
314 | " sith_params3 = {\"in_features\":sith_params2['hidden_size'], \n",
315 | " \"tau_min\":1, \"tau_max\":150.0, 'buff_max':600, \n",
316 | " \"k\":4, 'dt':1,\n",
317 | " \"ntau\":8, 'g':0.0, \n",
318 | " \"ttype\":ttype, 'batch_norm':False,\n",
319 | " \"hidden_size\":25, \"act_func\":nn.ReLU()}\n",
320 | " layer_params = [sith_params1, sith_params2, sith_params3]\n",
321 | "\n",
322 | " model = DeepSITH_Tracker(out=1,\n",
323 | " layer_params=layer_params, \n",
324 | " dropout=0.).cuda()\n",
325 | " optimizer = torch.optim.Adam(model.parameters(), lr=.004)\n",
326 | " loss_func = nn.MSELoss()\n",
327 | "\n",
328 | " tot = 0\n",
329 | " for p in model.parameters():\n",
330 | " tot += p.numel()\n",
331 | " print(\"tot_weights\", tot)\n",
332 | " #print(model)\n",
333 | " \n",
334 | " epochs = 1000\n",
335 | " batch_size = 32\n",
336 | " progress_bar = tqdm(range(int(epochs)), bar_format='{l_bar}{bar:5}{r_bar}{bar:-5b}')\n",
337 | " last_perf = 1000\n",
338 | " for e in progress_bar:\n",
339 | " train(model, ttype, dataset, dataset_valid, \n",
340 | " optimizer, loss_func, batch_size=batch_size, loss_buffer_size=64,\n",
341 | " epoch=e, perf_file=join('perf','mackeyglass_deepsith_ratio_1.csv'),\n",
342 | " prog_bar=progress_bar, tau=start_tau*diff, pred_len=start_pd*diff)"
343 | ]
344 | }
345 | ],
346 | "metadata": {
347 | "kernelspec": {
348 | "display_name": "Python 3",
349 | "language": "python",
350 | "name": "python3"
351 | },
352 | "language_info": {
353 | "codemirror_mode": {
354 | "name": "ipython",
355 | "version": 3
356 | },
357 | "file_extension": ".py",
358 | "mimetype": "text/x-python",
359 | "name": "python",
360 | "nbconvert_exporter": "python",
361 | "pygments_lexer": "ipython3",
362 | "version": "3.6.10"
363 | },
364 | "toc": {
365 | "nav_menu": {},
366 | "number_sections": true,
367 | "sideBar": true,
368 | "skip_h1_title": false,
369 | "title_cell": "Table of Contents",
370 | "title_sidebar": "Contents",
371 | "toc_cell": false,
372 | "toc_position": {},
373 | "toc_section_display": true,
374 | "toc_window_display": false
375 | }
376 | },
377 | "nbformat": 4,
378 | "nbformat_minor": 4
379 | }
380 |
--------------------------------------------------------------------------------
/experiments/MackeyGlass/MackeyGlass-coRNN-increasing_tau.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "code",
5 | "execution_count": null,
6 | "metadata": {
7 | "ExecuteTime": {
8 | "end_time": "2021-01-31T16:13:20.250852Z",
9 | "start_time": "2021-01-31T16:13:20.076923Z"
10 | }
11 | },
12 | "outputs": [],
13 | "source": [
14 | "%matplotlib inline"
15 | ]
16 | },
17 | {
18 | "cell_type": "code",
19 | "execution_count": null,
20 | "metadata": {
21 | "ExecuteTime": {
22 | "end_time": "2021-01-31T16:13:20.997573Z",
23 | "start_time": "2021-01-31T16:13:20.311754Z"
24 | }
25 | },
26 | "outputs": [],
27 | "source": [
28 | "import torch\n",
29 | "import matplotlib.pyplot as plt\n",
30 | "import torchvision\n",
31 | "import numpy as np\n",
32 | "import torch.nn as nn\n",
33 | "import torch.nn.functional as F\n",
34 | "from tqdm import tqdm_notebook\n",
35 | "import PIL\n",
36 | "from torch.nn.utils import weight_norm\n",
37 | "from torch.autograd import Variable\n",
38 | "\n",
39 | "device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')\n",
40 | "\n",
41 | "from os.path import join\n",
42 | "import scipy.special\n",
43 | "import pandas as pd\n",
44 | "import seaborn as sn\n",
45 | "import scipy\n",
46 | "from scipy.spatial.distance import euclidean\n",
47 | "from scipy.interpolate import interp1d\n",
48 | "from tqdm.notebook import tqdm\n",
49 | "import random\n",
50 | "from csv import DictWriter\n",
51 | "import nengolib\n",
52 | "# if gpu is to be used\n",
53 | "use_cuda = torch.cuda.is_available()\n",
54 | "FloatTensor = torch.cuda.FloatTensor if use_cuda else torch.FloatTensor\n",
55 | "DoubleTensor = torch.cuda.DoubleTensor if use_cuda else torch.DoubleTensor\n",
56 | "\n",
57 | "IntTensor = torch.cuda.IntTensor if use_cuda else torch.IntTensor\n",
58 | "LongTensor = torch.cuda.LongTensor if use_cuda else torch.LongTensor\n",
59 | "ByteTensor = torch.cuda.ByteTensor if use_cuda else torch.ByteTensor\n",
60 | "ttype = FloatTensor\n",
61 | "\n",
62 | "import seaborn as sns\n",
63 | "print(use_cuda)\n",
64 | "import pickle\n"
65 | ]
66 | },
67 | {
68 | "cell_type": "code",
69 | "execution_count": null,
70 | "metadata": {
71 | "ExecuteTime": {
72 | "end_time": "2021-01-31T16:13:21.000738Z",
73 | "start_time": "2021-01-31T16:13:20.998684Z"
74 | }
75 | },
76 | "outputs": [],
77 | "source": [
78 | "sn.set_context(\"poster\")"
79 | ]
80 | },
81 | {
82 | "cell_type": "code",
83 | "execution_count": null,
84 | "metadata": {
85 | "ExecuteTime": {
86 | "end_time": "2021-01-31T16:13:21.016467Z",
87 | "start_time": "2021-01-31T16:13:21.001752Z"
88 | }
89 | },
90 | "outputs": [],
91 | "source": [
92 | "\n",
93 | "class coRNNCell(nn.Module):\n",
94 | " def __init__(self, n_inp, n_hid, dt, gamma, epsilon):\n",
95 | " super(coRNNCell, self).__init__()\n",
96 | " self.dt = dt\n",
97 | " self.gamma = gamma\n",
98 | " self.epsilon = epsilon\n",
99 | " self.i2h = nn.Linear(n_inp + n_hid + n_hid, n_hid)\n",
100 | "\n",
101 | " def forward(self,x,hy,hz):\n",
102 | " hz = hz + self.dt * (torch.tanh(self.i2h(torch.cat((x, hz, hy),1)))\n",
103 | " - self.gamma * hy - self.epsilon * hz)\n",
104 | " hy = hy + self.dt * hz\n",
105 | "\n",
106 | " return hy, hz\n",
107 | "\n",
108 | "class coRNN(nn.Module):\n",
109 | " def __init__(self, n_inp, n_hid, n_out, dt, gamma, epsilon):\n",
110 | " super(coRNN, self).__init__()\n",
111 | " self.n_hid = n_hid\n",
112 | " self.cell = coRNNCell(n_inp,n_hid,dt,gamma,epsilon)\n",
113 | " self.readout = nn.Linear(n_hid, n_out)\n",
114 | "\n",
115 | " def forward(self, x):\n",
116 | " outputs = []\n",
117 | " ## initialize hidden states\n",
118 | " hy = Variable(torch.zeros(x.size(1),self.n_hid)).to(device)\n",
119 | " hz = Variable(torch.zeros(x.size(1),self.n_hid)).to(device)\n",
120 | "\n",
121 | " for t in range(x.size(0)):\n",
122 | " hy, hz = self.cell(x[t],hy,hz)\n",
123 | " outputs.append(hy)\n",
124 | " outputs = torch.stack(outputs)\n",
125 | " \n",
126 | " output = self.readout(outputs)\n",
127 | "\n",
128 | " return output"
129 | ]
130 | },
131 | {
132 | "cell_type": "markdown",
133 | "metadata": {},
134 | "source": [
135 | "# Load Stimuli"
136 | ]
137 | },
138 | {
139 | "cell_type": "code",
140 | "execution_count": null,
141 | "metadata": {
142 | "ExecuteTime": {
143 | "end_time": "2021-01-31T16:13:25.543469Z",
144 | "start_time": "2021-01-31T16:13:25.536756Z"
145 | }
146 | },
147 | "outputs": [],
148 | "source": [
149 | "import collections\n",
150 | "\n",
151 | "def mackey_glass(sample_len=1000, tau=17, delta_t=10, seed=None, n_samples=1):\n",
152 | " # Adapted from https://github.com/mila-iqia/summerschool2015/blob/master/rnn_tutorial/synthetic.py\n",
153 | " '''\n",
154 | " mackey_glass(sample_len=1000, tau=17, seed = None, n_samples = 1) -> input\n",
155 | " Generate the Mackey Glass time-series. Parameters are:\n",
156 | " - sample_len: length of the time-series in timesteps. Default is 1000.\n",
157 | " - tau: delay of the MG - system. Commonly used values are tau=17 (mild \n",
158 | " chaos) and tau=30 (moderate chaos). Default is 17.\n",
159 | " - seed: to seed the random generator, can be used to generate the same\n",
160 | " timeseries at each invocation.\n",
161 | " - n_samples : number of samples to generate\n",
162 | " '''\n",
163 | " history_len = tau * delta_t \n",
164 | " # Initial conditions for the history of the system\n",
165 | " timeseries = 1.2\n",
166 | " \n",
167 | " if seed is not None:\n",
168 | " np.random.seed(seed)\n",
169 | "\n",
170 | " samples = []\n",
171 | "\n",
172 | " for _ in range(n_samples):\n",
173 | " history = collections.deque(1.2 * np.ones(history_len) + 0.2 * \\\n",
174 | " (np.random.rand(history_len) - 0.5))\n",
175 | " # Preallocate the array for the time-series\n",
176 | " inp = np.zeros((sample_len,1))\n",
177 | " \n",
178 | " for timestep in range(sample_len):\n",
179 | " for _ in range(delta_t):\n",
180 | " xtau = history.popleft()\n",
181 | " history.append(timeseries)\n",
182 | " timeseries = history[-1] + (0.2 * xtau / (1.0 + xtau ** 10) - \\\n",
183 | " 0.1 * history[-1]) / delta_t\n",
184 | " inp[timestep] = timeseries\n",
185 | " \n",
186 | " # Squash timeseries through tanh\n",
187 | " inp = np.tanh(inp - 1)\n",
188 | " samples.append(inp)\n",
189 | " return samples\n",
190 | "\n",
191 | "\n",
192 | "def generate_data(n_batches, length, split=0.5, seed=0,\n",
193 | " predict_length=15, tau=17, washout=100, delta_t=1,\n",
194 | " center=True):\n",
195 | " X = np.asarray(mackey_glass(\n",
196 | " sample_len=length+predict_length+washout, tau=tau,\n",
197 | " seed=seed, n_samples=n_batches))\n",
198 | " X = X[:, washout:, :]\n",
199 | " cutoff = int(split*n_batches)\n",
200 | " if center:\n",
201 | " X -= np.mean(X) # global mean over all batches, approx -0.066\n",
202 | " Y = X[:, predict_length:, :]\n",
203 | " X = X[:, :-predict_length, :]\n",
204 | " assert X.shape == Y.shape\n",
205 | " return ((X[:cutoff], Y[:cutoff]),\n",
206 | " (X[cutoff:], Y[cutoff:]))"
207 | ]
208 | },
209 | {
210 | "cell_type": "code",
211 | "execution_count": null,
212 | "metadata": {
213 | "ExecuteTime": {
214 | "end_time": "2020-11-23T17:56:09.308282Z",
215 | "start_time": "2020-11-23T17:56:00.723134Z"
216 | }
217 | },
218 | "outputs": [],
219 | "source": [
220 | "(train_X, train_Y), (test_X, test_Y) = generate_data(128, 5000)\n",
221 | "dataset = torch.utils.data.TensorDataset(torch.Tensor(train_X).cuda(), torch.Tensor(train_Y).cuda())\n",
222 | "dataset = torch.utils.data.DataLoader(dataset, batch_size=32, shuffle=True)\n",
223 | "\n",
224 | "dataset_valid = torch.utils.data.TensorDataset(torch.Tensor(test_X).cuda(), torch.Tensor(test_Y).cuda())\n",
225 | "dataset_valid = torch.utils.data.DataLoader(dataset_valid, batch_size=64, shuffle=False)\n",
226 | "\n",
227 | "print(train_X.shape, train_Y.shape, test_X.shape)"
228 | ]
229 | },
230 | {
231 | "cell_type": "markdown",
232 | "metadata": {},
233 | "source": [
234 | "## Setup for Model"
235 | ]
236 | },
237 | {
238 | "cell_type": "code",
239 | "execution_count": null,
240 | "metadata": {
241 | "ExecuteTime": {
242 | "end_time": "2021-01-31T16:13:46.359075Z",
243 | "start_time": "2021-01-31T16:13:46.346791Z"
244 | }
245 | },
246 | "outputs": [],
247 | "source": [
248 | "\n",
249 | "def train(model, ttype, train_loader, test_loader, optimizer, loss_func, epoch, perf_file,\n",
250 | " loss_buffer_size=800, batch_size=4, device='cuda',\n",
251 | " prog_bar=None, tau=17, pred_len=15):\n",
252 | " \n",
253 | " assert(loss_buffer_size%batch_size==0)\n",
254 | " \n",
255 | " losses = []\n",
256 | " last_test_perf = 0\n",
257 | " best_test_perf = 1000\n",
258 | " \n",
259 | " for batch_idx, (data, target) in enumerate(train_loader):\n",
260 | " model.train()\n",
261 | " data = data.to(device).transpose(1,0)\n",
262 | " target = target.to(device)\n",
263 | " optimizer.zero_grad()\n",
264 | " out = model(data).transpose(1,0)\n",
265 | " loss = loss_func(out,\n",
266 | " target)\n",
267 | "\n",
268 | " loss.backward()\n",
269 | " optimizer.step()\n",
270 | "\n",
271 | " losses.append(loss.detach().cpu().numpy())\n",
272 | " losses = losses[int(-loss_buffer_size/batch_size):]\n",
273 | " \n",
274 | " if ((batch_idx*batch_size)%loss_buffer_size == 0):\n",
275 | " loss_track = {}\n",
276 | " last_test_perf = test_model(model, 'cuda', test_loader, \n",
277 | " )\n",
278 | " loss_track['avg_loss'] = np.mean(losses)\n",
279 | " loss_track['last_test'] = last_test_perf\n",
280 | " loss_track['epoch'] = epoch\n",
281 | " loss_track['batch_idx'] = batch_idx\n",
282 | " loss_track['tau'] = tau\n",
283 | " loss_track['pred_len'] = pred_len\n",
284 | " with open(perf_file, 'a+') as fp:\n",
285 | " csv_writer = DictWriter(fp, fieldnames=list(loss_track.keys()))\n",
286 | " if fp.tell() == 0:\n",
287 | " csv_writer.writeheader()\n",
288 | " csv_writer.writerow(loss_track)\n",
289 | " fp.flush()\n",
290 | " if best_test_perf > last_test_perf:\n",
291 | " torch.save(model.state_dict(), perf_file[:-4]+\".pt\")\n",
292 | " best_test_perf = last_test_perf\n",
293 | " if not (prog_bar is None):\n",
294 | " # Update progress_bar\n",
295 | " s = \"{}:{} Loss: {:.4f},valid: {:.4f}\"\n",
296 | " format_list = [e,batch_idx*batch_size, np.mean(losses), \n",
297 | " last_test_perf] \n",
298 | " s = s.format(*format_list)\n",
299 | " prog_bar.set_description(s)\n",
300 | " \n",
301 | "def test_model(model, device, test_loader):\n",
302 | " # Test the Model\n",
303 | " nrmsd = []\n",
304 | " with torch.no_grad():\n",
305 | " for x, y in test_loader:\n",
306 | " data = x.to(device).transpose(1,0)\n",
307 | " target = y.to(device)\n",
308 | " optimizer.zero_grad()\n",
309 | " out = model(data).transpose(1,0)\n",
310 | " nrmsd.append(nengolib.signal.nrmse(out.detach().cpu().numpy().flatten(), \n",
311 | " target=target.detach().cpu().numpy().flatten()))\n",
312 | "\n",
313 | " perf = np.array(nrmsd).mean()\n",
314 | " return perf"
315 | ]
316 | },
317 | {
318 | "cell_type": "code",
319 | "execution_count": null,
320 | "metadata": {
321 | "ExecuteTime": {
322 | "end_time": "2021-01-31T22:27:47.613782Z",
323 | "start_time": "2021-01-31T16:15:15.395163Z"
324 | }
325 | },
326 | "outputs": [],
327 | "source": [
328 | "start_tau = 17\n",
329 | "start_pd = 15\n",
330 | "diffs = [1, 2, 3, 4, 5]\n",
331 | "for diff in diffs:\n",
332 | " (train_X, train_Y), (test_X, test_Y) = generate_data(128, 5000, tau=start_tau*diff, \n",
333 | " predict_length=start_pd*diff)#tau)\n",
334 | " dataset = torch.utils.data.TensorDataset(torch.Tensor(train_X).cuda(), torch.Tensor(train_Y).cuda())\n",
335 | " dataset = torch.utils.data.DataLoader(dataset, batch_size=32, shuffle=True)\n",
336 | "\n",
337 | " dataset_valid = torch.utils.data.TensorDataset(torch.Tensor(test_X).cuda(), torch.Tensor(test_Y).cuda())\n",
338 | " dataset_valid = torch.utils.data.DataLoader(dataset_valid, batch_size=64, shuffle=False)\n",
339 | " \n",
340 | " cornn_params = dict(n_inp=1,\n",
341 | " n_hid=128, \n",
342 | " n_out=1,\n",
343 | " dt=1.6e-2,\n",
344 | " gamma=94.5,\n",
345 | " epsilon=9.5)\n",
346 | " model = coRNN(**cornn_params).cuda()\n",
347 | " \n",
348 | " optimizer = torch.optim.Adam(model.parameters(), lr=2e-2)\n",
349 | " loss_func = nn.MSELoss()\n",
350 | "\n",
351 | " epochs = 1000\n",
352 | " batch_size = 32\n",
353 | " progress_bar = tqdm(range(int(epochs)), bar_format='{l_bar}{bar:5}{r_bar}{bar:-5b}')\n",
354 | " last_perf = 1000\n",
355 | " for e in progress_bar:\n",
356 | " train(model, ttype, dataset, dataset_valid, \n",
357 | " optimizer, loss_func, batch_size=batch_size, loss_buffer_size=64,\n",
358 | " epoch=e, perf_file=join('perf','mackeyglass_coRNN_ratio_1.csv'),\n",
359 | " prog_bar=progress_bar, tau=start_tau*diff, pred_len=start_pd*diff)"
360 | ]
361 | }
362 | ],
363 | "metadata": {
364 | "kernelspec": {
365 | "display_name": "Python 3",
366 | "language": "python",
367 | "name": "python3"
368 | },
369 | "language_info": {
370 | "codemirror_mode": {
371 | "name": "ipython",
372 | "version": 3
373 | },
374 | "file_extension": ".py",
375 | "mimetype": "text/x-python",
376 | "name": "python",
377 | "nbconvert_exporter": "python",
378 | "pygments_lexer": "ipython3",
379 | "version": "3.6.10"
380 | },
381 | "toc": {
382 | "nav_menu": {},
383 | "number_sections": true,
384 | "sideBar": true,
385 | "skip_h1_title": false,
386 | "title_cell": "Table of Contents",
387 | "title_sidebar": "Contents",
388 | "toc_cell": false,
389 | "toc_position": {},
390 | "toc_section_display": true,
391 | "toc_window_display": false
392 | }
393 | },
394 | "nbformat": 4,
395 | "nbformat_minor": 4
396 | }
397 |
--------------------------------------------------------------------------------
/experiments/AddingProblem/addingproblem-coRNN.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "code",
5 | "execution_count": null,
6 | "metadata": {
7 | "ExecuteTime": {
8 | "end_time": "2021-02-22T16:53:09.063623Z",
9 | "start_time": "2021-02-22T16:53:08.886901Z"
10 | }
11 | },
12 | "outputs": [],
13 | "source": [
14 | "%matplotlib inline"
15 | ]
16 | },
17 | {
18 | "cell_type": "code",
19 | "execution_count": null,
20 | "metadata": {
21 | "ExecuteTime": {
22 | "end_time": "2021-02-22T16:53:09.671516Z",
23 | "start_time": "2021-02-22T16:53:09.100488Z"
24 | }
25 | },
26 | "outputs": [],
27 | "source": [
28 | "import matplotlib.pylab as plt\n",
29 | "import torch\n",
30 | "import numpy as np\n",
31 | "import seaborn as sn\n",
32 | "sn.set_context(\"poster\")\n",
33 | "import os\n",
34 | "import torch\n",
35 | "import torch.nn as nn\n",
36 | "import torch.nn.functional as F\n",
37 | "from torch.nn.utils import weight_norm\n",
38 | "from torchvision import transforms, datasets\n",
39 | "\n",
40 | "from cornn import coRNN\n",
41 | "import numpy as np\n",
42 | "import scipy\n",
43 | "import scipy.stats as st\n",
44 | "import scipy.special\n",
45 | "import scipy.signal\n",
46 | "import scipy.interpolate\n",
47 | "\n",
48 | "import pandas as pd\n",
49 | "\n",
50 | "from os.path import join\n",
51 | "import random\n",
52 | "from csv import DictWriter\n",
53 | "\n",
54 | "from tqdm.notebook import tqdm\n",
55 | "import pickle\n",
56 | "# if gpu is to be used\n",
57 | "use_cuda = torch.cuda.is_available()\n",
58 | "print(use_cuda)\n",
59 | "FloatTensor = torch.cuda.FloatTensor if use_cuda else torch.FloatTensor\n",
60 | "DoubleTensor = torch.cuda.DoubleTensor if use_cuda else torch.DoubleTensor\n",
61 | "IntTensor = torch.cuda.IntTensor if use_cuda else torch.IntTensor\n",
62 | "LongTensor = torch.cuda.LongTensor if use_cuda else torch.LongTensor\n",
63 | "ByteTensor = torch.cuda.ByteTensor if use_cuda else torch.ByteTensor\n",
64 | "ttype = FloatTensor"
65 | ]
66 | },
67 | {
68 | "cell_type": "code",
69 | "execution_count": null,
70 | "metadata": {
71 | "ExecuteTime": {
72 | "end_time": "2021-02-22T16:53:09.677199Z",
73 | "start_time": "2021-02-22T16:53:09.672636Z"
74 | }
75 | },
76 | "outputs": [],
77 | "source": [
78 | "import torch\n",
79 | "import numpy as np\n",
80 | "\n",
81 | "def get_batch(batch_size, T, ttype):\n",
82 | " values = torch.rand(T, batch_size, requires_grad=False)\n",
83 | " indices = torch.zeros_like(values)\n",
84 | " half = int(T / 2)\n",
85 | " for i in range(batch_size):\n",
86 | " half_1 = np.random.randint(half)\n",
87 | " hals_2 = np.random.randint(half, T)\n",
88 | " indices[half_1, i] = 1\n",
89 | " indices[hals_2, i] = 1\n",
90 | "\n",
91 | " data = torch.stack((values, indices), dim=-1).type(ttype)\n",
92 | " targets = torch.mul(values, indices).sum(dim=0).type(ttype)\n",
93 | " return data, targets\n"
94 | ]
95 | },
96 | {
97 | "cell_type": "code",
98 | "execution_count": null,
99 | "metadata": {
100 | "ExecuteTime": {
101 | "end_time": "2021-02-22T16:53:13.510567Z",
102 | "start_time": "2021-02-22T16:53:13.508399Z"
103 | }
104 | },
105 | "outputs": [],
106 | "source": [
107 | "torch.manual_seed(1111)"
108 | ]
109 | },
110 | {
111 | "cell_type": "code",
112 | "execution_count": null,
113 | "metadata": {
114 | "ExecuteTime": {
115 | "end_time": "2021-02-22T16:53:13.523219Z",
116 | "start_time": "2021-02-22T16:53:13.511392Z"
117 | }
118 | },
119 | "outputs": [],
120 | "source": [
121 | "def train(model, ttype, seq_length, optimizer, loss_func, \n",
122 | " epoch, perf_file, loss_buffer_size=20, batch_size=1, test_size=10,\n",
123 | " device='cuda', prog_bar=None):\n",
124 | " assert(loss_buffer_size%batch_size==0)\n",
125 | "\n",
126 | " losses = []\n",
127 | " perfs = []\n",
128 | " last_test_perf = 0\n",
129 | " for batch_idx in range(20000):\n",
130 | " model.train()\n",
131 | " sig, target = get_batch(batch_size, seq_length, ttype=ttype)\n",
132 | " \n",
133 | " target = target.unsqueeze(1)\n",
134 | " optimizer.zero_grad()\n",
135 | " out = model(sig)\n",
136 | " loss = loss_func(out,\n",
137 | " target)\n",
138 | " \n",
139 | " loss.backward()\n",
140 | " optimizer.step()\n",
141 | "\n",
142 | " losses.append(loss.detach().cpu().numpy())\n",
143 | " losses = losses[-loss_buffer_size:]\n",
144 | " if not (prog_bar is None):\n",
145 | " # Update progress_bar\n",
146 | " s = \"{}:{} Loss: {:.8f}\"\n",
147 | " format_list = [e, int(batch_idx/(50/batch_size)), np.mean(losses)] \n",
148 | " s = s.format(*format_list)\n",
149 | " prog_bar.set_description(s)\n",
150 | " if ((batch_idx*batch_size)%loss_buffer_size == 0) & (batch_idx != 0):\n",
151 | " loss_track = {}\n",
152 | " #last_test_perf = test_norm(model, 'cuda', test_sig, test_class,\n",
153 | " # batch_size=test_size, \n",
154 | " # )\n",
155 | " loss_track['avg_loss'] = np.mean(losses)\n",
156 | " #loss_track['last_test'] = last_test_perf\n",
157 | " loss_track['epoch'] = epoch\n",
158 | " loss_track['batch_idx'] = batch_idx\n",
159 | " with open(perf_file, 'a+') as fp:\n",
160 | " csv_writer = DictWriter(fp, fieldnames=list(loss_track.keys()))\n",
161 | " if fp.tell() == 0:\n",
162 | " csv_writer.writeheader()\n",
163 | " csv_writer.writerow(loss_track)\n",
164 | " fp.flush()\n",
165 | "def test_norm(model, device, seq_length, loss_func, batch_size=100):\n",
166 | " model.eval()\n",
167 | " correct = 0\n",
168 | " count = 0\n",
169 | " with torch.no_grad():\n",
170 | " sig, target = get_batch(batch_size, seq_length, ttype=ttype)\n",
171 | " target = target.unsqueeze(1)\n",
172 | " out = model(sig)\n",
173 | " loss = loss_func(out,\n",
174 | " target)\n",
175 | " return loss"
176 | ]
177 | },
178 | {
179 | "cell_type": "code",
180 | "execution_count": null,
181 | "metadata": {
182 | "ExecuteTime": {
183 | "end_time": "2021-02-22T16:53:13.536036Z",
184 | "start_time": "2021-02-22T16:53:13.524035Z"
185 | }
186 | },
187 | "outputs": [],
188 | "source": [
189 | "from torch import nn\n",
190 | "import torch\n",
191 | "from torch.autograd import Variable\n",
192 | "device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')\n",
193 | "\n",
194 | "class coRNNCell(nn.Module):\n",
195 | " def __init__(self, n_inp, n_hid, dt, gamma, epsilon):\n",
196 | " super(coRNNCell, self).__init__()\n",
197 | " self.dt = dt\n",
198 | " self.gamma = gamma\n",
199 | " self.epsilon = epsilon\n",
200 | " self.i2h = nn.Linear(n_inp + n_hid + n_hid, n_hid)\n",
201 | "\n",
202 | " def forward(self,x,hy,hz):\n",
203 | " hz = hz + self.dt * (torch.tanh(self.i2h(torch.cat((x, hz, hy),1)))\n",
204 | " - self.gamma * hy - self.epsilon * hz)\n",
205 | " hy = hy + self.dt * hz\n",
206 | "\n",
207 | " return hy, hz\n",
208 | "\n",
209 | "class coRNN(nn.Module):\n",
210 | " def __init__(self, n_inp, n_hid, n_out, dt, gamma, epsilon):\n",
211 | " super(coRNN, self).__init__()\n",
212 | " self.n_hid = n_hid\n",
213 | " self.cell = coRNNCell(n_inp,n_hid,dt,gamma,epsilon)\n",
214 | " self.readout = nn.Linear(n_hid, n_out)\n",
215 | "\n",
216 | " def forward(self, x):\n",
217 | " ## initialize hidden states\n",
218 | " hy = Variable(torch.zeros(x.size(1),self.n_hid)).to(device)\n",
219 | " hz = Variable(torch.zeros(x.size(1),self.n_hid)).to(device)\n",
220 | "\n",
221 | " for t in range(x.size(0)):\n",
222 | " hy, hz = self.cell(x[t],hy,hz)\n",
223 | " output = self.readout(hy)\n",
224 | "\n",
225 | " return output"
226 | ]
227 | },
228 | {
229 | "cell_type": "markdown",
230 | "metadata": {},
231 | "source": [
232 | "# T = 100"
233 | ]
234 | },
235 | {
236 | "cell_type": "code",
237 | "execution_count": null,
238 | "metadata": {
239 | "ExecuteTime": {
240 | "end_time": "2021-02-22T16:53:16.576467Z",
241 | "start_time": "2021-02-22T16:53:16.564088Z"
242 | }
243 | },
244 | "outputs": [],
245 | "source": [
246 | "cornn_params = dict(n_inp=2,\n",
247 | " n_hid=128, \n",
248 | " n_out=1,\n",
249 | " dt=1.6e-2,\n",
250 | " gamma=94.5,\n",
251 | " epsilon=9.5)\n",
252 | "model = coRNN(**cornn_params).cuda()\n",
253 | "\n",
254 | "tot_weights = 0\n",
255 | "for p in model.parameters():\n",
256 | " tot_weights += p.numel()\n",
257 | "print(\"Total Weights:\", tot_weights)\n",
258 | "print(model)"
259 | ]
260 | },
261 | {
262 | "cell_type": "code",
263 | "execution_count": null,
264 | "metadata": {
265 | "ExecuteTime": {
266 | "start_time": "2021-02-22T16:53:19.447Z"
267 | }
268 | },
269 | "outputs": [],
270 | "source": [
271 | "seq_length=100\n",
272 | "\n",
273 | "loss_func = nn.MSELoss()\n",
274 | "optimizer = torch.optim.AdamW(model.parameters())\n",
275 | "epochs = 1\n",
276 | "progress_bar = tqdm(range(int(epochs)), bar_format='{l_bar}{bar:5}{r_bar}{bar:-5b}')\n",
277 | "for e in progress_bar:\n",
278 | " train(model, ttype, seq_length,\n",
279 | " optimizer, loss_func, batch_size=50, loss_buffer_size=100,\n",
280 | " epoch=e, perf_file=join('perf','adding100_cornn_big.csv'),\n",
281 | " prog_bar=progress_bar)"
282 | ]
283 | },
284 | {
285 | "cell_type": "markdown",
286 | "metadata": {},
287 | "source": [
288 | "# T = 500"
289 | ]
290 | },
291 | {
292 | "cell_type": "code",
293 | "execution_count": null,
294 | "metadata": {
295 | "ExecuteTime": {
296 | "end_time": "2020-12-15T17:16:54.804717Z",
297 | "start_time": "2020-12-15T17:16:54.799148Z"
298 | }
299 | },
300 | "outputs": [],
301 | "source": [
302 | "cornn_params = dict(n_inp=2,\n",
303 | " n_hid=128, \n",
304 | " n_out=1,\n",
305 | " dt=6e-2,\n",
306 | " gamma=66,\n",
307 | " epsilon=15)\n",
308 | "model = coRNN(**cornn_params).cuda()\n",
309 | "\n",
310 | "tot_weights = 0\n",
311 | "for p in model.parameters():\n",
312 | " tot_weights += p.numel()\n",
313 | "print(\"Total Weights:\", tot_weights)\n",
314 | "print(model)"
315 | ]
316 | },
317 | {
318 | "cell_type": "code",
319 | "execution_count": null,
320 | "metadata": {
321 | "ExecuteTime": {
322 | "start_time": "2020-12-15T17:17:01.879Z"
323 | }
324 | },
325 | "outputs": [],
326 | "source": [
327 | "seq_length=500\n",
328 | "\n",
329 | "loss_func = nn.MSELoss()\n",
330 | "optimizer = torch.optim.Adam(model.parameters(), lr=3e-2)\n",
331 | "epochs = 1\n",
332 | "progress_bar = tqdm(range(int(epochs)), bar_format='{l_bar}{bar:5}{r_bar}{bar:-5b}')\n",
333 | "for e in progress_bar:\n",
334 | " train(model, ttype, seq_length,\n",
335 | " optimizer, loss_func, batch_size=50, loss_buffer_size=100,\n",
336 | " epoch=e, perf_file=join('perf','adding500_cornn_6.csv'),\n",
337 | " prog_bar=progress_bar)"
338 | ]
339 | },
340 | {
341 | "cell_type": "markdown",
342 | "metadata": {},
343 | "source": [
344 | "# T = 2000"
345 | ]
346 | },
347 | {
348 | "cell_type": "code",
349 | "execution_count": null,
350 | "metadata": {
351 | "ExecuteTime": {
352 | "end_time": "2020-12-02T05:15:54.017640Z",
353 | "start_time": "2020-12-02T05:15:54.014072Z"
354 | }
355 | },
356 | "outputs": [],
357 | "source": [
358 | "cornn_params = dict(n_inp=2,\n",
359 | " n_hid=128, \n",
360 | " n_out=1,\n",
361 | " dt=3e-2,\n",
362 | " gamma=80,\n",
363 | " epsilon=12)\n",
364 | "model = coRNN(**cornn_params).cuda()\n",
365 | "\n",
366 | "tot_weights = 0\n",
367 | "for p in model.parameters():\n",
368 | " tot_weights += p.numel()\n",
369 | "print(\"Total Weights:\", tot_weights)\n",
370 | "print(model)"
371 | ]
372 | },
373 | {
374 | "cell_type": "code",
375 | "execution_count": null,
376 | "metadata": {
377 | "ExecuteTime": {
378 | "end_time": "2020-12-02T09:28:58.690751Z",
379 | "start_time": "2020-12-02T05:15:54.018817Z"
380 | }
381 | },
382 | "outputs": [],
383 | "source": [
384 | "seq_length=2000\n",
385 | "\n",
386 | "loss_func = nn.MSELoss()\n",
387 | "optimizer = torch.optim.Adam(model.parameters(), lr=2e-2)\n",
388 | "epochs = 1\n",
389 | "progress_bar = tqdm(range(int(epochs)), bar_format='{l_bar}{bar:5}{r_bar}{bar:-5b}')\n",
390 | "for e in progress_bar:\n",
391 | " train(model, ttype, seq_length,\n",
392 | " optimizer, loss_func, batch_size=50, loss_buffer_size=100,\n",
393 | " epoch=e, perf_file=join('perf','adding2000_cornn_1.csv'),\n",
394 | " prog_bar=progress_bar)"
395 | ]
396 | },
397 | {
398 | "cell_type": "markdown",
399 | "metadata": {},
400 | "source": [
401 | "# T = 5000"
402 | ]
403 | },
404 | {
405 | "cell_type": "code",
406 | "execution_count": null,
407 | "metadata": {
408 | "ExecuteTime": {
409 | "end_time": "2020-12-01T18:39:48.793500Z",
410 | "start_time": "2020-12-01T18:39:48.775403Z"
411 | }
412 | },
413 | "outputs": [],
414 | "source": [
415 | "cornn_params = dict(n_inp=2,\n",
416 | " n_hid=128, \n",
417 | " n_out=1,\n",
418 | " dt=1.6e-2,\n",
419 | " gamma=94.5,\n",
420 | " epsilon=9.5)\n",
421 | "model = coRNN(**cornn_params).cuda()\n",
422 | "\n",
423 | "tot_weights = 0\n",
424 | "for p in model.parameters():\n",
425 | " tot_weights += p.numel()\n",
426 | "print(\"Total Weights:\", tot_weights)\n",
427 | "print(model)"
428 | ]
429 | },
430 | {
431 | "cell_type": "code",
432 | "execution_count": null,
433 | "metadata": {
434 | "ExecuteTime": {
435 | "end_time": "2020-12-02T05:15:54.012718Z",
436 | "start_time": "2020-12-01T18:39:52.788917Z"
437 | }
438 | },
439 | "outputs": [],
440 | "source": [
441 | "seq_length=5000\n",
442 | "\n",
443 | "loss_func = nn.MSELoss()\n",
444 | "optimizer = torch.optim.Adam(model.parameters(), lr=2e-2)\n",
445 | "epochs = 1\n",
446 | "progress_bar = tqdm(range(int(epochs)), bar_format='{l_bar}{bar:5}{r_bar}{bar:-5b}')\n",
447 | "for e in progress_bar:\n",
448 | " train(model, ttype, seq_length,\n",
449 | " optimizer, loss_func, batch_size=50, loss_buffer_size=100,\n",
450 | " epoch=e, perf_file=join('perf','adding5000_cornn_1.csv'),\n",
451 | " prog_bar=progress_bar)"
452 | ]
453 | }
454 | ],
455 | "metadata": {
456 | "kernelspec": {
457 | "display_name": "Python 3",
458 | "language": "python",
459 | "name": "python3"
460 | },
461 | "language_info": {
462 | "codemirror_mode": {
463 | "name": "ipython",
464 | "version": 3
465 | },
466 | "file_extension": ".py",
467 | "mimetype": "text/x-python",
468 | "name": "python",
469 | "nbconvert_exporter": "python",
470 | "pygments_lexer": "ipython3",
471 | "version": "3.6.10"
472 | },
473 | "toc": {
474 | "nav_menu": {
475 | "height": "141px",
476 | "width": "160px"
477 | },
478 | "number_sections": true,
479 | "sideBar": true,
480 | "skip_h1_title": false,
481 | "title_cell": "Table of Contents",
482 | "title_sidebar": "Contents",
483 | "toc_cell": false,
484 | "toc_position": {},
485 | "toc_section_display": true,
486 | "toc_window_display": false
487 | }
488 | },
489 | "nbformat": 4,
490 | "nbformat_minor": 4
491 | }
492 |
--------------------------------------------------------------------------------
/experiments/psMNIST/psMNIST.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "code",
5 | "execution_count": null,
6 | "metadata": {
7 | "ExecuteTime": {
8 | "end_time": "2020-12-03T16:57:39.184101Z",
9 | "start_time": "2020-12-03T16:57:38.997252Z"
10 | }
11 | },
12 | "outputs": [],
13 | "source": [
14 | "%matplotlib inline\n"
15 | ]
16 | },
17 | {
18 | "cell_type": "code",
19 | "execution_count": null,
20 | "metadata": {
21 | "ExecuteTime": {
22 | "end_time": "2020-12-05T18:29:47.565802Z",
23 | "start_time": "2020-12-05T18:29:47.560378Z"
24 | }
25 | },
26 | "outputs": [],
27 | "source": [
28 | "import matplotlib.pyplot as plt\n",
29 | "from torch.optim.lr_scheduler import StepLR\n",
30 | "import torch\n",
31 | "from sklearn.metrics import confusion_matrix\n",
32 | "import pandas as pd\n",
33 | "import numpy as np\n",
34 | "\n",
35 | "import torch.nn as nn\n",
36 | "\n",
37 | "from tqdm import tqdm_notebook\n",
38 | "\n",
39 | "from torchvision import transforms\n",
40 | "from torchvision import datasets\n",
41 | "from os.path import join\n",
42 | "\n",
43 | "from deepsith import DeepSITH\n",
44 | "\n",
45 | "from tqdm.notebook import tqdm\n",
46 | "\n",
47 | "import random\n",
48 | "\n",
49 | "from csv import DictWriter\n",
50 | "# if gpu is to be used\n",
51 | "use_cuda = torch.cuda.is_available()\n",
52 | "\n",
53 | "FloatTensor = torch.cuda.FloatTensor if use_cuda else torch.FloatTensor\n",
54 | "LongTensor = torch.cuda.LongTensor if use_cuda else torch.LongTensor\n",
55 | "\n",
56 | "ttype =FloatTensor\n",
57 | "\n",
58 | "import seaborn as sn\n",
59 | "print(use_cuda)\n",
60 | "import pickle\n",
61 | "\n",
62 | "sn.set_context(\"poster\")"
63 | ]
64 | },
65 | {
66 | "cell_type": "code",
67 | "execution_count": null,
68 | "metadata": {
69 | "ExecuteTime": {
70 | "end_time": "2020-12-03T16:57:40.061632Z",
71 | "start_time": "2020-12-03T16:57:40.058700Z"
72 | }
73 | },
74 | "outputs": [],
75 | "source": [
76 | "class DeepSITH_Classifier(nn.Module):\n",
77 | " def __init__(self, out_features, layer_params, dropout=.1):\n",
78 | " super(DeepSITH_Classifier, self).__init__()\n",
79 | " last_hidden = layer_params[-1]['hidden_size']\n",
80 | " self.hs = DeepSITH(layer_params=layer_params, dropout=dropout)\n",
81 | " self.to_out = nn.Linear(last_hidden, out_features)\n",
82 | " def forward(self, inp):\n",
83 | " x = self.hs(inp)\n",
84 | " x = self.to_out(x)\n",
85 | " return x"
86 | ]
87 | },
88 | {
89 | "cell_type": "markdown",
90 | "metadata": {},
91 | "source": [
92 | "# Load Stimuli"
93 | ]
94 | },
95 | {
96 | "cell_type": "code",
97 | "execution_count": null,
98 | "metadata": {
99 | "ExecuteTime": {
100 | "end_time": "2020-12-03T16:57:41.468289Z",
101 | "start_time": "2020-12-03T16:57:41.465864Z"
102 | }
103 | },
104 | "outputs": [],
105 | "source": [
106 | "# Same seed and supposed Permutation as the coRNN paper\n",
107 | "torch.manual_seed(12008)\n",
108 | "permute = torch.randperm(784)"
109 | ]
110 | },
111 | {
112 | "cell_type": "code",
113 | "execution_count": null,
114 | "metadata": {
115 | "ExecuteTime": {
116 | "end_time": "2020-12-03T16:57:42.535384Z",
117 | "start_time": "2020-12-03T16:57:42.527542Z"
118 | }
119 | },
120 | "outputs": [],
121 | "source": [
122 | "print(permute)"
123 | ]
124 | },
125 | {
126 | "cell_type": "code",
127 | "execution_count": null,
128 | "metadata": {
129 | "ExecuteTime": {
130 | "end_time": "2020-12-03T16:57:45.892964Z",
131 | "start_time": "2020-12-03T16:57:45.890797Z"
132 | }
133 | },
134 | "outputs": [],
135 | "source": [
136 | "norm = transforms.Normalize((.1307,), (.3081,), )"
137 | ]
138 | },
139 | {
140 | "cell_type": "code",
141 | "execution_count": null,
142 | "metadata": {
143 | "ExecuteTime": {
144 | "end_time": "2020-12-03T16:57:46.269650Z",
145 | "start_time": "2020-12-03T16:57:46.255928Z"
146 | }
147 | },
148 | "outputs": [],
149 | "source": [
150 | "batch_size = 64\n",
151 | "transform = transforms.Compose([transforms.ToTensor(),\n",
152 | " transforms.Normalize((.1307,), (.3081,))\n",
153 | " ])\n",
154 | "ds1 = datasets.MNIST('../data', train=True, download=True, transform=transform)\n",
155 | "ds2 = datasets.MNIST('../data', train=False, download=True, transform=transform)\n",
156 | "train_loader=torch.utils.data.DataLoader(ds1,batch_size=batch_size, \n",
157 | " num_workers=1, pin_memory=True, shuffle=True)\n",
158 | "test_loader=torch.utils.data.DataLoader(ds2, batch_size=batch_size, \n",
159 | " num_workers=1, pin_memory=True, shuffle=True)"
160 | ]
161 | },
162 | {
163 | "cell_type": "code",
164 | "execution_count": null,
165 | "metadata": {
166 | "ExecuteTime": {
167 | "end_time": "2020-12-03T16:57:51.113709Z",
168 | "start_time": "2020-12-03T16:57:47.283835Z"
169 | }
170 | },
171 | "outputs": [],
172 | "source": [
173 | "testi = next(iter(test_loader))[0]\n",
174 | "\n",
175 | "plt.imshow(testi[0].reshape(-1)[permute].reshape(28,28))\n",
176 | "\n",
177 | "plt.colorbar()"
178 | ]
179 | },
180 | {
181 | "cell_type": "code",
182 | "execution_count": null,
183 | "metadata": {
184 | "ExecuteTime": {
185 | "end_time": "2020-12-03T16:57:51.164838Z",
186 | "start_time": "2020-12-03T16:57:51.162192Z"
187 | }
188 | },
189 | "outputs": [],
190 | "source": [
191 | "testi.shape"
192 | ]
193 | },
194 | {
195 | "cell_type": "markdown",
196 | "metadata": {},
197 | "source": [
198 | "# Define test and train"
199 | ]
200 | },
201 | {
202 | "cell_type": "code",
203 | "execution_count": null,
204 | "metadata": {
205 | "ExecuteTime": {
206 | "end_time": "2020-12-04T15:09:47.998604Z",
207 | "start_time": "2020-12-04T15:09:47.988228Z"
208 | }
209 | },
210 | "outputs": [],
211 | "source": [
212 | "\n",
213 | "def train(model, ttype, train_loader, test_loader, optimizer, loss_func, epoch, perf_file,\n",
214 | " permute=None, loss_buffer_size=800, batch_size=4, device='cuda',\n",
215 | " prog_bar=None):\n",
216 | " \n",
217 | " assert(loss_buffer_size%batch_size==0)\n",
218 | " if permute is None:\n",
219 | " permute = torch.LongTensor(list(range(784)))\n",
220 | " \n",
221 | " losses = []\n",
222 | " perfs = []\n",
223 | " last_test_perf = 0\n",
224 | " best_test_perf = -1\n",
225 | " \n",
226 | " for batch_idx, (data, target) in enumerate(train_loader):\n",
227 | " model.train()\n",
228 | " data = data.to(device).view(data.shape[0],1,1,-1)\n",
229 | " target = target.to(device)\n",
230 | " optimizer.zero_grad()\n",
231 | " out = model(data[:, :, :, permute])\n",
232 | " loss = loss_func(out[:, -1, :],\n",
233 | " target)\n",
234 | "\n",
235 | " loss.backward()\n",
236 | " optimizer.step()\n",
237 | "\n",
238 | " perfs.append((torch.argmax(out[:, -1, :], dim=-1) == \n",
239 | " target).sum().item())\n",
240 | " perfs = perfs[int(-loss_buffer_size/batch_size):]\n",
241 | " losses.append(loss.detach().cpu().numpy())\n",
242 | " losses = losses[int(-loss_buffer_size/batch_size):]\n",
243 | " if not (prog_bar is None):\n",
244 | " # Update progress_bar\n",
245 | " s = \"{}:{} Loss: {:.4f}, perf: {:.4f}, valid: {:.4f}\"\n",
246 | " format_list = [e,batch_idx*batch_size, np.mean(losses), \n",
247 | " np.sum(perfs)/((len(perfs))*batch_size), last_test_perf] \n",
248 | " s = s.format(*format_list)\n",
249 | " prog_bar.set_description(s)\n",
250 | " \n",
251 | " if ((batch_idx*batch_size)%loss_buffer_size == 0) & (batch_idx != 0):\n",
252 | " loss_track = {}\n",
253 | " # last_test_perf = test(model, 'cuda', test_loader, \n",
254 | " # batch_size=batch_size, \n",
255 | " # permute=permute)\n",
256 | " loss_track['avg_loss'] = np.mean(losses)\n",
257 | " #loss_track['last_test'] = last_test_perf\n",
258 | " loss_track['epoch'] = epoch\n",
259 | " loss_track['batch_idx'] = batch_idx\n",
260 | " loss_track['train_perf']= np.sum(perfs)/((len(perfs))*batch_size)\n",
261 | " with open(perf_file, 'a+') as fp:\n",
262 | " csv_writer = DictWriter(fp, fieldnames=list(loss_track.keys()))\n",
263 | " if fp.tell() == 0:\n",
264 | " csv_writer.writeheader()\n",
265 | " csv_writer.writerow(loss_track)\n",
266 | " fp.flush()\n",
267 | " #if best_test_perf < last_test_perf:\n",
268 | " # torch.save(model.state_dict(), perf_file[:-4]+\".pt\")\n",
269 | " # best_test_perf = last_test_perf\n",
270 | "\n",
271 | " \n",
272 | "def test(model, device, test_loader, batch_size=4, permute=None):\n",
273 | " if permute is None:\n",
274 | " permute = torch.LongTensor(list(range(784)))\n",
275 | " \n",
276 | " model.eval()\n",
277 | " correct = 0\n",
278 | " count = 0\n",
279 | " with torch.no_grad():\n",
280 | " for data, target in test_loader:\n",
281 | " data = data.to(device).view(data.shape[0],1,1,-1)\n",
282 | " target = target.to(device)\n",
283 | " \n",
284 | " out = model(data[:,:,:, permute])\n",
285 | " pred = out[:, -1].argmax(dim=-1, keepdim=True)\n",
286 | " correct += pred.eq(target.view_as(pred)).sum().item()\n",
287 | " count += 1\n",
288 | " return correct / len(test_loader.dataset)"
289 | ]
290 | },
291 | {
292 | "cell_type": "markdown",
293 | "metadata": {},
294 | "source": [
295 | "# Setup the model"
296 | ]
297 | },
298 | {
299 | "cell_type": "code",
300 | "execution_count": null,
301 | "metadata": {
302 | "ExecuteTime": {
303 | "end_time": "2020-12-04T15:09:48.936759Z",
304 | "start_time": "2020-12-04T15:09:48.926780Z"
305 | },
306 | "scrolled": true
307 | },
308 | "outputs": [],
309 | "source": [
310 | "g = 0.0\n",
311 | "sith_params1 = {\"in_features\":1, \n",
312 | " \"tau_min\":1, \"tau_max\":30.0, \"buff_max\":50,\n",
313 | " \"k\":125, 'dt':1,\n",
314 | " \"ntau\":20, 'g':g, \n",
315 | " \"ttype\":ttype, \"batch_norm\":True,\n",
316 | " \"hidden_size\":60, \"act_func\":nn.ReLU()\n",
317 | " }\n",
318 | "sith_params2 = {\"in_features\":sith_params1['hidden_size'], \n",
319 | " \"tau_min\":1, \"tau_max\":150.0, \"buff_max\":250,\n",
320 | " \"k\":61, 'dt':1,\n",
321 | " \"ntau\":20, 'g':g, \n",
322 | " \"ttype\":ttype, \"batch_norm\":True,\n",
323 | " \"hidden_size\":60, \"act_func\":nn.ReLU()\n",
324 | " }\n",
325 | "sith_params3 = {\"in_features\":sith_params2['hidden_size'], \n",
326 | " \"tau_min\":1, \"tau_max\":750.0, \"buff_max\":1500,\n",
327 | " \"k\":35, 'dt':1,\n",
328 | " \"ntau\":20, 'g':g, \n",
329 | " \"ttype\":ttype, \"batch_norm\":True,\n",
330 | " \"hidden_size\":60, \"act_func\":nn.ReLU()\n",
331 | " }\n",
332 | "\n",
333 | "\n",
334 | "\n",
335 | "layer_params = [sith_params1, sith_params2, sith_params3]\n",
336 | "\n",
337 | "\n",
338 | "model = DeepSITH_Classifier(10,\n",
339 | " layer_params=layer_params, \n",
340 | " dropout=0.2).cuda()\n",
341 | "\n",
342 | "tot_weights = 0\n",
343 | "for p in model.parameters():\n",
344 | " tot_weights += p.numel()\n",
345 | "print(\"Total Weights:\", tot_weights)\n",
346 | "print(model)"
347 | ]
348 | },
349 | {
350 | "cell_type": "code",
351 | "execution_count": null,
352 | "metadata": {
353 | "ExecuteTime": {
354 | "end_time": "2020-12-05T00:17:49.489853Z",
355 | "start_time": "2020-12-04T15:09:54.978362Z"
356 | },
357 | "scrolled": true
358 | },
359 | "outputs": [],
360 | "source": [
361 | "epochs = 90\n",
362 | "loss_func = nn.CrossEntropyLoss()\n",
363 | "optimizer = torch.optim.Adam(model.parameters(), lr=2e-3)\n",
364 | "sched = StepLR(optimizer, step_size=30, gamma=0.1)\n",
365 | "progress_bar = tqdm(range(int(epochs)), bar_format='{l_bar}{bar:5}{r_bar}{bar:-5b}')\n",
366 | "test_perf = []\n",
367 | "for e in progress_bar:\n",
368 | " train(model, ttype, train_loader, test_loader, optimizer, loss_func, batch_size=batch_size,\n",
369 | " epoch=e, perf_file=join('perf','pmnist_deepsith_78.csv'),loss_buffer_size=64*32, \n",
370 | " prog_bar=progress_bar, permute=permute)\n",
371 | " last_test_perf = test(model, 'cuda', test_loader, \n",
372 | " batch_size=batch_size, \n",
373 | " permute=permute)\n",
374 | " test_perf.append({\"epoch\":e,\n",
375 | " 'test':last_test_perf})\n",
376 | " sched.step()"
377 | ]
378 | },
379 | {
380 | "cell_type": "markdown",
381 | "metadata": {
382 | "ExecuteTime": {
383 | "end_time": "2020-11-02T01:31:22.277851Z",
384 | "start_time": "2020-11-02T01:31:22.275272Z"
385 | }
386 | },
387 | "source": [
388 | "# Find Errors"
389 | ]
390 | },
391 | {
392 | "cell_type": "code",
393 | "execution_count": null,
394 | "metadata": {
395 | "ExecuteTime": {
396 | "end_time": "2020-12-05T00:26:42.879054Z",
397 | "start_time": "2020-12-05T00:26:29.019237Z"
398 | }
399 | },
400 | "outputs": [],
401 | "source": [
402 | "test(model, 'cuda', test_loader, \n",
403 | " batch_size=batch_size, \n",
404 | " permute=permute)"
405 | ]
406 | },
407 | {
408 | "cell_type": "code",
409 | "execution_count": null,
410 | "metadata": {
411 | "ExecuteTime": {
412 | "end_time": "2020-11-10T14:51:23.948941Z",
413 | "start_time": "2020-11-10T14:51:23.944246Z"
414 | }
415 | },
416 | "outputs": [],
417 | "source": [
418 | "def conf_mat_gen(model, device, test_loader, batch_size=4, permute=None):\n",
419 | " if permute is None:\n",
420 | " permute = torch.LongTensor(list(range(784)))\n",
421 | " evals = {'pred':[],\n",
422 | " 'actual':[]}\n",
423 | " model.eval()\n",
424 | " correct = 0\n",
425 | " count = 0\n",
426 | " with torch.no_grad():\n",
427 | " for data, target in test_loader:\n",
428 | " data = data.to(device).view(data.shape[0],1,1,-1)\n",
429 | " target = target.to(device)\n",
430 | " for x in target:\n",
431 | " evals['actual'].append(x.detach().cpu().numpy())\n",
432 | " out = model(data[:,:,:, permute])\n",
433 | " for x in out[:, -1].argmax(dim=-1, keepdim=True):\n",
434 | " evals['pred'].append(x.detach().cpu().numpy())\n",
435 | " return evals"
436 | ]
437 | },
438 | {
439 | "cell_type": "code",
440 | "execution_count": null,
441 | "metadata": {
442 | "ExecuteTime": {
443 | "end_time": "2020-11-10T14:51:36.918179Z",
444 | "start_time": "2020-11-10T14:51:25.179622Z"
445 | }
446 | },
447 | "outputs": [],
448 | "source": [
449 | "evals = conf_mat_gen(model, 'cuda', test_loader, batch_size=4, permute=permute)"
450 | ]
451 | },
452 | {
453 | "cell_type": "code",
454 | "execution_count": null,
455 | "metadata": {
456 | "ExecuteTime": {
457 | "end_time": "2020-11-10T14:53:19.261434Z",
458 | "start_time": "2020-11-10T14:53:18.926905Z"
459 | }
460 | },
461 | "outputs": [],
462 | "source": [
463 | "fig = plt.figure(figsize=(10,11))\n",
464 | "plt.imshow(confusion_matrix(np.array(evals['pred'])[:, 0], \n",
465 | " np.array(evals['actual'])), cmap='coolwarm')\n",
466 | "plt.colorbar()\n",
467 | "plt.xticks(list(range(10)));\n",
468 | "plt.yticks(list(range(10)));\n",
469 | "plt.savefig(join('figs', 'sMNIST_LoLa_conf'), dpi=200, bboxinches='tight')"
470 | ]
471 | }
472 | ],
473 | "metadata": {
474 | "kernelspec": {
475 | "display_name": "Python 3",
476 | "language": "python",
477 | "name": "python3"
478 | },
479 | "language_info": {
480 | "codemirror_mode": {
481 | "name": "ipython",
482 | "version": 3
483 | },
484 | "file_extension": ".py",
485 | "mimetype": "text/x-python",
486 | "name": "python",
487 | "nbconvert_exporter": "python",
488 | "pygments_lexer": "ipython3",
489 | "version": "3.6.10"
490 | },
491 | "toc": {
492 | "nav_menu": {},
493 | "number_sections": true,
494 | "sideBar": true,
495 | "skip_h1_title": false,
496 | "title_cell": "Table of Contents",
497 | "title_sidebar": "Contents",
498 | "toc_cell": false,
499 | "toc_position": {},
500 | "toc_section_display": true,
501 | "toc_window_display": false
502 | }
503 | },
504 | "nbformat": 4,
505 | "nbformat_minor": 4
506 | }
507 |
--------------------------------------------------------------------------------
/experiments/Hateful8/hateful8-LSTM.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "code",
5 | "execution_count": null,
6 | "metadata": {
7 | "ExecuteTime": {
8 | "end_time": "2020-12-18T15:21:53.650363Z",
9 | "start_time": "2020-12-18T15:21:53.466128Z"
10 | }
11 | },
12 | "outputs": [],
13 | "source": [
14 | "%matplotlib inline"
15 | ]
16 | },
17 | {
18 | "cell_type": "code",
19 | "execution_count": null,
20 | "metadata": {
21 | "ExecuteTime": {
22 | "end_time": "2020-12-18T15:21:54.221984Z",
23 | "start_time": "2020-12-18T15:21:53.678357Z"
24 | }
25 | },
26 | "outputs": [],
27 | "source": [
28 | "import matplotlib.pyplot as plt\n",
29 | "import torch\n",
30 | "from torch import nn as nn\n",
31 | "from math import factorial\n",
32 | "import random\n",
33 | "import torch.nn.functional as F\n",
34 | "import numpy as np\n",
35 | "import seaborn as sn\n",
36 | "import pandas as pd\n",
37 | "import os \n",
38 | "from os.path import join\n",
39 | "import glob\n",
40 | "from math import factorial\n",
41 | "ttype = torch.cuda.FloatTensor if torch.cuda.is_available() else torch.FloatTensor\n",
42 | "ctype = torch.cuda.LongTensor if torch.cuda.is_available() else torch.LongTensor\n",
43 | "print(ttype)\n",
44 | "from torch.nn.utils import weight_norm\n",
45 | "\n",
46 | "from tqdm.notebook import tqdm\n",
47 | "import pickle\n",
48 | "sn.set_context(\"poster\")\n",
49 | "import itertools\n",
50 | "from csv import DictWriter\n",
51 | "import matplotlib.pylab as plt\n",
52 | "import csv\n",
53 | "import numpy as np\n",
54 | "import pandas as pd\n",
55 | "import os\n",
56 | "import seaborn as sn\n",
57 | "sn.set_context('talk')"
58 | ]
59 | },
60 | {
61 | "cell_type": "code",
62 | "execution_count": null,
63 | "metadata": {
64 | "ExecuteTime": {
65 | "end_time": "2020-12-18T15:24:24.946369Z",
66 | "start_time": "2020-12-18T15:24:24.822539Z"
67 | }
68 | },
69 | "outputs": [],
70 | "source": [
71 | "def generate_noise(maxn=18):\n",
72 | " \"\"\"Generates dot and dash based noise.\"\"\"\n",
73 | " \n",
74 | " threes = np.random.randint(int(.5*maxn), int(.75*maxn))\n",
75 | " ones = (maxn - threes) * 2\n",
76 | " noise = list(itertools.repeat([1,1,1,0], threes))\n",
77 | " noise[:int(len(noise)/3)] = list(itertools.repeat([0,0], int(len(noise)/3)))\n",
78 | " ones = ones + int(len(noise)/3)\n",
79 | " noise.extend(list(itertools.repeat([1,0], ones)))\n",
80 | " random.shuffle(noise)\n",
81 | " noise = np.concatenate(noise)\n",
82 | " return noise\n",
83 | "noise = generate_noise()\n",
84 | "print(noise.shape)\n",
85 | "plt.plot(noise)"
86 | ]
87 | },
88 | {
89 | "cell_type": "code",
90 | "execution_count": null,
91 | "metadata": {
92 | "ExecuteTime": {
93 | "end_time": "2020-12-18T15:24:32.691623Z",
94 | "start_time": "2020-12-18T15:24:30.795867Z"
95 | }
96 | },
97 | "outputs": [],
98 | "source": [
99 | "sig_lets = [\"A\",\"B\",\"C\",\"D\",\"E\",\"F\",\"G\",\"H\",]\n",
100 | "\n",
101 | "signals = ttype([[0,1,1,1,0,1,1,1,0,1,0,1,0,1,0,0,0],\n",
102 | " [0,1,1,1,0,1,0,1,1,1,0,1,0,1,0,0,0],\n",
103 | " [0,1,1,1,0,1,0,1,0,1,1,1,0,1,0,0,0],\n",
104 | " [0,1,1,1,0,1,0,1,0,1,0,1,1,1,0,0,0],\n",
105 | " [0,1,0,1,1,1,0,1,1,1,0,1,1,1,0,0,0],\n",
106 | " [0,1,1,1,0,1,0,1,1,1,0,1,1,1,0,0,0],\n",
107 | " [0,1,1,1,0,1,1,1,0,1,0,1,1,1,0,0,0],\n",
108 | " [0,1,1,1,0,1,1,1,0,1,1,1,0,1,0,0,0]]\n",
109 | " ).view(8, 1, 1, -1)\n",
110 | "#signals = ms\n",
111 | "key2id = {k:i for i, k in enumerate(sig_lets)}\n",
112 | "\n",
113 | "print(key2id)\n"
114 | ]
115 | },
116 | {
117 | "cell_type": "code",
118 | "execution_count": null,
119 | "metadata": {
120 | "ExecuteTime": {
121 | "end_time": "2020-12-11T18:54:00.513324Z",
122 | "start_time": "2020-12-11T18:54:00.368662Z"
123 | }
124 | },
125 | "outputs": [],
126 | "source": [
127 | "torch.manual_seed(12345)\n",
128 | "np.random.seed(12345)\n",
129 | "training_samples = 32\n",
130 | "\n",
131 | "training_signals = []\n",
132 | "training_class = []\n",
133 | "\n",
134 | "for i, sig in enumerate(signals):\n",
135 | " temp_signals = []\n",
136 | " temp_class = []\n",
137 | " for x in range(training_samples):\n",
138 | " noise = ttype(generate_noise())\n",
139 | " temp = torch.cat([sig[0,0], noise]).unsqueeze(0)\n",
140 | " while(any([(temp == c_).all() for c_ in temp_signals])):\n",
141 | " noise = ttype(generate_noise())\n",
142 | " temp = torch.cat([sig[0,0], noise]).unsqueeze(0)\n",
143 | " temp_signals.append(temp)\n",
144 | " temp_class.append(i)\n",
145 | " training_signals.extend(temp_signals)\n",
146 | " training_class.extend(temp_class)\n",
147 | "\n",
148 | "batch_rand = torch.randperm(training_samples*signals.shape[0]) \n",
149 | "training_signals = torch.cat(training_signals).cuda().unsqueeze(-1)[batch_rand]\n",
150 | "training_class = ctype(training_class).cuda().unsqueeze(-1)[batch_rand]\n",
151 | "\n",
152 | "dataset = torch.utils.data.TensorDataset(training_signals, training_class)\n",
153 | "dataset = torch.utils.data.DataLoader(dataset, batch_size=32, shuffle=True)\n",
154 | "\n"
155 | ]
156 | },
157 | {
158 | "cell_type": "code",
159 | "execution_count": null,
160 | "metadata": {
161 | "ExecuteTime": {
162 | "end_time": "2020-12-11T18:54:02.761255Z",
163 | "start_time": "2020-12-11T18:54:02.121822Z"
164 | }
165 | },
166 | "outputs": [],
167 | "source": [
168 | "testing_samples = 10\n",
169 | "testing_signals = []\n",
170 | "testing_class = []\n",
171 | "\n",
172 | "for i, sig in enumerate(signals):\n",
173 | " temp_signals = []\n",
174 | " temp_class = []\n",
175 | " for x in range(testing_samples):\n",
176 | " noise = ttype(generate_noise())\n",
177 | " temp = torch.cat([sig[0,0], noise]).unsqueeze(0)\n",
178 | " while(any([(temp == c_).all() for c_ in temp_signals]) or any([(temp == c_).all() for c_ in training_signals])):\n",
179 | " noise = ttype(generate_noise())\n",
180 | " temp = torch.cat([sig[0,0], noise]).unsqueeze(0)\n",
181 | " temp_signals.append(temp)\n",
182 | " temp_class.append(i)\n",
183 | " testing_signals.extend(temp_signals)\n",
184 | " testing_class.extend(temp_class)\n",
185 | "batch_rand = torch.randperm(testing_samples*signals.shape[0])\n",
186 | "\n",
187 | "testing_signals = torch.cat(testing_signals).cuda().unsqueeze(-1)[batch_rand]\n",
188 | "testing_class = ctype(testing_class).cuda().unsqueeze(-1)[batch_rand]\n",
189 | "\n",
190 | "\n",
191 | "dataset_valid = torch.utils.data.TensorDataset(testing_signals, testing_class)\n",
192 | "dataset_valid = torch.utils.data.DataLoader(dataset_valid, batch_size=32, shuffle=False)\n"
193 | ]
194 | },
195 | {
196 | "cell_type": "code",
197 | "execution_count": null,
198 | "metadata": {
199 | "ExecuteTime": {
200 | "end_time": "2020-12-18T15:24:33.846618Z",
201 | "start_time": "2020-12-18T15:24:33.842381Z"
202 | }
203 | },
204 | "outputs": [],
205 | "source": [
206 | "class LSTM_Predictor(nn.Module):\n",
207 | " def __init__(self, out_features, lstm_params):\n",
208 | " super(LSTM_Predictor, self).__init__()\n",
209 | " self.lstm = nn.LSTM(**lstm_params)\n",
210 | " self.to_out = nn.Linear(lstm_params['hidden_size'], \n",
211 | " out_features)\n",
212 | " def forward(self, inp):\n",
213 | " x = self.lstm(inp)[0].transpose(1,0)\n",
214 | " x = torch.tanh(self.to_out(x))\n",
215 | " return x"
216 | ]
217 | },
218 | {
219 | "cell_type": "code",
220 | "execution_count": null,
221 | "metadata": {
222 | "ExecuteTime": {
223 | "end_time": "2020-12-18T15:24:48.479338Z",
224 | "start_time": "2020-12-18T15:24:48.463241Z"
225 | }
226 | },
227 | "outputs": [],
228 | "source": [
229 | "def train(model, ttype, train_loader, test_loader, optimizer, loss_func, epoch, perf_file,\n",
230 | " permute=None, loss_buffer_size=64, batch_size=4, device='cuda',\n",
231 | " prog_bar=None, maxn=6):\n",
232 | " \n",
233 | " assert(loss_buffer_size%batch_size==0)\n",
234 | " \n",
235 | " losses = []\n",
236 | " perfs = []\n",
237 | " last_test_perf = 0\n",
238 | " best_test_perf = -1\n",
239 | " \n",
240 | " for batch_idx, (data, target) in enumerate(train_loader):\n",
241 | " model.train()\n",
242 | " data = data.to(device).transpose(1,0)\n",
243 | " target = target.to(device)\n",
244 | " optimizer.zero_grad()\n",
245 | " out = model(data)\n",
246 | " loss = loss_func(out[:, -1, :],\n",
247 | " target[:, 0])\n",
248 | " \n",
249 | " loss.backward()\n",
250 | " optimizer.step()\n",
251 | "\n",
252 | " perfs.append((torch.argmax(out[:, -1, :], dim=-1) == \n",
253 | " target[:, 0]).sum().item())\n",
254 | " perfs = perfs[int(-loss_buffer_size/batch_size):]\n",
255 | " losses.append(loss.detach().cpu().numpy())\n",
256 | " losses = losses[int(-loss_buffer_size/batch_size):]\n",
257 | " if not (prog_bar is None):\n",
258 | " # Update progress_bar\n",
259 | " s = \"{}:{} Loss: {:.4f}, perf: {:.4f}, valid: {:.4f}\"\n",
260 | " format_list = [e,batch_idx*batch_size, np.mean(losses), \n",
261 | " np.sum(perfs)/((len(perfs))*batch_size), last_test_perf] \n",
262 | " s = s.format(*format_list)\n",
263 | " prog_bar.set_description(s)\n",
264 | " \n",
265 | " if ((batch_idx*batch_size)%loss_buffer_size == 0) & (batch_idx != 0):\n",
266 | " loss_track = {}\n",
267 | " last_test_perf = test(model, 'cuda', test_loader, \n",
268 | " batch_size=batch_size, \n",
269 | " permute=permute)\n",
270 | " loss_track['avg_loss'] = np.mean(losses)\n",
271 | " loss_track['last_test'] = last_test_perf\n",
272 | " loss_track['epoch'] = epoch\n",
273 | " loss_track['maxn'] = maxn\n",
274 | " loss_track['batch_idx'] = batch_idx\n",
275 | " loss_track['pres_num'] = batch_idx*batch_size + epoch*len(train_loader.dataset)\n",
276 | " loss_track['train_perf']= np.sum(perfs)/((len(perfs))*batch_size)\n",
277 | " with open(perf_file, 'a+') as fp:\n",
278 | " csv_writer = DictWriter(fp, fieldnames=list(loss_track.keys()))\n",
279 | " if fp.tell() == 0:\n",
280 | " csv_writer.writeheader()\n",
281 | " csv_writer.writerow(loss_track)\n",
282 | " fp.flush()\n",
283 | " #if best_test_perf < last_test_perf:\n",
284 | " # torch.save(model.state_dict(), perf_file[:-4]+\".pt\")\n",
285 | " # best_test_perf = last_test_perf\n",
286 | "\n",
287 | " \n",
288 | "def test(model, device, test_loader, batch_size=4, permute=None):\n",
289 | " model.eval()\n",
290 | " correct = 0\n",
291 | " count = 0\n",
292 | " with torch.no_grad():\n",
293 | " for data, target in test_loader:\n",
294 | " data = data.to(device).transpose(1,0)\n",
295 | " target = target.to(device)\n",
296 | " \n",
297 | " out = model(data)\n",
298 | " pred = out[:, -1, :].argmax(dim=-1, keepdim=True)\n",
299 | " \n",
300 | " correct += pred.eq(target.view_as(pred)).sum().item()\n",
301 | " count += 1\n",
302 | " return correct / len(test_loader.dataset)"
303 | ]
304 | },
305 | {
306 | "cell_type": "code",
307 | "execution_count": null,
308 | "metadata": {},
309 | "outputs": [],
310 | "source": []
311 | },
312 | {
313 | "cell_type": "markdown",
314 | "metadata": {},
315 | "source": [
316 | "# Training and testing"
317 | ]
318 | },
319 | {
320 | "cell_type": "code",
321 | "execution_count": null,
322 | "metadata": {
323 | "ExecuteTime": {
324 | "end_time": "2020-12-11T18:58:50.330501Z",
325 | "start_time": "2020-12-11T18:58:50.327920Z"
326 | }
327 | },
328 | "outputs": [],
329 | "source": [
330 | "# You likely don't need this to be this long, but just in case.\n",
331 | "epochs = 1000\n",
332 | "\n",
333 | "# Just for visualizing average loss through time. \n",
334 | "loss_buffer_size = 100"
335 | ]
336 | },
337 | {
338 | "cell_type": "code",
339 | "execution_count": null,
340 | "metadata": {
341 | "ExecuteTime": {
342 | "end_time": "2020-12-18T16:14:26.814579Z",
343 | "start_time": "2020-12-18T16:13:00.732980Z"
344 | },
345 | "scrolled": true
346 | },
347 | "outputs": [],
348 | "source": [
349 | "test_noise_lengths = [6,7,9,13,21,37]\n",
350 | "for maxn in test_noise_lengths:\n",
351 | " torch.manual_seed(12345)\n",
352 | " np.random.seed(12345)\n",
353 | " training_samples = 32\n",
354 | "\n",
355 | " training_signals = []\n",
356 | " training_class = []\n",
357 | "\n",
358 | " for i, sig in enumerate(signals):\n",
359 | " temp_signals = []\n",
360 | " temp_class = []\n",
361 | " for x in range(training_samples):\n",
362 | " noise = ttype(generate_noise(maxn))\n",
363 | " temp = torch.cat([sig[0,0], noise]).unsqueeze(0)\n",
364 | " while(any([(temp == c_).all() for c_ in temp_signals])):\n",
365 | " noise = ttype(generate_noise(maxn))\n",
366 | " temp = torch.cat([sig[0,0], noise]).unsqueeze(0)\n",
367 | " temp_signals.append(temp)\n",
368 | " temp_class.append(i)\n",
369 | " training_signals.extend(temp_signals)\n",
370 | " training_class.extend(temp_class)\n",
371 | "\n",
372 | " batch_rand = torch.randperm(training_samples*signals.shape[0]) \n",
373 | " training_signals = torch.cat(training_signals).cuda().unsqueeze(-1)[batch_rand]\n",
374 | " training_class = ctype(training_class).cuda().unsqueeze(-1)[batch_rand]\n",
375 | "\n",
376 | " dataset = torch.utils.data.TensorDataset(training_signals, training_class)\n",
377 | " dataset = torch.utils.data.DataLoader(dataset, batch_size=32, shuffle=True)\n",
378 | " testing_samples = 10\n",
379 | " testing_signals = []\n",
380 | " testing_class = []\n",
381 | "\n",
382 | " for i, sig in enumerate(signals):\n",
383 | " temp_signals = []\n",
384 | " temp_class = []\n",
385 | " for x in range(testing_samples):\n",
386 | " noise = ttype(generate_noise(maxn))\n",
387 | " temp = torch.cat([sig[0,0], noise]).unsqueeze(0)\n",
388 | " while(any([(temp == c_).all() for c_ in temp_signals]) or any([(temp == c_).all() for c_ in training_signals])):\n",
389 | " noise = ttype(generate_noise(maxn))\n",
390 | " temp = torch.cat([sig[0,0], noise]).unsqueeze(0)\n",
391 | " temp_signals.append(temp)\n",
392 | " temp_class.append(i)\n",
393 | " testing_signals.extend(temp_signals)\n",
394 | " testing_class.extend(temp_class)\n",
395 | " batch_rand = torch.randperm(testing_samples*signals.shape[0])\n",
396 | "\n",
397 | " testing_signals = torch.cat(testing_signals).cuda().unsqueeze(-1)[batch_rand]\n",
398 | " testing_class = ctype(testing_class).cuda().unsqueeze(-1)[batch_rand]\n",
399 | "\n",
400 | "\n",
401 | " dataset_valid = torch.utils.data.TensorDataset(testing_signals, testing_class)\n",
402 | " dataset_valid = torch.utils.data.DataLoader(dataset_valid, batch_size=32, shuffle=False)\n",
403 | " lstm_params = dict(input_size=1,\n",
404 | " hidden_size=38, \n",
405 | " num_layers=3)\n",
406 | " model = LSTM_Predictor(8, lstm_params=lstm_params).cuda()\n",
407 | "\n",
408 | " tot_weights = 0\n",
409 | " for p in model.parameters():\n",
410 | " tot_weights += p.numel()\n",
411 | " print(\"Total Weights:\", tot_weights)\n",
412 | " print(model)\n",
413 | " loss_func = torch.nn.CrossEntropyLoss()\n",
414 | " optimizer = torch.optim.Adam(model.parameters())\n",
415 | " epochs = 400\n",
416 | " batch_size = 32\n",
417 | " progress_bar = tqdm(range(int(epochs)), bar_format='{l_bar}{bar:5}{r_bar}{bar:-5b}')\n",
418 | " for e in progress_bar:\n",
419 | " train(model, ttype, dataset, dataset_valid, \n",
420 | " optimizer, loss_func, batch_size=batch_size,\n",
421 | " epoch=e, perf_file=join('perf','h8_lstm_length_0.csv'),\n",
422 | " prog_bar=progress_bar, maxn=maxn)"
423 | ]
424 | }
425 | ],
426 | "metadata": {
427 | "kernelspec": {
428 | "display_name": "Python 3",
429 | "language": "python",
430 | "name": "python3"
431 | },
432 | "language_info": {
433 | "codemirror_mode": {
434 | "name": "ipython",
435 | "version": 3
436 | },
437 | "file_extension": ".py",
438 | "mimetype": "text/x-python",
439 | "name": "python",
440 | "nbconvert_exporter": "python",
441 | "pygments_lexer": "ipython3",
442 | "version": "3.6.10"
443 | },
444 | "toc": {
445 | "nav_menu": {},
446 | "number_sections": true,
447 | "sideBar": true,
448 | "skip_h1_title": false,
449 | "title_cell": "Table of Contents",
450 | "title_sidebar": "Contents",
451 | "toc_cell": false,
452 | "toc_position": {},
453 | "toc_section_display": true,
454 | "toc_window_display": false
455 | }
456 | },
457 | "nbformat": 4,
458 | "nbformat_minor": 4
459 | }
460 |
--------------------------------------------------------------------------------
/experiments/Hateful8/hateful8-LMU.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "code",
5 | "execution_count": null,
6 | "metadata": {
7 | "ExecuteTime": {
8 | "end_time": "2021-05-12T18:04:53.607992Z",
9 | "start_time": "2021-05-12T18:04:53.437658Z"
10 | }
11 | },
12 | "outputs": [],
13 | "source": [
14 | "%matplotlib inline\n"
15 | ]
16 | },
17 | {
18 | "cell_type": "code",
19 | "execution_count": null,
20 | "metadata": {
21 | "ExecuteTime": {
22 | "end_time": "2021-05-12T18:04:54.763726Z",
23 | "start_time": "2021-05-12T18:04:53.682215Z"
24 | }
25 | },
26 | "outputs": [],
27 | "source": [
28 | "import matplotlib.pyplot as plt\n",
29 | "import torch\n",
30 | "from torch import nn as nn\n",
31 | "from math import factorial\n",
32 | "import random\n",
33 | "import torch.nn.functional as F\n",
34 | "import numpy as np\n",
35 | "import seaborn as sn\n",
36 | "import pandas as pd\n",
37 | "import os \n",
38 | "from os.path import join\n",
39 | "import glob\n",
40 | "from math import factorial\n",
41 | "ttype = torch.cuda.FloatTensor if torch.cuda.is_available() else torch.FloatTensor\n",
42 | "ctype = torch.cuda.LongTensor if torch.cuda.is_available() else torch.LongTensor\n",
43 | "print(ttype)\n",
44 | "from torch.nn.utils import weight_norm\n",
45 | "from lmu import LegendreMemoryUnit\n",
46 | "\n",
47 | "from tqdm.notebook import tqdm\n",
48 | "import pickle\n",
49 | "sn.set_context(\"poster\")\n",
50 | "import itertools\n",
51 | "from csv import DictWriter\n",
52 | "import matplotlib.pylab as plt\n",
53 | "import csv\n",
54 | "import numpy as np\n",
55 | "import pandas as pd\n",
56 | "import os\n",
57 | "import seaborn as sn\n",
58 | "sn.set_context('talk')"
59 | ]
60 | },
61 | {
62 | "cell_type": "code",
63 | "execution_count": null,
64 | "metadata": {
65 | "ExecuteTime": {
66 | "end_time": "2021-05-12T18:04:54.848136Z",
67 | "start_time": "2021-05-12T18:04:54.764814Z"
68 | }
69 | },
70 | "outputs": [],
71 | "source": [
72 | "def generate_noise(maxn=18):\n",
73 | " \"\"\"Generates dot and dash based noise.\"\"\"\n",
74 | " \n",
75 | " threes = np.random.randint(int(.5*maxn), int(.75*maxn))\n",
76 | " ones = (maxn - threes) * 2\n",
77 | " noise = list(itertools.repeat([1,1,1,0], threes))\n",
78 | " noise[:int(len(noise)/3)] = list(itertools.repeat([0,0], int(len(noise)/3)))\n",
79 | " ones = ones + int(len(noise)/3)\n",
80 | " noise.extend(list(itertools.repeat([1,0], ones)))\n",
81 | " random.shuffle(noise)\n",
82 | " noise = np.concatenate(noise)\n",
83 | " return noise\n",
84 | "noise = generate_noise()\n",
85 | "print(noise.shape)\n",
86 | "plt.plot(noise)\n"
87 | ]
88 | },
89 | {
90 | "cell_type": "code",
91 | "execution_count": null,
92 | "metadata": {
93 | "ExecuteTime": {
94 | "end_time": "2021-05-12T18:04:55.827538Z",
95 | "start_time": "2021-05-12T18:04:54.849239Z"
96 | }
97 | },
98 | "outputs": [],
99 | "source": [
100 | "sig_lets = [\"A\",\"B\",\"C\",\"D\",\"E\",\"F\",\"G\",\"H\",]\n",
101 | "\n",
102 | "signals = ttype([[0,1,1,1,0,1,1,1,0,1,0,1,0,1,0,0,0],\n",
103 | " [0,1,1,1,0,1,0,1,1,1,0,1,0,1,0,0,0],\n",
104 | " [0,1,1,1,0,1,0,1,0,1,1,1,0,1,0,0,0],\n",
105 | " [0,1,1,1,0,1,0,1,0,1,0,1,1,1,0,0,0],\n",
106 | " \n",
107 | " [0,1,0,1,1,1,0,1,1,1,0,1,1,1,0,0,0],\n",
108 | " [0,1,1,1,0,1,0,1,1,1,0,1,1,1,0,0,0],\n",
109 | " [0,1,1,1,0,1,1,1,0,1,0,1,1,1,0,0,0],\n",
110 | " [0,1,1,1,0,1,1,1,0,1,1,1,0,1,0,0,0],\n",
111 | "\n",
112 | " ]\n",
113 | " ).view(8, 1, 1, -1)\n",
114 | "\n",
115 | "plt.imshow(signals[:,0,0,:].detach().cpu())\n"
116 | ]
117 | },
118 | {
119 | "cell_type": "code",
120 | "execution_count": null,
121 | "metadata": {
122 | "ExecuteTime": {
123 | "end_time": "2021-05-12T18:04:55.987629Z",
124 | "start_time": "2021-05-12T18:04:55.828588Z"
125 | }
126 | },
127 | "outputs": [],
128 | "source": [
129 | "torch.manual_seed(12345)\n",
130 | "np.random.seed(12345)\n",
131 | "training_samples = 32\n",
132 | "\n",
133 | "training_signals = []\n",
134 | "training_class = []\n",
135 | "\n",
136 | "for i, sig in enumerate(signals):\n",
137 | " temp_signals = []\n",
138 | " temp_class = []\n",
139 | " for x in range(training_samples):\n",
140 | " noise = ttype(generate_noise())\n",
141 | " temp = torch.cat([sig[0,0], noise]).unsqueeze(0)\n",
142 | " while(any([(temp == c_).all() for c_ in temp_signals])):\n",
143 | " noise = ttype(generate_noise())\n",
144 | " temp = torch.cat([sig[0,0], noise]).unsqueeze(0)\n",
145 | " temp_signals.append(temp)\n",
146 | " temp_class.append(i)\n",
147 | " training_signals.extend(temp_signals)\n",
148 | " training_class.extend(temp_class)\n",
149 | "\n",
150 | "batch_rand = torch.randperm(training_samples*signals.shape[0]) \n",
151 | "training_signals = torch.cat(training_signals).cuda().unsqueeze(-1)[batch_rand]\n",
152 | "training_class = ctype(training_class).cuda().unsqueeze(-1)[batch_rand]\n",
153 | "\n",
154 | "dataset = torch.utils.data.TensorDataset(training_signals, training_class)\n",
155 | "dataset = torch.utils.data.DataLoader(dataset, batch_size=32, shuffle=True)\n",
156 | "\n"
157 | ]
158 | },
159 | {
160 | "cell_type": "code",
161 | "execution_count": null,
162 | "metadata": {
163 | "ExecuteTime": {
164 | "end_time": "2021-05-12T18:04:56.934134Z",
165 | "start_time": "2021-05-12T18:04:56.388410Z"
166 | }
167 | },
168 | "outputs": [],
169 | "source": [
170 | "testing_samples = 10\n",
171 | "testing_signals = []\n",
172 | "testing_class = []\n",
173 | "\n",
174 | "for i, sig in enumerate(signals):\n",
175 | " temp_signals = []\n",
176 | " temp_class = []\n",
177 | " for x in range(testing_samples):\n",
178 | " noise = ttype(generate_noise())\n",
179 | " temp = torch.cat([sig[0,0], noise]).unsqueeze(0)\n",
180 | " while(any([(temp == c_).all() for c_ in temp_signals]) or any([(temp == c_).all() for c_ in training_signals])):\n",
181 | " noise = ttype(generate_noise())\n",
182 | " temp = torch.cat([sig[0,0], noise]).unsqueeze(0)\n",
183 | " temp_signals.append(temp)\n",
184 | " temp_class.append(i)\n",
185 | " testing_signals.extend(temp_signals)\n",
186 | " testing_class.extend(temp_class)\n",
187 | "batch_rand = torch.randperm(testing_samples*signals.shape[0])\n",
188 | "\n",
189 | "testing_signals = torch.cat(testing_signals).cuda().unsqueeze(-1)[batch_rand]\n",
190 | "testing_class = ctype(testing_class).cuda().unsqueeze(-1)[batch_rand]\n",
191 | "\n",
192 | "\n",
193 | "dataset_valid = torch.utils.data.TensorDataset(testing_signals, testing_class)\n",
194 | "dataset_valid = torch.utils.data.DataLoader(dataset_valid, batch_size=32, shuffle=False)\n"
195 | ]
196 | },
197 | {
198 | "cell_type": "code",
199 | "execution_count": null,
200 | "metadata": {
201 | "ExecuteTime": {
202 | "end_time": "2021-05-12T18:04:57.496406Z",
203 | "start_time": "2021-05-12T18:04:57.489480Z"
204 | }
205 | },
206 | "outputs": [],
207 | "source": [
208 | "class LMUModel(nn.Module):\n",
209 | " def __init__(self, n_out, layer_params):\n",
210 | " super(LMUModel, self).__init__()\n",
211 | " self.layers = nn.ModuleList([LegendreMemoryUnit(**layer_params[i])\n",
212 | " for i in range(len(layer_params))])\n",
213 | " self.dense = nn.Linear(layer_params[-1]['hidden_size'], n_out)\n",
214 | "\n",
215 | " \n",
216 | " def forward(self, x):\n",
217 | " for l in self.layers:\n",
218 | " x, _ = l(x) \n",
219 | " x = self.dense(x)\n",
220 | " return x"
221 | ]
222 | },
223 | {
224 | "cell_type": "code",
225 | "execution_count": null,
226 | "metadata": {
227 | "ExecuteTime": {
228 | "end_time": "2021-05-12T18:04:57.965310Z",
229 | "start_time": "2021-05-12T18:04:57.950810Z"
230 | }
231 | },
232 | "outputs": [],
233 | "source": [
234 | "def train(model, ttype, train_loader, test_loader, optimizer, loss_func, epoch, perf_file,\n",
235 | " permute=None, loss_buffer_size=64, batch_size=4, device='cuda',\n",
236 | " prog_bar=None, maxn=6):\n",
237 | " \n",
238 | " assert(loss_buffer_size%batch_size==0)\n",
239 | " \n",
240 | " losses = []\n",
241 | " perfs = []\n",
242 | " last_test_perf = 0\n",
243 | " best_test_perf = -1\n",
244 | " \n",
245 | " for batch_idx, (data, target) in enumerate(train_loader):\n",
246 | " model.train()\n",
247 | " data = data.to(device)\n",
248 | " target = target.to(device)\n",
249 | " optimizer.zero_grad()\n",
250 | " out = model(data)\n",
251 | " loss = loss_func(out[:,-1],\n",
252 | " target[:, 0])\n",
253 | " \n",
254 | " loss.backward()\n",
255 | " optimizer.step()\n",
256 | "\n",
257 | " perfs.append((torch.argmax(out[:,-1], dim=-1) == \n",
258 | " target[:, 0]).sum().item())\n",
259 | " perfs = perfs[int(-loss_buffer_size/batch_size):]\n",
260 | " losses.append(loss.detach().cpu().numpy())\n",
261 | " losses = losses[int(-loss_buffer_size/batch_size):]\n",
262 | " if not (prog_bar is None):\n",
263 | " # Update progress_bar\n",
264 | " s = \"{}:{} Loss: {:.4f}, perf: {:.4f}, valid: {:.4f}\"\n",
265 | " format_list = [e,batch_idx*batch_size, np.mean(losses), \n",
266 | " np.sum(perfs)/((len(perfs))*batch_size), last_test_perf] \n",
267 | " s = s.format(*format_list)\n",
268 | " prog_bar.set_description(s)\n",
269 | " \n",
270 | " if ((batch_idx*batch_size)%loss_buffer_size == 0) & (batch_idx != 0):\n",
271 | " loss_track = {}\n",
272 | " last_test_perf = test(model, 'cuda', test_loader, \n",
273 | " batch_size=batch_size, \n",
274 | " permute=permute)\n",
275 | " loss_track['avg_loss'] = np.mean(losses)\n",
276 | " loss_track['last_test'] = last_test_perf\n",
277 | " loss_track['epoch'] = epoch\n",
278 | " loss_track['maxn'] = maxn\n",
279 | " loss_track['batch_idx'] = batch_idx\n",
280 | " loss_track['pres_num'] = batch_idx*batch_size + epoch*len(train_loader.dataset)\n",
281 | " loss_track['train_perf']= np.sum(perfs)/((len(perfs))*batch_size)\n",
282 | " with open(perf_file, 'a+') as fp:\n",
283 | " csv_writer = DictWriter(fp, fieldnames=list(loss_track.keys()))\n",
284 | " if fp.tell() == 0:\n",
285 | " csv_writer.writeheader()\n",
286 | " csv_writer.writerow(loss_track)\n",
287 | " fp.flush()\n",
288 | " if best_test_perf < last_test_perf:\n",
289 | " torch.save(model.state_dict(), perf_file[:-4]+\".pt\")\n",
290 | " best_test_perf = last_test_perf\n",
291 | "\n",
292 | " \n",
293 | "def test(model, device, test_loader, batch_size=4, permute=None):\n",
294 | " model.eval()\n",
295 | " correct = 0\n",
296 | " count = 0\n",
297 | " with torch.no_grad():\n",
298 | " for data, target in test_loader:\n",
299 | " data = data.to(device)\n",
300 | " target = target.to(device)\n",
301 | " \n",
302 | " out = model(data)\n",
303 | " pred = out[:,-1].argmax(dim=-1, keepdim=True)\n",
304 | " correct += pred.eq(target.view_as(pred)).sum().item()\n",
305 | " count += 1\n",
306 | " return correct / len(test_loader.dataset)"
307 | ]
308 | },
309 | {
310 | "cell_type": "markdown",
311 | "metadata": {},
312 | "source": [
313 | "# Training and testing"
314 | ]
315 | },
316 | {
317 | "cell_type": "code",
318 | "execution_count": null,
319 | "metadata": {
320 | "ExecuteTime": {
321 | "end_time": "2021-05-12T18:05:02.299619Z",
322 | "start_time": "2021-05-12T18:05:02.293735Z"
323 | }
324 | },
325 | "outputs": [],
326 | "source": [
327 | "# You likely don't need this to be this long, but just in case.\n",
328 | "epochs = 400\n",
329 | "\n",
330 | "# Just for visualizing average loss through time. \n",
331 | "loss_buffer_size = 100"
332 | ]
333 | },
334 | {
335 | "cell_type": "code",
336 | "execution_count": null,
337 | "metadata": {
338 | "ExecuteTime": {
339 | "end_time": "2021-05-12T18:05:07.312643Z",
340 | "start_time": "2021-05-12T18:05:03.479007Z"
341 | },
342 | "scrolled": true
343 | },
344 | "outputs": [],
345 | "source": [
346 | "test_noise_lengths = [6,7,9,13,21,37]\n",
347 | "for maxn in test_noise_lengths:\n",
348 | " torch.manual_seed(12345)\n",
349 | " np.random.seed(12345)\n",
350 | " training_samples = 32\n",
351 | "\n",
352 | " training_signals = []\n",
353 | " training_class = []\n",
354 | "\n",
355 | " for i, sig in enumerate(signals):\n",
356 | " temp_signals = []\n",
357 | " temp_class = []\n",
358 | " for x in range(training_samples):\n",
359 | " noise = ttype(generate_noise(maxn))\n",
360 | " temp = torch.cat([sig[0,0], noise]).unsqueeze(0)\n",
361 | " while(any([(temp == c_).all() for c_ in temp_signals])):\n",
362 | " noise = ttype(generate_noise(maxn))\n",
363 | " temp = torch.cat([sig[0,0], noise]).unsqueeze(0)\n",
364 | " temp_signals.append(temp)\n",
365 | " temp_class.append(i)\n",
366 | " training_signals.extend(temp_signals)\n",
367 | " training_class.extend(temp_class)\n",
368 | "\n",
369 | " batch_rand = torch.randperm(training_samples*signals.shape[0]) \n",
370 | " training_signals = torch.cat(training_signals).cuda().unsqueeze(-1)[batch_rand]\n",
371 | " training_class = ctype(training_class).cuda().unsqueeze(-1)[batch_rand]\n",
372 | "\n",
373 | " dataset = torch.utils.data.TensorDataset(training_signals, training_class)\n",
374 | " dataset = torch.utils.data.DataLoader(dataset, batch_size=32, shuffle=True)\n",
375 | " testing_samples = 10\n",
376 | " testing_signals = []\n",
377 | " testing_class = []\n",
378 | "\n",
379 | " for i, sig in enumerate(signals):\n",
380 | " temp_signals = []\n",
381 | " temp_class = []\n",
382 | " for x in range(testing_samples):\n",
383 | " noise = ttype(generate_noise(maxn))\n",
384 | " temp = torch.cat([sig[0,0], noise]).unsqueeze(0)\n",
385 | " while(any([(temp == c_).all() for c_ in temp_signals]) or any([(temp == c_).all() for c_ in training_signals])):\n",
386 | " noise = ttype(generate_noise(maxn))\n",
387 | " temp = torch.cat([sig[0,0], noise]).unsqueeze(0)\n",
388 | " temp_signals.append(temp)\n",
389 | " temp_class.append(i)\n",
390 | " testing_signals.extend(temp_signals)\n",
391 | " testing_class.extend(temp_class)\n",
392 | " batch_rand = torch.randperm(testing_samples*signals.shape[0])\n",
393 | "\n",
394 | " testing_signals = torch.cat(testing_signals).cuda().unsqueeze(-1)[batch_rand]\n",
395 | " testing_class = ctype(testing_class).cuda().unsqueeze(-1)[batch_rand]\n",
396 | "\n",
397 | "\n",
398 | " dataset_valid = torch.utils.data.TensorDataset(testing_signals, testing_class)\n",
399 | " dataset_valid = torch.utils.data.DataLoader(dataset_valid, batch_size=32, shuffle=False)\n",
400 | "\n",
401 | " hz=125\n",
402 | "\n",
403 | " lmu_params = [dict(input_dim=1, hidden_size=hz, order=40, theta=temp.shape[-1]),\n",
404 | " #dict(input_dim=hz, hidden_size=hz, order=4, theta=4),\n",
405 | " #dict(input_dim=hz, hidden_size=hz, order=4, theta=4),\n",
406 | " ]\n",
407 | " model = LMUModel(8, lmu_params).cuda()\n",
408 | "\n",
409 | " tot_weights = 0\n",
410 | " for p in model.parameters():\n",
411 | " tot_weights += p.numel()\n",
412 | " print(\"Total Weights:\", tot_weights)\n",
413 | " print(model)\n",
414 | " loss_func = torch.nn.CrossEntropyLoss()\n",
415 | " optimizer = torch.optim.Adam(model.parameters())\n",
416 | " epochs = 400\n",
417 | " batch_size = 32\n",
418 | " progress_bar = tqdm(range(int(epochs)), bar_format='{l_bar}{bar:5}{r_bar}{bar:-5b}')\n",
419 | " for e in progress_bar:\n",
420 | " train(model, ttype, dataset, dataset_valid, \n",
421 | " optimizer, loss_func, batch_size=batch_size,\n",
422 | " epoch=e, perf_file=join('perf','h8_LMU_length_6.csv'),\n",
423 | " prog_bar=progress_bar, maxn=maxn)\n",
424 | " "
425 | ]
426 | },
427 | {
428 | "cell_type": "code",
429 | "execution_count": null,
430 | "metadata": {},
431 | "outputs": [],
432 | "source": []
433 | }
434 | ],
435 | "metadata": {
436 | "kernelspec": {
437 | "display_name": "Python 3",
438 | "language": "python",
439 | "name": "python3"
440 | },
441 | "language_info": {
442 | "codemirror_mode": {
443 | "name": "ipython",
444 | "version": 3
445 | },
446 | "file_extension": ".py",
447 | "mimetype": "text/x-python",
448 | "name": "python",
449 | "nbconvert_exporter": "python",
450 | "pygments_lexer": "ipython3",
451 | "version": "3.6.10"
452 | },
453 | "toc": {
454 | "nav_menu": {},
455 | "number_sections": true,
456 | "sideBar": true,
457 | "skip_h1_title": false,
458 | "title_cell": "Table of Contents",
459 | "title_sidebar": "Contents",
460 | "toc_cell": false,
461 | "toc_position": {},
462 | "toc_section_display": true,
463 | "toc_window_display": false
464 | }
465 | },
466 | "nbformat": 4,
467 | "nbformat_minor": 4
468 | }
469 |
--------------------------------------------------------------------------------