├── Bipartite_small_world_network.pdf
├── LICENSE
├── README.md
├── setup.py
├── sparselinear
├── __init__.py
├── activationsparsity.py
└── sparselinear.py
└── tutorials
└── SparseLinearDemo.ipynb
/Bipartite_small_world_network.pdf:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/hyeon95y/SparseLinear/35b4a9cf843ecc9ce560b37615f369fdc6aad4cf/Bipartite_small_world_network.pdf
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | Copyright (c) 2020 Rain Neuromorphics Inc.
2 |
3 | Permission is hereby granted, free of charge, to any person obtaining a copy
4 | of this software and associated documentation files (the "Software"), to deal
5 | in the Software without restriction, including without limitation the rights
6 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
7 | copies of the Software, and to permit persons to whom the Software is
8 | furnished to do so, subject to the following conditions:
9 |
10 | The above copyright notice and this permission notice shall be included in all
11 | copies or substantial portions of the Software.
12 |
13 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
14 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
15 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
16 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
17 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
18 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
19 | SOFTWARE.
20 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # SparseLinear
2 |
3 | SparseLinear is a PyTorch package that allows a user to create extremely wide and sparse linear layers efficiently. A sparsely connected network is a network where each node is connected to a fraction of available nodes. This differs from a fully connected network, where each node in one layer is connected to every node in the next layer.
4 |
5 | The provided layer along with the dynamic activation sparsity module is compatible with backpropagation. The sparse linear layer is initialized with sparsity, supports unstructured sparsity and allows dynamic growth and pruning. We achieve this by building a linear layer on top of [PyTorch Sparse](https://github.com/rusty1s/pytorch_sparse), which provides optimized sparse matrix operations with autograd support in PyTorch.
6 |
7 | ## Table of Contents
8 |
9 | - [More about SparseLinear](#intro)
10 | - [More about Dynamic Activation](#kwin)
11 | - [Installation](#install)
12 | - [Getting Started](#usage)
13 |
14 | ## More about SparseLinear
15 | The default arguments initialize a sparse linear layer with random connections that applies a linear transformation to the incoming data
16 |
17 | #### Parameters
18 |
19 | - **in_features** - size of each input sample
20 | - **out_features** - size of each output sample
21 | - **bias** - If set to ``False``, the layer will not learn an additive bias. Default: ``True``
22 | - **sparsity** - sparsity of weight matrix. Default: `0.9`
23 | - **connectivity** - user-defined sparsity matrix. Default: `None`
24 | - **small_world** - boolean flag to generate small-world sparsity. Default: ``False``
25 | - **dynamic** - boolean flag to dynamically change the network structure. Default: ``False``
26 | - **deltaT** - frequency for growing and pruning update step. Default: `6000`
27 | - **Tend** - stopping time for growing and pruning algorithm update step. Default: `150000`
28 | - **alpha** - f-decay parameter for cosine updates. Default: `0.1`
29 | - **max_size** - maximum number of entries allowed before chunking occurrs for small-world network generation and dynamic connections. Default: `1e8`
30 |
31 | #### Shape
32 |
33 | - Input: `(N, *, H_{in})` where `*` means any number of additional dimensions and `H_{in} = in_features`
34 | - Output: `(N, *, H_{out})` where all but the last dimension are the same shape as the input and `H_{out} = out_features`
35 |
36 | #### Variables
37 |
38 | - **~SparseLinear.weight** - the learnable weights of the module of shape `(out_features, in_features)`. The values are initialized from
, where
39 | - **~SparseLinear.bias** - the learnable bias of the module of shape `(out_features)`. If `bias` is ``True``, the values are initialized from
where
40 |
41 | #### Examples:
42 |
43 | ```python
44 | >>> m = sl.SparseLinear(20, 30)
45 | >>> input = torch.randn(128, 20)
46 | >>> output = m(input)
47 | >>> print(output.size())
48 | torch.Size([128, 30])
49 | ```
50 |
51 | The following customization can also be done using appropriate arguments -
52 |
53 | #### User-defined Sparsity
54 |
55 | One can choose to add self-defined static sparsity. The `connectivity` flag accepts a (2, nnz) LongTensor that represents the rows and columns of nonzero elements in the layer.
56 |
57 | #### Small-world Sparsity
58 |
59 | The default static sparsity is random. With this flag, one can instead use small-world sparsity. See [here](https://en.wikipedia.org/wiki/Small-world_network). To specify, set `small_world` to `True`. Specifically, we make connections distance-dependent to ensure small-world behavior.
60 |
61 | #### Dynamic Growing and Pruning Algorithm
62 |
63 | The user can grow and prune units during training starting from a sparse configuration using this feature. The implementation is based on [Rigging the lottery](https://arxiv.org/pdf/1911.11134.pdf) algorithm. Specify `dynamic` to be `True` to dynamically alter the layer connections while training.
64 |
65 | ## Dynamic Activation Sparsity
66 |
67 | In addition, we provide a Dynamic Activation Sparsity module to utilize principled, per-layer activation sparsity. The algorithm implementation is based on the [K-Winners strategy](https://arxiv.org/pdf/1903.11257.pdf).
68 |
69 | #### Parameters
70 |
71 | - **alpha** - constant used in updating duty-cycle. Default: `0.1`
72 | - **beta** - boosting factor for neurons not activated in the previous duty cycle. Default: `1.5`
73 | - **act_sparsity** - fraction of the input used in calculating K for K-Winners strategy. Default: `0.65`
74 |
75 | #### Shape
76 |
77 | - Input: `(N, *)` where `*` means, any number of additional dimensions
78 | - Output: `(N, *)`, same shape as the input
79 |
80 | #### Examples:
81 |
82 | ```python
83 | >>> x = asy.ActivationSparsity(10)
84 | >>> input = torch.randn(3,10)
85 | >>> output = x(input)
86 | ```
87 |
88 | ## Installation
89 |
90 | - Follow the installation instructions and install PyTorch Sparse package from [here](https://github.com/rusty1s/pytorch_sparse).
91 | - Then run ```pip install sparselinear```
92 |
93 | ## Getting Started
94 |
95 | We provide a Jupyter notebook in [this](https://github.com/rain-neuromorphics/SparseLinear/blob/master/tutorials/SparseLinearDemo.ipynb) repository that demonstrates the basic functionalities of the sparse linear layer. We also show steps to train various models using the additional features of this package.
96 |
--------------------------------------------------------------------------------
/setup.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 |
3 | from setuptools import setup, find_packages
4 |
5 | with open("README.md", "r") as fh:
6 | long_description = fh.read()
7 |
8 | install_requires = ['numpy', 'torch']
9 | setup(name='sparselinear',
10 | version='0.0.5',
11 | description='Pytorch extension library for creating sparse linear layers',
12 | long_description=long_description,
13 | long_description_content_type="text/markdown",
14 | author='Rain Neuromorphics',
15 | author_email='ross@rain-neuromorphics.com',
16 | url='https://github.com/rain-neuromorphics/SparseLinear',
17 | keywords=['pytorch', 'sparse', 'linear'],
18 | license='MIT',
19 | install_requires=install_requires,
20 | packages=find_packages(),
21 | python_requires='>=3.6',
22 | )
23 |
--------------------------------------------------------------------------------
/sparselinear/__init__.py:
--------------------------------------------------------------------------------
1 | from .sparselinear import SparseLinear
2 | from .activationsparsity import ActivationSparsity
3 |
4 |
--------------------------------------------------------------------------------
/sparselinear/activationsparsity.py:
--------------------------------------------------------------------------------
1 | import math
2 | import torch
3 | import torch.nn as nn
4 |
5 | class ActivationSparsity(nn.Module):
6 | """Applies activation sparsity to the last dimension of input using K-winners strategy
7 |
8 | Args:
9 | alpha (float): constant used in updating duty-cycle
10 | Default: 0.1
11 | beta (float): boosting factor for neurons not activated in the previous duty cycle
12 | Default: 1.5
13 | act_sparsity (float): fraction of the input used in calculating K for K-Winners strategy
14 | Default: 0.65
15 |
16 | Shape:
17 | - Input: :math:`(N, *)` where `*` means, any number of additional dimensions
18 | - Output: :math:`(N, *)`, same shape as the input
19 |
20 | Examples::
21 |
22 | >>> x = asy.ActivationSparsity(10)
23 | >>> input = torch.randn(3,10)
24 | >>> output = x(input)
25 | """
26 | def __init__(self, alpha=0.1, beta=1.5, act_sparsity=0.65):
27 | super(ActivationSparsity, self).__init__()
28 | self.alpha = alpha
29 | self.beta = beta
30 | self.act_sparsity = act_sparsity
31 | self.duty_cycle = None
32 |
33 | def updateDC(self, inputs, duty_cycle):
34 | duty_cycle = (1 - self.alpha) * duty_cycle + self.alpha * (inputs.gt(0).sum(dim=0,dtype=torch.float))
35 | return duty_cycle
36 |
37 | def forward(self, inputs):
38 | in_features = inputs.shape[-1]
39 | out_shape=list(inputs.shape)
40 | inputs = inputs.reshape(inputs.shape[0],-1)
41 |
42 | device = inputs.device
43 |
44 | if self.duty_cycle is None:
45 | self.duty_cycle = torch.zeros(in_features, requires_grad=True).to(device)
46 |
47 | k = math.floor((1-self.act_sparsity) * in_features)
48 | with torch.no_grad():
49 |
50 | target = k / inputs.shape[-1]
51 | boost_coefficient = torch.exp(self.beta * (target - self.duty_cycle))
52 | boosted_input = inputs * boost_coefficient
53 |
54 | # Get top k values
55 | values, indices = boosted_input.topk( k, dim=-1, sorted=False)
56 | row_indices = torch.arange(inputs.shape[0]).repeat_interleave(k).view(-1,k)
57 |
58 | outputs = torch.zeros_like(inputs).to(device)
59 | outputs = outputs.index_put((row_indices, indices), inputs[row_indices, indices], accumulate=False)
60 |
61 | if self.training:
62 | with torch.no_grad():
63 | self.duty_cycle = self.updateDC(outputs, self.duty_cycle)
64 |
65 | return outputs.view(out_shape)
66 |
67 | def extra_repr(self):
68 | return 'act_sparsity={}, alpha={}, beta={}, duty_cycle={}'.format(
69 | self.act_sparsity, self.alpha, self.beta, self.duty_cycle
70 | )
--------------------------------------------------------------------------------
/sparselinear/sparselinear.py:
--------------------------------------------------------------------------------
1 | import math
2 | import warnings
3 |
4 | import numpy as np
5 | import torch
6 | import torch.nn as nn
7 |
8 |
9 | def small_world_chunker(inputs, outputs, nnz):
10 | """Utility function for small world initialization as presented in the write up Bipartite_small_world_network"""
11 | pair_distance = inputs.view(-1, 1) - outputs
12 | arg = torch.abs(pair_distance) + 1.0
13 |
14 | # lambda search
15 | L, U = 1e-5, 5.0
16 | lamb = 1.0 # initial guess
17 | itr = 1
18 | error_threshold = 10.0
19 | max_itr = 1000
20 | P = arg ** (-lamb)
21 | P_sum = P.sum()
22 | error = abs(P_sum - nnz)
23 |
24 | while error > error_threshold:
25 | assert (
26 | itr <= max_itr
27 | ), "No solution found; please try different network sizes and sparsity levels"
28 | if P_sum < nnz:
29 | U = lamb
30 | lamb = (lamb + L) / 2.0
31 | elif P_sum > nnz:
32 | L = lamb
33 | lamb = (lamb + U) / 2.0
34 |
35 | P = arg ** (-lamb)
36 | P_sum = P.sum()
37 | error = abs(P_sum - nnz)
38 | itr += 1
39 | return P
40 |
41 |
42 | class GrowConnections(torch.autograd.Function):
43 | """ Custom pytorch function to handle growing connections"""
44 |
45 | @staticmethod
46 | def forward(ctx, inputs, weights, k, indices, features, max_size):
47 | out_features, in_features = features
48 | output_shape = list(inputs.shape)
49 | output_shape[-1] = out_features
50 | if len(output_shape) == 1:
51 | inputs = inputs.view(1, -1)
52 | inputs = inputs.flatten(end_dim=-2)
53 |
54 | # output = torch_sparse.spmm(indices, weights, out_features, in_features, inputs.t()).t()
55 | target = torch.sparse.FloatTensor(
56 | indices, weights, torch.Size([out_features, in_features]),
57 | ).to_dense()
58 | output = torch.mm(target, inputs.t()).t()
59 |
60 | ctx.save_for_backward(inputs, weights, indices)
61 | ctx.in1 = k
62 | ctx.in2 = out_features
63 | ctx.in3 = in_features
64 | ctx.in4 = max_size
65 |
66 | return output
67 |
68 | @staticmethod
69 | def backward(ctx, grad_output):
70 | inputs, weights, indices = ctx.saved_tensors
71 | k = ctx.in1
72 | out_features = ctx.in2
73 | in_features = ctx.in3
74 | max_size = ctx.in4
75 |
76 | device = grad_output.device
77 | p_index = torch.LongTensor([1, 0])
78 | new_indices = torch.zeros_like(indices).to(device=device)
79 | new_indices[p_index] = indices
80 |
81 | # grad_input = torch_sparse.spmm(new_indices, weights, in_features, out_features, grad_output.t()).t()
82 | target = torch.sparse.FloatTensor(
83 | new_indices, weights, torch.Size([in_features, out_features]),
84 | ).to_dense()
85 | grad_input = torch.mm(target, grad_outputs.t()).t()
86 |
87 | if in_features * out_features <= max_size:
88 | grad_weights = torch.matmul(inputs.t(), grad_output)
89 | grad_weights = torch.abs(grad_weights.t())
90 | mask = torch.ones_like(grad_weights)
91 | mask[indices[0], indices[1]] = 0
92 |
93 | masked_weights = mask * grad_weights
94 | _, lm_indices = torch.topk(masked_weights.reshape(-1), k, sorted=False)
95 | row = lm_indices.floor_divide(in_features)
96 | col = lm_indices.fmod(in_features)
97 | else:
98 | tk = None
99 | m = max_size / in_features
100 | chunks = math.ceil(out_features / m)
101 |
102 | for item in range(chunks):
103 | if item != chunks - 1:
104 | sliced_input = inputs.t()[item * m : (item + 1) * m, :]
105 | grad_m = torch.matmul(sliced_input, grad_output).t()
106 | grad_m_abs = torch.abs(grad_m)
107 | topk_values, topk_indices = torch.topk(
108 | grad_m_abs.view(-1), k, sorted=False,
109 | )
110 | else:
111 | grad_m = torch.matmul(inputs.t()[item * m :, :], grad_output).t()
112 | grad_m_abs = torch.abs(grad_m)
113 | topk_values, topk_indices = torch.topk(
114 | grad_m_abs.view(-1), k, sorted=False,
115 | )
116 |
117 | row = (
118 | topk_indices.floor_divide(in_features)
119 | + torch.ones_like(topk_indices) * item * m
120 | )
121 | col = topk_indices.fmod(in_features)
122 | indices = torch.stack((row, col))
123 |
124 | if tk is None:
125 | tk = torch.cat((topk_values, indices), dim=0)
126 | else:
127 | topk_values_prev = tk[0]
128 | concat_values = torch.cat(
129 | (topk_values_prev, topk_values), dim=1,
130 | ).view(-1)
131 | topk_values_2k, topk_indices_2k = torch.topk(
132 | concat_values, k, sorted=False,
133 | )
134 |
135 | # Get the topk indices from the combination of two indices
136 | topk_prev = topk_indices_2k[topk_indices_2k < k]
137 | topk_values_indices = tk[:][topk_prev]
138 |
139 | topk_curr = topk_indices_2k[topk_indices_2k >= k]
140 | topk_curr = topk_curr % k
141 |
142 | curr_indices = indices[:][topk_curr]
143 | curr_values = topk_values[topk_curr]
144 |
145 | indices_values = torch.cat((curr_indices, curr_values), dim=0)
146 | tk = torch.cat((topk_values_indices, indices_values), dim=1)
147 | row = tk[1]
148 | col = tk[2]
149 |
150 | new_indices = torch.stack((row, col))
151 | x = torch.cat((indices[:, :-k], new_indices), dim=1)
152 |
153 | if indices.shape[1] > x.shape[1]:
154 | diff = indices.shape[1] - x.shape[1]
155 | new_entries = torch.zeros((2, diff), dtype=torch.long).to(device=device)
156 | x = torch.cat((x, new_entries), dim=1)
157 |
158 | indices.copy_(x)
159 |
160 | return grad_input, None, None, None, None, None
161 |
162 |
163 | class SparseLinear(nn.Module):
164 | """Applies a linear transformation to the incoming data: :math:`y = xA^T + b`
165 |
166 | Args:
167 | in_features: size of each input sample
168 | out_features: size of each output sample
169 | bias: If set to ``False``, the layer will not learn an additive bias.
170 | Default: ``True``
171 | sparsity: sparsity of weight matrix
172 | Default: 0.9
173 | connectivity: user defined sparsity matrix
174 | Default: None
175 | small_world: boolean flag to generate small world sparsity
176 | Default: ``False``
177 | dynamic: boolean flag to dynamically change the network structure
178 | Default: ``False``
179 | deltaT (int): frequency for growing and pruning update step
180 | Default: 6000
181 | Tend (int): stopping time for growing and pruning algorithm update step
182 | Default: 150000
183 | alpha (float): f-decay parameter for cosine updates
184 | Default: 0.1
185 | max_size (int): maximum number of entries allowed before chunking occurrs
186 | Default: 1e8
187 |
188 | Shape:
189 | - Input: :math:`(N, *, H_{in})` where :math:`*` means any number of
190 | additional dimensions and :math:`H_{in} = \text{in\_features}`
191 | - Output: :math:`(N, *, H_{out})` where all but the last dimension
192 | are the same shape as the input and :math:`H_{out} = \text{out\_features}`.
193 |
194 | Attributes:
195 | weight: the learnable weights of the module of shape
196 | :math:`(\text{out\_features}, \text{in\_features})`. The values are
197 | initialized from :math:`\mathcal{U}(-\sqrt{k}, \sqrt{k})`, where
198 | :math:`k = \frac{1}{\text{in\_features}}`
199 | bias: the learnable bias of the module of shape :math:`(\text{out\_features})`.
200 | If :attr:`bias` is ``True``, the values are initialized from
201 | :math:`\mathcal{U}(-\sqrt{k}, \sqrt{k})` where
202 | :math:`k = \frac{1}{\text{in\_features}}`
203 |
204 | Examples:
205 |
206 | >>> m = nn.SparseLinear(20, 30)
207 | >>> input = torch.randn(128, 20)
208 | >>> output = m(input)
209 | >>> print(output.size())
210 | torch.Size([128, 30])
211 | """
212 |
213 | def __init__(
214 | self,
215 | in_features,
216 | out_features,
217 | bias=True,
218 | sparsity=0.9,
219 | connectivity=None,
220 | small_world=False,
221 | dynamic=False,
222 | deltaT=6000,
223 | Tend=150000,
224 | alpha=0.1,
225 | max_size=1e8,
226 | ):
227 | assert in_features < 2 ** 31 and out_features < 2 ** 31 and sparsity < 1.0
228 | assert (
229 | connectivity is None or not small_world
230 | ), "Cannot specify connectivity along with small world sparsity"
231 | if connectivity is not None:
232 | assert isinstance(connectivity, torch.LongTensor) or isinstance(
233 | connectivity, torch.cuda.LongTensor,
234 | ), "Connectivity must be a Long Tensor"
235 | assert (
236 | connectivity.shape[0] == 2 and connectivity.shape[1] > 0
237 | ), "Input shape for connectivity should be (2,nnz)"
238 | assert (
239 | connectivity.shape[1] <= in_features * out_features
240 | ), "Nnz can't be bigger than the weight matrix"
241 | super(SparseLinear, self).__init__()
242 | self.in_features = in_features
243 | self.out_features = out_features
244 | self.connectivity = connectivity
245 | self.small_world = small_world
246 | self.dynamic = dynamic
247 | self.max_size = max_size
248 |
249 | # Generate and coalesce indices : Faster to coalesce on GPU
250 | coalesce_device = (
251 | torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
252 | )
253 |
254 | if not small_world:
255 | if connectivity is None:
256 | self.sparsity = sparsity
257 | nnz = round((1.0 - sparsity) * in_features * out_features)
258 | if in_features * out_features <= 10 ** 8:
259 | indices = np.random.choice(
260 | in_features * out_features, nnz, replace=False,
261 | )
262 | indices = torch.as_tensor(indices, device=coalesce_device)
263 | row_ind = indices.floor_divide(in_features)
264 | col_ind = indices.fmod(in_features)
265 | else:
266 | warnings.warn(
267 | "Matrix too large to sample non-zero indices without replacement, sparsity will be approximate",
268 | RuntimeWarning,
269 | )
270 | row_ind = torch.randint(
271 | 0, out_features, (nnz,), device=coalesce_device,
272 | )
273 | col_ind = torch.randint(
274 | 0, in_features, (nnz,), device=coalesce_device,
275 | )
276 | indices = torch.stack((row_ind, col_ind))
277 | else:
278 | # User defined sparsity
279 | nnz = connectivity.shape[1]
280 | self.sparsity = 1.0 - nnz / (out_features * in_features)
281 | connectivity = connectivity.to(device=coalesce_device)
282 | indices = connectivity
283 | else:
284 | # Generate small world sparsity
285 | self.sparsity = sparsity
286 | nnz = round((1.0 - sparsity) * in_features * out_features)
287 | assert nnz > min(
288 | in_features, out_features,
289 | ), "The matrix is too sparse for small-world algorithm; please decrease sparsity"
290 | offset = abs(out_features - in_features) / 2.0
291 |
292 | # Node labels
293 | inputs = torch.arange(
294 | 1 + offset * (out_features > in_features),
295 | in_features + 1 + offset * (out_features > in_features),
296 | device=coalesce_device,
297 | )
298 | outputs = torch.arange(
299 | 1 + offset * (out_features < in_features),
300 | out_features + 1 + offset * (out_features < in_features),
301 | device=coalesce_device,
302 | )
303 |
304 | # Creating chunks for small world algorithm
305 | total_data = in_features * out_features # Total params
306 | chunks = math.ceil(total_data / self.max_size)
307 | split_div = max(in_features, out_features) // chunks # Full chunks
308 | split_mod = max(in_features, out_features) % chunks # Remaining chunk
309 | idx = (
310 | torch.repeat_interleave(torch.Tensor([split_div]), chunks)
311 | .int()
312 | .to(device=coalesce_device)
313 | )
314 | idx[:split_mod] += 1
315 | idx = torch.cumsum(idx, dim=0)
316 | idx = torch.cat([torch.LongTensor([0]).to(device=coalesce_device), idx])
317 |
318 | count = 0
319 |
320 | rows = torch.empty(0).long().to(device=coalesce_device)
321 | cols = torch.empty(0).long().to(device=coalesce_device)
322 |
323 | for i in range(chunks):
324 | inputs_ = (
325 | inputs[idx[i] : idx[i + 1]]
326 | if out_features <= in_features
327 | else inputs
328 | )
329 | outputs_ = (
330 | outputs[idx[i] : idx[i + 1]]
331 | if out_features > in_features
332 | else outputs
333 | )
334 |
335 | y = small_world_chunker(inputs_, outputs_, round(nnz / chunks))
336 | ref = torch.rand_like(y)
337 |
338 | # Refer to Eq.7 from Bipartite_small_world_network write-up
339 | mask = torch.empty(y.shape, dtype=bool).to(device=coalesce_device)
340 | mask[y < ref] = False
341 | mask[y >= ref] = True
342 |
343 | rows_, cols_ = mask.to_sparse().indices()
344 |
345 | rows = torch.cat([rows, rows_ + idx[i]])
346 | cols = torch.cat([cols, cols_])
347 |
348 | indices = torch.stack((cols, rows))
349 | nnz = indices.shape[1]
350 |
351 | values = torch.empty(nnz, device=coalesce_device)
352 | # indices, values = torch_sparse.coalesce(indices, values, out_features, in_features)
353 |
354 | self.register_buffer("indices", indices.cpu())
355 | self.weights = nn.Parameter(values.cpu())
356 |
357 | if bias:
358 | self.bias = nn.Parameter(torch.Tensor(out_features))
359 | else:
360 | self.register_parameter("bias", None)
361 |
362 | if self.dynamic:
363 | self.deltaT = deltaT
364 | self.Tend = Tend
365 | self.alpha = alpha
366 | self.itr_count = 0
367 |
368 | self.reset_parameters()
369 |
370 | def reset_parameters(self):
371 | bound = 1 / self.in_features ** 0.5
372 | nn.init.uniform_(self.weights, -bound, bound)
373 | if self.bias is not None:
374 | nn.init.uniform_(self.bias, -bound, bound)
375 |
376 | @property
377 | def weight(self):
378 | """ returns a torch.sparse.FloatTensor view of the underlying weight matrix
379 | This is only for inspection purposes and should not be modified or used in any autograd operations
380 | """
381 | weight = torch.sparse.FloatTensor(
382 | self.indices, self.weights, (self.out_features, self.in_features),
383 | )
384 | return weight.coalesce().detach()
385 |
386 | def forward(self, inputs):
387 | if self.training and self.dynamic:
388 | self.itr_count += 1
389 | output_shape = list(inputs.shape)
390 | output_shape[-1] = self.out_features
391 |
392 | # Handle dynamic sparsity
393 | if (
394 | self.training
395 | and self.dynamic
396 | and self.itr_count < self.Tend
397 | and self.itr_count % self.deltaT == 0
398 | ):
399 | # Drop criterion
400 | f_decay = (
401 | self.alpha * (1 + math.cos(self.itr_count * math.pi / self.Tend)) / 2
402 | )
403 | k = int(f_decay * (1 - self.sparsity) * self.weights.view(-1, 1).shape[0])
404 | n = self.weights.shape[0]
405 |
406 | neg_weights = -1 * torch.abs(self.weights)
407 | _, lm_indices = torch.topk(neg_weights, n - k, largest=False, sorted=False)
408 |
409 | self.indices = torch.index_select(self.indices, 1, lm_indices)
410 | self.weights = nn.Parameter(torch.index_select(self.weights, 0, lm_indices))
411 |
412 | device = inputs.device
413 | # Growth criterion
414 | new_weights = torch.zeros(k).to(device=device)
415 | self.weights = nn.Parameter(torch.cat((self.weights, new_weights), dim=0))
416 |
417 | new_indices = torch.zeros((2, k), dtype=torch.long).to(device=device)
418 | self.indices = torch.cat((self.indices, new_indices), dim=1)
419 | output = GrowConnections.apply(
420 | inputs,
421 | self.weights,
422 | k,
423 | self.indices,
424 | (self.out_features, self.in_features),
425 | self.max_size,
426 | )
427 | else:
428 | if len(output_shape) == 1:
429 | inputs = inputs.view(1, -1)
430 | inputs = inputs.flatten(end_dim=-2)
431 |
432 | # output = torch_sparse.spmm(self.indices, self.weights, self.out_features, self.in_features, inputs.t()).t()
433 | target = torch.sparse.FloatTensor(
434 | self.indices,
435 | self.weights,
436 | torch.Size([self.out_features, self.in_features]),
437 | ).to_dense()
438 | output = torch.mm(target, inputs.t()).t()
439 |
440 | if self.bias is not None:
441 | output += self.bias
442 |
443 | return output.view(output_shape)
444 |
445 | def extra_repr(self):
446 | return "in_features={}, out_features={}, bias={}, sparsity={}, connectivity={}, small_world={}".format(
447 | self.in_features,
448 | self.out_features,
449 | self.bias is not None,
450 | self.sparsity,
451 | self.connectivity,
452 | self.small_world,
453 | )
454 |
455 |
--------------------------------------------------------------------------------
/tutorials/SparseLinearDemo.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "markdown",
5 | "metadata": {},
6 | "source": [
7 | "# SparseLinear Demonstration using MNIST \n",
8 | "\n",
9 | "##### Training models consisting of sparsely connected linear layers\n",
10 | "\n",
11 | "## Table of Contents\n",
12 | "\n",
13 | "- [Introduction](#intro)\n",
14 | "- [Setup](#setup)\n",
15 | "- [Time and memory efficiency](#efficiency)\n",
16 | "- [Training with random inputs](#random)\n",
17 | "- [Training on MNIST](#mnist)\n",
18 | "- [Training sparse models with user-defined connections](#user)\n",
19 | "- [Training sparse models with dynamic connections](#dynamic)\n",
20 | "- [Training sparse models with small-world connections](#sw)\n",
21 | "- [Utilizing the activation sparsity feature](#activation)\n",
22 | "- [Training very wide and sparse models](#big)"
23 | ]
24 | },
25 | {
26 | "cell_type": "markdown",
27 | "metadata": {},
28 | "source": [
29 | "## Introduction \n",
30 | "\n",
31 | "SparseLinear is a PyTorch package that allows a user to create extremely wide and sparse linear layers efficiently. A sparsely connected network is a network where each node is connected to some fraction of available nodes.\n",
32 | "\n",
33 | "The provided package is built on top of [PyTorch Sparse](https://github.com/rusty1s/pytorch_sparse), which provides optimized sparse matrix operations with autograd support in PyTorch.\n",
34 | "\n",
35 | "In this tutorial, we demonstrate its basic usage along with steps to train using the package features. Note that it is advisable to run these on the GPU instead of the CPU owing to much faster training times on the former."
36 | ]
37 | },
38 | {
39 | "cell_type": "markdown",
40 | "metadata": {},
41 | "source": [
42 | "## Setup \n",
43 | "\n",
44 | "We import PyTorch, which contains the (dense)linear module, and load the device."
45 | ]
46 | },
47 | {
48 | "cell_type": "code",
49 | "execution_count": 1,
50 | "metadata": {
51 | "tags": []
52 | },
53 | "outputs": [
54 | {
55 | "name": "stdout",
56 | "output_type": "stream",
57 | "text": [
58 | "cuda:0\n"
59 | ]
60 | }
61 | ],
62 | "source": [
63 | "import torch\n",
64 | "import torch.nn as nn\n",
65 | "\n",
66 | "import warnings\n",
67 | "warnings.filterwarnings('ignore')\n",
68 | "\n",
69 | "device = torch.device(\"cuda:0\" if torch.cuda.is_available() else \"cpu\")\n",
70 | "print(device)"
71 | ]
72 | },
73 | {
74 | "cell_type": "markdown",
75 | "metadata": {},
76 | "source": [
77 | "We create the linear layer and demonstrate some of its built-in attributes. "
78 | ]
79 | },
80 | {
81 | "cell_type": "code",
82 | "execution_count": 2,
83 | "metadata": {},
84 | "outputs": [
85 | {
86 | "data": {
87 | "text/plain": [
88 | "('in_features=10, out_features=20, bias=True',\n",
89 | " torch.Size([20, 10]),\n",
90 | " torch.Size([20]))"
91 | ]
92 | },
93 | "execution_count": 2,
94 | "metadata": {},
95 | "output_type": "execute_result"
96 | }
97 | ],
98 | "source": [
99 | "fc1 = nn.Linear(10, 20)\n",
100 | "fc1.extra_repr(), fc1.weight.shape, fc1.bias.shape"
101 | ]
102 | },
103 | {
104 | "cell_type": "markdown",
105 | "metadata": {},
106 | "source": [
107 | "In a similar manner, we now import the `sparselinear` package. As can be observed, the custom layer's weights and biases can be accessed in the same manner as before. The new layer also returns some extra attributes about which we will discuss later. "
108 | ]
109 | },
110 | {
111 | "cell_type": "code",
112 | "execution_count": 3,
113 | "metadata": {
114 | "tags": []
115 | },
116 | "outputs": [],
117 | "source": [
118 | "import sparselinear as sl"
119 | ]
120 | },
121 | {
122 | "cell_type": "code",
123 | "execution_count": 4,
124 | "metadata": {},
125 | "outputs": [
126 | {
127 | "data": {
128 | "text/plain": [
129 | "('in_features=10, out_features=20, bias=True, sparsity=0.9, connectivity=None, small_world=False',\n",
130 | " torch.Size([20, 10]),\n",
131 | " torch.Size([20]))"
132 | ]
133 | },
134 | "execution_count": 4,
135 | "metadata": {},
136 | "output_type": "execute_result"
137 | }
138 | ],
139 | "source": [
140 | "sl1 = sl.SparseLinear(10,20)\n",
141 | "sl1.extra_repr(), sl1.weight.shape, sl1.bias.shape"
142 | ]
143 | },
144 | {
145 | "cell_type": "markdown",
146 | "metadata": {},
147 | "source": [
148 | "We now take a look at the two weight matrices."
149 | ]
150 | },
151 | {
152 | "cell_type": "code",
153 | "execution_count": 5,
154 | "metadata": {},
155 | "outputs": [
156 | {
157 | "data": {
158 | "text/plain": [
159 | "Parameter containing:\n",
160 | "tensor([[ 2.2700e-01, 1.9950e-04, 1.5138e-01, -2.3028e-01, 3.0498e-01,\n",
161 | " -1.1592e-01, 8.9561e-03, -1.9644e-01, -1.8278e-01, -1.2167e-01],\n",
162 | " [ 1.8054e-01, 4.9672e-03, -2.7930e-01, 1.7971e-02, -2.5313e-01,\n",
163 | " -1.6389e-01, 2.8138e-02, 2.3216e-01, 8.5033e-02, 2.6193e-01],\n",
164 | " [-1.3997e-01, -2.0780e-01, -1.3777e-01, 9.5758e-02, -1.1465e-01,\n",
165 | " -3.0299e-01, 2.3639e-01, 2.3740e-01, 3.4879e-02, -2.8988e-01],\n",
166 | " [ 3.4024e-02, -3.4284e-02, -3.1449e-01, -7.3634e-02, 1.0884e-01,\n",
167 | " 3.4649e-02, 2.2210e-01, -2.2692e-01, 1.7318e-01, 1.0567e-01],\n",
168 | " [ 1.8497e-01, 8.6446e-02, -1.3994e-02, -1.8335e-01, 7.1342e-02,\n",
169 | " -5.4367e-02, -1.2261e-01, -1.2711e-01, 1.2817e-01, 3.0136e-01],\n",
170 | " [ 2.7756e-01, -2.6505e-01, 2.1932e-02, 2.2353e-01, -2.0779e-01,\n",
171 | " 2.9041e-01, -2.9108e-01, 2.5556e-02, 2.6355e-02, 9.2430e-02],\n",
172 | " [-6.9308e-02, 1.4349e-01, 2.1799e-01, 9.2573e-02, -1.1946e-01,\n",
173 | " -8.8171e-02, 1.8941e-01, 2.6366e-01, -2.2858e-01, -7.3599e-02],\n",
174 | " [ 2.7514e-01, 5.0453e-02, -2.7328e-01, 1.8520e-01, -1.9852e-01,\n",
175 | " -1.9735e-01, 2.4275e-01, -3.9498e-02, -9.1360e-02, -1.1861e-01],\n",
176 | " [-1.1171e-01, -9.5010e-02, 8.9707e-02, 4.6313e-02, 1.4619e-01,\n",
177 | " 8.1823e-02, 1.7853e-01, -8.7963e-02, -1.1446e-01, 2.6627e-01],\n",
178 | " [ 2.3615e-01, -3.6197e-02, -8.4897e-03, -1.6606e-01, 2.2675e-01,\n",
179 | " 2.4547e-01, -9.0289e-03, 2.4520e-01, -5.2978e-02, 2.1697e-01],\n",
180 | " [-3.0513e-01, 1.3863e-01, 7.8270e-02, 1.4909e-01, 1.9973e-01,\n",
181 | " -3.0507e-01, 2.5586e-01, -1.7424e-01, -1.5309e-01, 5.5867e-04],\n",
182 | " [-7.0133e-02, -1.9032e-02, -1.2506e-01, -1.7848e-01, -2.9393e-02,\n",
183 | " -1.2914e-01, 2.6337e-01, 2.6768e-02, 1.2672e-01, 2.7603e-01],\n",
184 | " [-2.8693e-01, -2.3720e-01, 2.5797e-01, 1.4165e-01, 1.3373e-01,\n",
185 | " 1.8109e-01, -1.6957e-01, -1.7997e-02, 8.7014e-02, -2.4652e-01],\n",
186 | " [-2.5947e-01, -2.8787e-01, -2.5539e-01, 1.6618e-01, 1.8108e-01,\n",
187 | " 2.5726e-01, -8.1458e-02, -5.7938e-02, -1.2392e-01, 2.4684e-01],\n",
188 | " [-1.4422e-01, 2.5416e-01, -3.1530e-01, -3.4516e-02, 2.4788e-03,\n",
189 | " -4.7461e-02, -4.3661e-02, 6.9811e-02, -1.5750e-01, -2.6075e-02],\n",
190 | " [-2.1888e-02, 7.3086e-02, -1.8007e-01, 1.3642e-01, 3.5587e-02,\n",
191 | " -9.6858e-02, -2.9306e-01, 3.6169e-02, 1.5752e-01, -1.6869e-01],\n",
192 | " [-1.4439e-01, 1.7146e-01, -2.5298e-01, -8.1204e-02, -2.2055e-01,\n",
193 | " 7.1367e-02, -1.9483e-01, 6.2341e-02, 2.7298e-01, -1.9703e-01],\n",
194 | " [ 2.4394e-01, -1.1722e-01, 1.9469e-01, -1.3386e-01, 2.3983e-02,\n",
195 | " -1.6841e-01, -2.1754e-01, -3.1110e-01, -3.0143e-01, -2.0946e-01],\n",
196 | " [-5.3311e-02, -2.9350e-01, -2.7600e-01, -2.5163e-02, 1.9405e-01,\n",
197 | " 2.2620e-01, -2.3203e-01, 2.2993e-01, 2.7201e-02, -2.3870e-01],\n",
198 | " [-1.0589e-01, 2.1078e-01, 1.7010e-01, 1.1657e-01, 9.0184e-02,\n",
199 | " 2.8292e-01, -9.2838e-02, 6.3635e-02, 6.2464e-02, -3.0613e-01]],\n",
200 | " requires_grad=True)"
201 | ]
202 | },
203 | "execution_count": 5,
204 | "metadata": {},
205 | "output_type": "execute_result"
206 | }
207 | ],
208 | "source": [
209 | "fc1.weight"
210 | ]
211 | },
212 | {
213 | "cell_type": "code",
214 | "execution_count": 6,
215 | "metadata": {},
216 | "outputs": [
217 | {
218 | "data": {
219 | "text/plain": [
220 | "tensor(indices=tensor([[ 0, 0, 1, 2, 3, 3, 4, 5, 5, 6, 7, 11, 12, 12,\n",
221 | " 13, 14, 14, 15, 15, 19],\n",
222 | " [ 3, 8, 1, 4, 0, 7, 1, 5, 8, 7, 0, 2, 3, 4,\n",
223 | " 7, 3, 7, 4, 7, 9]]),\n",
224 | " values=tensor([-0.0633, -0.0889, -0.1319, -0.0976, 0.0678, -0.0097,\n",
225 | " -0.1309, -0.0588, -0.2626, -0.1929, 0.2060, -0.2528,\n",
226 | " -0.0253, 0.1744, 0.2165, 0.1699, -0.0991, -0.2581,\n",
227 | " 0.3090, 0.0959]),\n",
228 | " size=(20, 10), nnz=20, layout=torch.sparse_coo)"
229 | ]
230 | },
231 | "execution_count": 6,
232 | "metadata": {},
233 | "output_type": "execute_result"
234 | }
235 | ],
236 | "source": [
237 | "sl1.weight"
238 | ]
239 | },
240 | {
241 | "cell_type": "markdown",
242 | "metadata": {},
243 | "source": [
244 | "As can be seen, the first weight matrix has 200 non-zero entries while the second one has 20 non-zero entries as specified by `nnz`. The indices tensor keeps track of all the indices where a non-zero entry is present with the corresponding entry in the values tensor providing the entry at that index."
245 | ]
246 | },
247 | {
248 | "cell_type": "markdown",
249 | "metadata": {},
250 | "source": [
251 | "## Time and Memory Efficiency \n",
252 | "\n",
253 | "The `SparseLinear` class is ideal for very wide and sparse layers. Since we utilize sparse tensors and only store non-zero values (and their corresponding indices), `SparseLinear` is much more efficient in terms of memory consumption than simply applying a mask over a standard dense weight matrix -- as is often done by researchers and practioners. Since we only perform computations on non-zero values, we see speedups in computation time for large matrices as well. As hardware becomes more well-suited for sparse computations, these speedups will likely increase.\n",
254 | "\n",
255 | "To show this, we create a (20000, 20000) `SparseLinear` layer with 99% sparsity and compare its runtime to that of a standard `Linear` layer. Later in this notebook, we create massive layers that would not be possible with standard `Linear` layers due to memory inefficiencies.\n",
256 | "\n",
257 | "We initialize two layers and define the input."
258 | ]
259 | },
260 | {
261 | "cell_type": "code",
262 | "execution_count": 7,
263 | "metadata": {},
264 | "outputs": [],
265 | "source": [
266 | "sl2 = sl.SparseLinear(20000, 20000, sparsity=.99).cuda()\n",
267 | "\n",
268 | "# Reduce weight dimensions if memory errors are raised\n",
269 | "fc2 = nn.Linear(20000, 20000).cuda()\n",
270 | "\n",
271 | "x = torch.rand(20000, device=device)"
272 | ]
273 | },
274 | {
275 | "cell_type": "markdown",
276 | "metadata": {},
277 | "source": [
278 | "We time the inference steps."
279 | ]
280 | },
281 | {
282 | "cell_type": "code",
283 | "execution_count": 8,
284 | "metadata": {},
285 | "outputs": [
286 | {
287 | "name": "stdout",
288 | "output_type": "stream",
289 | "text": [
290 | "583 µs ± 120 ns per loop (mean ± std. dev. of 7 runs, 1000 loops each)\n",
291 | "1.85 ms ± 93.1 ns per loop (mean ± std. dev. of 7 runs, 1000 loops each)\n"
292 | ]
293 | }
294 | ],
295 | "source": [
296 | "%timeit y = sl2(x)\n",
297 | "%timeit y = fc2(x)"
298 | ]
299 | },
300 | {
301 | "cell_type": "markdown",
302 | "metadata": {},
303 | "source": [
304 | "We time the training step for SparseLinear."
305 | ]
306 | },
307 | {
308 | "cell_type": "code",
309 | "execution_count": 9,
310 | "metadata": {},
311 | "outputs": [
312 | {
313 | "name": "stdout",
314 | "output_type": "stream",
315 | "text": [
316 | "789 µs ± 666 ns per loop (mean ± std. dev. of 7 runs, 100 loops each)\n"
317 | ]
318 | }
319 | ],
320 | "source": [
321 | "%%timeit\n",
322 | "y = sl2(x)\n",
323 | "y.sum().backward()"
324 | ]
325 | },
326 | {
327 | "cell_type": "markdown",
328 | "metadata": {},
329 | "source": [
330 | "We time the training step for Linear."
331 | ]
332 | },
333 | {
334 | "cell_type": "code",
335 | "execution_count": 10,
336 | "metadata": {},
337 | "outputs": [
338 | {
339 | "name": "stdout",
340 | "output_type": "stream",
341 | "text": [
342 | "9.29 ms ± 86 ns per loop (mean ± std. dev. of 7 runs, 1000 loops each)\n"
343 | ]
344 | }
345 | ],
346 | "source": [
347 | "%%timeit\n",
348 | "y = fc2(x)\n",
349 | "y.sum().backward()"
350 | ]
351 | },
352 | {
353 | "cell_type": "markdown",
354 | "metadata": {},
355 | "source": [
356 | "We delete layers to save GPU memory when running this notebook."
357 | ]
358 | },
359 | {
360 | "cell_type": "code",
361 | "execution_count": 11,
362 | "metadata": {},
363 | "outputs": [],
364 | "source": [
365 | "del sl2, fc2"
366 | ]
367 | },
368 | {
369 | "cell_type": "markdown",
370 | "metadata": {},
371 | "source": [
372 | "## Training with random inputs \n",
373 | "\n",
374 | "Next, we demonstrate how to train a two-layer network using the `SparseLinear` module provided in the package. The code has been built upon the PyTorch [tutorial](https://pytorch.org/tutorials/beginner/examples_nn/two_layer_net_nn.html) to highlight the parallels between the `nn.Linear` and `sl.SparseLinear` modules."
375 | ]
376 | },
377 | {
378 | "cell_type": "code",
379 | "execution_count": 12,
380 | "metadata": {
381 | "tags": []
382 | },
383 | "outputs": [
384 | {
385 | "name": "stdout",
386 | "output_type": "stream",
387 | "text": [
388 | "Dense model loss: 4.417; Sparse model loss: 5.111\n",
389 | "Dense model loss: 0.075; Sparse model loss: 0.031\n",
390 | "Dense model loss: 0.002; Sparse model loss: 0.000\n",
391 | "Dense model loss: 0.000; Sparse model loss: 0.000\n",
392 | "Dense model loss: 0.000; Sparse model loss: 0.000\n"
393 | ]
394 | }
395 | ],
396 | "source": [
397 | "# N is batch size; D_in is input dimension;\n",
398 | "# H is hidden dimension; D_out is output dimension.\n",
399 | "N, D_in, H, D_out = 64, 200, 1000, 10\n",
400 | "\n",
401 | "# Create random Tensors to hold inputs and outputs.\n",
402 | "x = torch.randn(N, D_in)\n",
403 | "y = torch.randn(N, D_out)\n",
404 | "\n",
405 | "# Use the nn package to define our dense model as a sequence of layers. \n",
406 | "model_dense = torch.nn.Sequential(\n",
407 | " torch.nn.Linear(D_in, H),\n",
408 | " torch.nn.ReLU(),\n",
409 | " torch.nn.Linear(H, D_out),\n",
410 | ")\n",
411 | "\n",
412 | "# Use the sl package to define our sparse model as a sequence of layers.\n",
413 | "# Note that the default sparsity is 90%.\n",
414 | "model_sparse = torch.nn.Sequential(\n",
415 | " sl.SparseLinear(D_in, H),\n",
416 | " torch.nn.ReLU(),\n",
417 | " sl.SparseLinear(H, D_out),\n",
418 | ")\n",
419 | "\n",
420 | "# We will use Mean Squared Error (MSE) as our loss function.\n",
421 | "loss_fn = torch.nn.MSELoss(reduction='sum')\n",
422 | "\n",
423 | "# We define our learning rates.\n",
424 | "# Note that sparse and dense models may require different learning rates.\n",
425 | "learning_rate_dense = 1e-4\n",
426 | "learning_rate_sparse = 1e-3\n",
427 | "\n",
428 | "for t in range(500):\n",
429 | " # Forward pass\n",
430 | " y_pred_dense = model_dense(x)\n",
431 | " y_pred_sparse = model_sparse(x)\n",
432 | "\n",
433 | " # Compute and print loss\n",
434 | " loss_dense = loss_fn(y_pred_dense, y)\n",
435 | " loss_sparse = loss_fn(y_pred_sparse, y)\n",
436 | " if t % 100 == 99:\n",
437 | " print(\"Dense model loss: %.3f; Sparse model loss: %.3f\" %(loss_dense.item(), loss_sparse.item()))\n",
438 | "\n",
439 | " # Zero the gradients before running the backward pass.\n",
440 | " model_dense.zero_grad()\n",
441 | " model_sparse.zero_grad()\n",
442 | "\n",
443 | " # Backward pass\n",
444 | " loss_dense.backward()\n",
445 | " loss_sparse.backward()\n",
446 | "\n",
447 | " # Update the weights using gradient descent\n",
448 | " with torch.no_grad():\n",
449 | " for param in model_dense.parameters():\n",
450 | " param -= learning_rate_dense * param.grad\n",
451 | " \n",
452 | " for param in model_sparse.parameters():\n",
453 | " param -= learning_rate_sparse * param.grad"
454 | ]
455 | },
456 | {
457 | "cell_type": "markdown",
458 | "metadata": {},
459 | "source": [
460 | "As we can see, the loss value in both models decreases. Let's now build models using this module and train on the MNIST digit classification task."
461 | ]
462 | },
463 | {
464 | "cell_type": "markdown",
465 | "metadata": {},
466 | "source": [
467 | "## Training on MNIST \n",
468 | "\n",
469 | "We start by doing the initial imports, generating transforms, creating the dataset along with the dataloader, defining the loss function and some other helper functions."
470 | ]
471 | },
472 | {
473 | "cell_type": "code",
474 | "execution_count": 13,
475 | "metadata": {},
476 | "outputs": [],
477 | "source": [
478 | "import time\n",
479 | "import torchvision\n",
480 | "import torchvision.transforms as transforms\n",
481 | "import torch.optim as optim\n",
482 | "import torch.nn.functional as F\n",
483 | "from torch.utils.data import sampler"
484 | ]
485 | },
486 | {
487 | "cell_type": "code",
488 | "execution_count": 14,
489 | "metadata": {
490 | "tags": []
491 | },
492 | "outputs": [],
493 | "source": [
494 | "tf = transforms.Compose([transforms.ToTensor(),\n",
495 | " transforms.Normalize((0.1307,), (0.3081,))])\n",
496 | "\n",
497 | "batch_size = 64\n",
498 | "trainset = torchvision.datasets.MNIST(root='./data', train=True, download=True, transform=tf)\n",
499 | "testset = torchvision.datasets.MNIST(root='./data', train=False, download=True, transform=tf)\n",
500 | "train_dataloader = torch.utils.data.DataLoader(trainset, batch_size=batch_size, shuffle=True, drop_last=True)\n",
501 | "test_dataloader = torch.utils.data.DataLoader(testset, batch_size=batch_size, shuffle=False, drop_last=True)"
502 | ]
503 | },
504 | {
505 | "cell_type": "code",
506 | "execution_count": 15,
507 | "metadata": {
508 | "tags": []
509 | },
510 | "outputs": [],
511 | "source": [
512 | "def train_model(model, optimizer, criterion, train_dataloader, test_dataloader, num_epochs=20):\n",
513 | " since = time.time()\n",
514 | " for epoch in range(num_epochs):\n",
515 | " cum_loss, total, correct = 0, 0, 0\n",
516 | " model.train()\n",
517 | " \n",
518 | " # Training epoch\n",
519 | " for i, (images, labels) in enumerate(train_dataloader, 0):\n",
520 | " images = images.to(device)\n",
521 | " labels = labels.to(device)\n",
522 | "\n",
523 | " # Forward pass & statistics\n",
524 | " out = model(images)\n",
525 | " predicted = out.argmax(dim=1)\n",
526 | " correct += (predicted == labels).sum().item()\n",
527 | " total += labels.size(0)\n",
528 | " loss = criterion(out, labels)\n",
529 | " cum_loss += loss.item()\n",
530 | "\n",
531 | " # Backwards pass & update\n",
532 | " loss.backward()\n",
533 | " optimizer.step()\n",
534 | " optimizer.zero_grad()\n",
535 | " \n",
536 | " epoch_loss = images.shape[0] * cum_loss / total\n",
537 | " epoch_acc = 100 * (correct / total)\n",
538 | " print('Epoch %d' % (epoch + 1))\n",
539 | " print('Training Loss: {:.4f}; Training Acc: {:.4f}'.format(epoch_loss, epoch_acc))\n",
540 | " \n",
541 | " cum_loss, total, correct = 0, 0, 0\n",
542 | " model.eval()\n",
543 | " \n",
544 | " # Test epoch\n",
545 | " for i, (images, labels) in enumerate(test_dataloader, 0):\n",
546 | " images = images.to(device)\n",
547 | " labels = labels.to(device)\n",
548 | "\n",
549 | " # Forward pass & statistics\n",
550 | " out = model(images)\n",
551 | " predicted = out.argmax(dim=1)\n",
552 | " correct += (predicted == labels).sum().item()\n",
553 | " total += labels.size(0)\n",
554 | " loss = criterion(out, labels)\n",
555 | " cum_loss += loss.item()\n",
556 | " \n",
557 | " epoch_loss = images.shape[0] * cum_loss / total\n",
558 | " epoch_acc = 100 * (correct / total)\n",
559 | " \n",
560 | " print('Test Loss: {:.4f}; Test Acc: {:.4f}'.format(epoch_loss, epoch_acc))\n",
561 | " print('------------')\n",
562 | " \n",
563 | " time_elapsed = time.time() - since\n",
564 | " print('\\nTraining complete in {:.0f}m {:.0f}s'.format(\n",
565 | " time_elapsed // 60, time_elapsed % 60))"
566 | ]
567 | },
568 | {
569 | "cell_type": "code",
570 | "execution_count": 16,
571 | "metadata": {},
572 | "outputs": [],
573 | "source": [
574 | "def flatten(x):\n",
575 | "\tN = x.shape[0]\n",
576 | "\treturn x.view(N, -1)\n",
577 | "\n",
578 | "class Flatten(nn.Module):\n",
579 | " def forward(self, x):\n",
580 | " return flatten(x)"
581 | ]
582 | },
583 | {
584 | "cell_type": "code",
585 | "execution_count": 17,
586 | "metadata": {},
587 | "outputs": [],
588 | "source": [
589 | "criterion = nn.CrossEntropyLoss()"
590 | ]
591 | },
592 | {
593 | "cell_type": "markdown",
594 | "metadata": {},
595 | "source": [
596 | "### Training a dense model\n",
597 | "\n",
598 | "We start off with training a two-layer fully connected network. "
599 | ]
600 | },
601 | {
602 | "cell_type": "code",
603 | "execution_count": 18,
604 | "metadata": {},
605 | "outputs": [],
606 | "source": [
607 | "model = nn.Sequential(\n",
608 | "\tFlatten(),\n",
609 | "\tnn.Linear(784, 2000),\n",
610 | " nn.LayerNorm(2000),\n",
611 | "\tnn.ReLU(),\n",
612 | " nn.Linear(2000, 10),\n",
613 | ")\n",
614 | "model = model.to(device)"
615 | ]
616 | },
617 | {
618 | "cell_type": "markdown",
619 | "metadata": {},
620 | "source": [
621 | "After we set everything up, we declare the optimizer and start training the dense model. We use SGD as the optimizer since we found its behavior to be slightly better than that of others. However, one is free to choose any optimizer as long as there exists an implementation for it to handle sparse tensors. "
622 | ]
623 | },
624 | {
625 | "cell_type": "code",
626 | "execution_count": 19,
627 | "metadata": {
628 | "tags": []
629 | },
630 | "outputs": [
631 | {
632 | "name": "stdout",
633 | "output_type": "stream",
634 | "text": [
635 | "Epoch 1\n",
636 | "Training Loss: 0.2062; Training Acc: 93.5916\n",
637 | "Test Loss: 0.1250; Test Acc: 95.9736\n",
638 | "------------\n",
639 | "Epoch 2\n",
640 | "Training Loss: 0.0778; Training Acc: 97.6387\n",
641 | "Test Loss: 0.0949; Test Acc: 96.9351\n",
642 | "------------\n",
643 | "Epoch 3\n",
644 | "Training Loss: 0.0491; Training Acc: 98.4775\n",
645 | "Test Loss: 0.0785; Test Acc: 97.5160\n",
646 | "------------\n",
647 | "Epoch 4\n",
648 | "Training Loss: 0.0311; Training Acc: 99.0678\n",
649 | "Test Loss: 0.0664; Test Acc: 97.9667\n",
650 | "------------\n",
651 | "Epoch 5\n",
652 | "Training Loss: 0.0201; Training Acc: 99.4614\n",
653 | "Test Loss: 0.0704; Test Acc: 97.7564\n",
654 | "------------\n",
655 | "Epoch 6\n",
656 | "Training Loss: 0.0131; Training Acc: 99.7065\n",
657 | "Test Loss: 0.0549; Test Acc: 98.3173\n",
658 | "------------\n",
659 | "Epoch 7\n",
660 | "Training Loss: 0.0081; Training Acc: 99.8749\n",
661 | "Test Loss: 0.0592; Test Acc: 98.2372\n",
662 | "------------\n",
663 | "Epoch 8\n",
664 | "Training Loss: 0.0057; Training Acc: 99.9350\n",
665 | "Test Loss: 0.0536; Test Acc: 98.3674\n",
666 | "------------\n",
667 | "Epoch 9\n",
668 | "Training Loss: 0.0035; Training Acc: 99.9850\n",
669 | "Test Loss: 0.0551; Test Acc: 98.3173\n",
670 | "------------\n",
671 | "Epoch 10\n",
672 | "Training Loss: 0.0025; Training Acc: 99.9983\n",
673 | "Test Loss: 0.0518; Test Acc: 98.4976\n",
674 | "------------\n",
675 | "Epoch 11\n",
676 | "Training Loss: 0.0020; Training Acc: 100.0000\n",
677 | "Test Loss: 0.0526; Test Acc: 98.3974\n",
678 | "------------\n",
679 | "Epoch 12\n",
680 | "Training Loss: 0.0017; Training Acc: 99.9983\n",
681 | "Test Loss: 0.0530; Test Acc: 98.4976\n",
682 | "------------\n",
683 | "Epoch 13\n",
684 | "Training Loss: 0.0015; Training Acc: 100.0000\n",
685 | "Test Loss: 0.0531; Test Acc: 98.5377\n",
686 | "------------\n",
687 | "Epoch 14\n",
688 | "Training Loss: 0.0013; Training Acc: 100.0000\n",
689 | "Test Loss: 0.0532; Test Acc: 98.5076\n",
690 | "------------\n",
691 | "Epoch 15\n",
692 | "Training Loss: 0.0012; Training Acc: 100.0000\n",
693 | "Test Loss: 0.0544; Test Acc: 98.4675\n",
694 | "------------\n",
695 | "Epoch 16\n",
696 | "Training Loss: 0.0011; Training Acc: 100.0000\n",
697 | "Test Loss: 0.0539; Test Acc: 98.4776\n",
698 | "------------\n",
699 | "Epoch 17\n",
700 | "Training Loss: 0.0010; Training Acc: 100.0000\n",
701 | "Test Loss: 0.0537; Test Acc: 98.5276\n",
702 | "------------\n",
703 | "Epoch 18\n",
704 | "Training Loss: 0.0009; Training Acc: 100.0000\n",
705 | "Test Loss: 0.0546; Test Acc: 98.4976\n",
706 | "------------\n",
707 | "Epoch 19\n",
708 | "Training Loss: 0.0009; Training Acc: 100.0000\n",
709 | "Test Loss: 0.0545; Test Acc: 98.5076\n",
710 | "------------\n",
711 | "Epoch 20\n",
712 | "Training Loss: 0.0008; Training Acc: 100.0000\n",
713 | "Test Loss: 0.0544; Test Acc: 98.4776\n",
714 | "------------\n",
715 | "\n",
716 | "Training complete in 4m 56s\n"
717 | ]
718 | }
719 | ],
720 | "source": [
721 | "learning_rate = 1e-2\n",
722 | "optimizer = optim.SGD(model.parameters(), lr=learning_rate, momentum=0.9)\n",
723 | "\n",
724 | "#Perform the training \n",
725 | "train_model(model, optimizer, criterion, train_dataloader, test_dataloader)"
726 | ]
727 | },
728 | {
729 | "cell_type": "markdown",
730 | "metadata": {},
731 | "source": [
732 | "### Training a Sparse Model\n",
733 | "\n",
734 | "(with the default configuration)\n",
735 | "\n",
736 | "In the same way that we declared a dense model, we now declare a sparse model with the same number of input and output features but far fewer parameters. "
737 | ]
738 | },
739 | {
740 | "cell_type": "code",
741 | "execution_count": 20,
742 | "metadata": {},
743 | "outputs": [],
744 | "source": [
745 | "sparse_model = nn.Sequential(\n",
746 | "\tFlatten(),\n",
747 | "\tsl.SparseLinear(784, 2000),\n",
748 | " nn.LayerNorm(2000),\n",
749 | "\tnn.ReLU(),\n",
750 | " sl.SparseLinear(2000, 10)\n",
751 | ")\n",
752 | "sparse_model = sparse_model.to(device)"
753 | ]
754 | },
755 | {
756 | "cell_type": "markdown",
757 | "metadata": {},
758 | "source": [
759 | "We now train this model. Note that the learning rate is an order of magnitude higher. This is something we have found to be a rule of thumb while training these models. "
760 | ]
761 | },
762 | {
763 | "cell_type": "code",
764 | "execution_count": 21,
765 | "metadata": {
766 | "tags": []
767 | },
768 | "outputs": [
769 | {
770 | "name": "stdout",
771 | "output_type": "stream",
772 | "text": [
773 | "Epoch 1\n",
774 | "Training Loss: 0.2103; Training Acc: 93.6383\n",
775 | "Test Loss: 0.1184; Test Acc: 96.3141\n",
776 | "------------\n",
777 | "Epoch 2\n",
778 | "Training Loss: 0.0896; Training Acc: 97.2269\n",
779 | "Test Loss: 0.1022; Test Acc: 96.7949\n",
780 | "------------\n",
781 | "Epoch 3\n",
782 | "Training Loss: 0.0604; Training Acc: 98.0340\n",
783 | "Test Loss: 0.0880; Test Acc: 97.4159\n",
784 | "------------\n",
785 | "Epoch 4\n",
786 | "Training Loss: 0.0424; Training Acc: 98.6860\n",
787 | "Test Loss: 0.0799; Test Acc: 97.6262\n",
788 | "------------\n",
789 | "Epoch 5\n",
790 | "Training Loss: 0.0313; Training Acc: 99.0061\n",
791 | "Test Loss: 0.0782; Test Acc: 97.7764\n",
792 | "------------\n",
793 | "Epoch 6\n",
794 | "Training Loss: 0.0228; Training Acc: 99.2696\n",
795 | "Test Loss: 0.0843; Test Acc: 97.6963\n",
796 | "------------\n",
797 | "Epoch 7\n",
798 | "Training Loss: 0.0163; Training Acc: 99.5147\n",
799 | "Test Loss: 0.0861; Test Acc: 97.8466\n",
800 | "------------\n",
801 | "Epoch 8\n",
802 | "Training Loss: 0.0117; Training Acc: 99.6565\n",
803 | "Test Loss: 0.0917; Test Acc: 97.7364\n",
804 | "------------\n",
805 | "Epoch 9\n",
806 | "Training Loss: 0.0074; Training Acc: 99.8332\n",
807 | "Test Loss: 0.0913; Test Acc: 97.8666\n",
808 | "------------\n",
809 | "Epoch 10\n",
810 | "Training Loss: 0.0044; Training Acc: 99.9200\n",
811 | "Test Loss: 0.0819; Test Acc: 97.9667\n",
812 | "------------\n",
813 | "Epoch 11\n",
814 | "Training Loss: 0.0023; Training Acc: 99.9800\n",
815 | "Test Loss: 0.0852; Test Acc: 98.0970\n",
816 | "------------\n",
817 | "Epoch 12\n",
818 | "Training Loss: 0.0012; Training Acc: 99.9950\n",
819 | "Test Loss: 0.0894; Test Acc: 98.0469\n",
820 | "------------\n",
821 | "Epoch 13\n",
822 | "Training Loss: 0.0008; Training Acc: 99.9967\n",
823 | "Test Loss: 0.0864; Test Acc: 98.1871\n",
824 | "------------\n",
825 | "Epoch 14\n",
826 | "Training Loss: 0.0006; Training Acc: 100.0000\n",
827 | "Test Loss: 0.0874; Test Acc: 98.1470\n",
828 | "------------\n",
829 | "Epoch 15\n",
830 | "Training Loss: 0.0005; Training Acc: 100.0000\n",
831 | "Test Loss: 0.0884; Test Acc: 98.1270\n",
832 | "------------\n",
833 | "Epoch 16\n",
834 | "Training Loss: 0.0004; Training Acc: 100.0000\n",
835 | "Test Loss: 0.0891; Test Acc: 98.1571\n",
836 | "------------\n",
837 | "Epoch 17\n",
838 | "Training Loss: 0.0004; Training Acc: 100.0000\n",
839 | "Test Loss: 0.0901; Test Acc: 98.1470\n",
840 | "------------\n",
841 | "Epoch 18\n",
842 | "Training Loss: 0.0004; Training Acc: 100.0000\n",
843 | "Test Loss: 0.0910; Test Acc: 98.1070\n",
844 | "------------\n",
845 | "Epoch 19\n",
846 | "Training Loss: 0.0003; Training Acc: 100.0000\n",
847 | "Test Loss: 0.0916; Test Acc: 98.0970\n",
848 | "------------\n",
849 | "Epoch 20\n",
850 | "Training Loss: 0.0003; Training Acc: 100.0000\n",
851 | "Test Loss: 0.0918; Test Acc: 98.1571\n",
852 | "------------\n",
853 | "\n",
854 | "Training complete in 5m 12s\n"
855 | ]
856 | }
857 | ],
858 | "source": [
859 | "learning_rate = 1e-1\n",
860 | "optimizer = optim.SGD(sparse_model.parameters(), lr=learning_rate, momentum=0.9)\n",
861 | "\n",
862 | "train_model(sparse_model, optimizer, criterion, train_dataloader, test_dataloader)"
863 | ]
864 | },
865 | {
866 | "cell_type": "markdown",
867 | "metadata": {},
868 | "source": [
869 | "As can be seen, the two models perform comparably.\n",
870 | "\n",
871 | "\n",
872 | "However, while the dense model has a total of 1590010 parameters, the sparse model only has 160810 parameters. This translates to **~89.8%** parameter reduction in the sparse model! \n",
873 | "\n",
874 | "We display the weight parameter counts of the two layers below. "
875 | ]
876 | },
877 | {
878 | "cell_type": "code",
879 | "execution_count": 22,
880 | "metadata": {},
881 | "outputs": [
882 | {
883 | "data": {
884 | "text/plain": [
885 | "(torch.Size([156800]),\n",
886 | " torch.Size([2000, 784]),\n",
887 | " torch.Size([10, 2000]),\n",
888 | " torch.Size([2000]))"
889 | ]
890 | },
891 | "execution_count": 22,
892 | "metadata": {},
893 | "output_type": "execute_result"
894 | }
895 | ],
896 | "source": [
897 | "sparse_model[1].weights.shape, model[1].weight.shape, model[4].weight.shape, sparse_model[4].weights.shape"
898 | ]
899 | },
900 | {
901 | "cell_type": "markdown",
902 | "metadata": {},
903 | "source": [
904 | "## Training sparse models with user-defined connections \n",
905 | "\n",
906 | "Instead of using the random set of connections created during initialization between the input and output neurons, one can choose to define one's own connections to the sparse linear layer by providing an input long tensor of shape (2,`nnz`) specifying connections from input to output neurons using the `connectivity` argument. "
907 | ]
908 | },
909 | {
910 | "cell_type": "markdown",
911 | "metadata": {},
912 | "source": [
913 | "Below we create a connectivity matrix where the input layer is connected to random entries in the output layer. Of course, this is just a small demonstration and one can experiment here with different connectivity matrices. "
914 | ]
915 | },
916 | {
917 | "cell_type": "code",
918 | "execution_count": 23,
919 | "metadata": {},
920 | "outputs": [],
921 | "source": [
922 | "num_connections = 200\n",
923 | "col = torch.arange(784).repeat_interleave(num_connections).view(1,-1).long()\n",
924 | "row = torch.randint(low=0, high=2000, size=(784*num_connections,)).view(1,-1).long()\n",
925 | "connections = torch.cat((row, col), dim=0)"
926 | ]
927 | },
928 | {
929 | "cell_type": "markdown",
930 | "metadata": {},
931 | "source": [
932 | "We provide our connectivity matrix as an input to the `SparseLinear` module and follow the same training procedure as before. "
933 | ]
934 | },
935 | {
936 | "cell_type": "code",
937 | "execution_count": 24,
938 | "metadata": {},
939 | "outputs": [],
940 | "source": [
941 | "sparse_model_user = nn.Sequential(\n",
942 | "\tFlatten(),\n",
943 | "\tsl.SparseLinear(784, 2000, connectivity=connections),\n",
944 | " nn.LayerNorm(2000),\n",
945 | "\tnn.ReLU(),\n",
946 | " sl.SparseLinear(2000, 10)\n",
947 | ")\n",
948 | "sparse_model_user = sparse_model_user.to(device)"
949 | ]
950 | },
951 | {
952 | "cell_type": "code",
953 | "execution_count": 25,
954 | "metadata": {},
955 | "outputs": [
956 | {
957 | "name": "stdout",
958 | "output_type": "stream",
959 | "text": [
960 | "Epoch 1\n",
961 | "Training Loss: 0.2136; Training Acc: 93.6283\n",
962 | "Test Loss: 0.1089; Test Acc: 96.8249\n",
963 | "------------\n",
964 | "Epoch 2\n",
965 | "Training Loss: 0.0902; Training Acc: 97.3302\n",
966 | "Test Loss: 0.1025; Test Acc: 96.8049\n",
967 | "------------\n",
968 | "Epoch 3\n",
969 | "Training Loss: 0.0619; Training Acc: 98.0606\n",
970 | "Test Loss: 0.0877; Test Acc: 97.3658\n",
971 | "------------\n",
972 | "Epoch 4\n",
973 | "Training Loss: 0.0437; Training Acc: 98.5859\n",
974 | "Test Loss: 0.0814; Test Acc: 97.5160\n",
975 | "------------\n",
976 | "Epoch 5\n",
977 | "Training Loss: 0.0340; Training Acc: 98.8661\n",
978 | "Test Loss: 0.0854; Test Acc: 97.5461\n",
979 | "------------\n",
980 | "Epoch 6\n",
981 | "Training Loss: 0.0239; Training Acc: 99.2496\n",
982 | "Test Loss: 0.0815; Test Acc: 97.6062\n",
983 | "------------\n",
984 | "Epoch 7\n",
985 | "Training Loss: 0.0165; Training Acc: 99.4997\n",
986 | "Test Loss: 0.0901; Test Acc: 97.5060\n",
987 | "------------\n",
988 | "Epoch 8\n",
989 | "Training Loss: 0.0122; Training Acc: 99.6298\n",
990 | "Test Loss: 0.0873; Test Acc: 97.7063\n",
991 | "------------\n",
992 | "Epoch 9\n",
993 | "Training Loss: 0.0080; Training Acc: 99.8016\n",
994 | "Test Loss: 0.0917; Test Acc: 97.6963\n",
995 | "------------\n",
996 | "Epoch 10\n",
997 | "Training Loss: 0.0041; Training Acc: 99.9450\n",
998 | "Test Loss: 0.0837; Test Acc: 97.8766\n",
999 | "------------\n",
1000 | "Epoch 11\n",
1001 | "Training Loss: 0.0020; Training Acc: 99.9833\n",
1002 | "Test Loss: 0.0864; Test Acc: 97.8566\n",
1003 | "------------\n",
1004 | "Epoch 12\n",
1005 | "Training Loss: 0.0012; Training Acc: 99.9967\n",
1006 | "Test Loss: 0.0852; Test Acc: 97.9267\n",
1007 | "------------\n",
1008 | "Epoch 13\n",
1009 | "Training Loss: 0.0008; Training Acc: 100.0000\n",
1010 | "Test Loss: 0.0849; Test Acc: 97.9167\n",
1011 | "------------\n",
1012 | "Epoch 14\n",
1013 | "Training Loss: 0.0006; Training Acc: 100.0000\n",
1014 | "Test Loss: 0.0871; Test Acc: 97.9167\n",
1015 | "------------\n",
1016 | "Epoch 15\n",
1017 | "Training Loss: 0.0005; Training Acc: 100.0000\n",
1018 | "Test Loss: 0.0881; Test Acc: 97.8766\n",
1019 | "------------\n",
1020 | "Epoch 16\n",
1021 | "Training Loss: 0.0005; Training Acc: 100.0000\n",
1022 | "Test Loss: 0.0884; Test Acc: 97.8966\n",
1023 | "------------\n",
1024 | "Epoch 17\n",
1025 | "Training Loss: 0.0004; Training Acc: 100.0000\n",
1026 | "Test Loss: 0.0896; Test Acc: 97.8866\n",
1027 | "------------\n",
1028 | "Epoch 18\n",
1029 | "Training Loss: 0.0004; Training Acc: 100.0000\n",
1030 | "Test Loss: 0.0903; Test Acc: 97.8966\n",
1031 | "------------\n",
1032 | "Epoch 19\n",
1033 | "Training Loss: 0.0004; Training Acc: 100.0000\n",
1034 | "Test Loss: 0.0905; Test Acc: 97.8966\n",
1035 | "------------\n",
1036 | "Epoch 20\n",
1037 | "Training Loss: 0.0003; Training Acc: 100.0000\n",
1038 | "Test Loss: 0.0918; Test Acc: 97.8766\n",
1039 | "------------\n",
1040 | "\n",
1041 | "Training complete in 5m 19s\n"
1042 | ]
1043 | }
1044 | ],
1045 | "source": [
1046 | "learning_rate = 1e-1\n",
1047 | "optimizer = optim.SGD(sparse_model_user.parameters(), lr=learning_rate, momentum=0.9)\n",
1048 | "\n",
1049 | "train_model(sparse_model_user, optimizer, criterion, train_dataloader, test_dataloader)"
1050 | ]
1051 | },
1052 | {
1053 | "cell_type": "markdown",
1054 | "metadata": {},
1055 | "source": [
1056 | "## Training sparse model with dynamic connections \n",
1057 | "\n",
1058 | "The default `SparseLinear` model creates a random set of connections during initialization between the input and output neurons. An improvement over this strategy is to prune some non-required connections and grow (hopefully)required ones. We implement the [Rigging the Lottery](https://arxiv.org/pdf/1911.11134.pdf) algorithm to achieve this. Specifying `dynamic` to be `True` alters the layer connections dynamically while training."
1059 | ]
1060 | },
1061 | {
1062 | "cell_type": "code",
1063 | "execution_count": 26,
1064 | "metadata": {
1065 | "tags": []
1066 | },
1067 | "outputs": [],
1068 | "source": [
1069 | "sparse_model_dynamic = nn.Sequential(\n",
1070 | "\tFlatten(),\n",
1071 | "\tsl.SparseLinear(784, 2000, dynamic=True),\n",
1072 | " nn.LayerNorm(2000),\n",
1073 | "\tnn.ReLU(),\n",
1074 | " sl.SparseLinear(2000, 10, dynamic=True)\n",
1075 | ")\n",
1076 | "sparse_model_dynamic = sparse_model_dynamic.to(device)"
1077 | ]
1078 | },
1079 | {
1080 | "cell_type": "code",
1081 | "execution_count": 27,
1082 | "metadata": {
1083 | "tags": []
1084 | },
1085 | "outputs": [
1086 | {
1087 | "name": "stdout",
1088 | "output_type": "stream",
1089 | "text": [
1090 | "Epoch 1\n",
1091 | "Training Loss: 0.2137; Training Acc: 93.5616\n",
1092 | "Test Loss: 0.1246; Test Acc: 96.0737\n",
1093 | "------------\n",
1094 | "Epoch 2\n",
1095 | "Training Loss: 0.0915; Training Acc: 97.2269\n",
1096 | "Test Loss: 0.0815; Test Acc: 97.4960\n",
1097 | "------------\n",
1098 | "Epoch 3\n",
1099 | "Training Loss: 0.0613; Training Acc: 98.1057\n",
1100 | "Test Loss: 0.0907; Test Acc: 97.2155\n",
1101 | "------------\n",
1102 | "Epoch 4\n",
1103 | "Training Loss: 0.0452; Training Acc: 98.5676\n",
1104 | "Test Loss: 0.0839; Test Acc: 97.3658\n",
1105 | "------------\n",
1106 | "Epoch 5\n",
1107 | "Training Loss: 0.0336; Training Acc: 98.9695\n",
1108 | "Test Loss: 0.0789; Test Acc: 97.5761\n",
1109 | "------------\n",
1110 | "Epoch 6\n",
1111 | "Training Loss: 0.0219; Training Acc: 99.3647\n",
1112 | "Test Loss: 0.0681; Test Acc: 97.8365\n",
1113 | "------------\n",
1114 | "Epoch 7\n",
1115 | "Training Loss: 0.0128; Training Acc: 99.7182\n",
1116 | "Test Loss: 0.0670; Test Acc: 97.9267\n",
1117 | "------------\n",
1118 | "Epoch 8\n",
1119 | "Training Loss: 0.0116; Training Acc: 99.7665\n",
1120 | "Test Loss: 0.0672; Test Acc: 97.9267\n",
1121 | "------------\n",
1122 | "Epoch 9\n",
1123 | "Training Loss: 0.0110; Training Acc: 99.7816\n",
1124 | "Test Loss: 0.0667; Test Acc: 97.9367\n",
1125 | "------------\n",
1126 | "Epoch 10\n",
1127 | "Training Loss: 0.0106; Training Acc: 99.7932\n",
1128 | "Test Loss: 0.0669; Test Acc: 97.9567\n",
1129 | "------------\n",
1130 | "Epoch 11\n",
1131 | "Training Loss: 0.0102; Training Acc: 99.8032\n",
1132 | "Test Loss: 0.0672; Test Acc: 97.9667\n",
1133 | "------------\n",
1134 | "Epoch 12\n",
1135 | "Training Loss: 0.0100; Training Acc: 99.8299\n",
1136 | "Test Loss: 0.0673; Test Acc: 98.0068\n",
1137 | "------------\n",
1138 | "Epoch 13\n",
1139 | "Training Loss: 0.0097; Training Acc: 99.8199\n",
1140 | "Test Loss: 0.0674; Test Acc: 97.9968\n",
1141 | "------------\n",
1142 | "Epoch 14\n",
1143 | "Training Loss: 0.0095; Training Acc: 99.8332\n",
1144 | "Test Loss: 0.0677; Test Acc: 98.0068\n",
1145 | "------------\n",
1146 | "Epoch 15\n",
1147 | "Training Loss: 0.0093; Training Acc: 99.8432\n",
1148 | "Test Loss: 0.0677; Test Acc: 98.0068\n",
1149 | "------------\n",
1150 | "Epoch 16\n",
1151 | "Training Loss: 0.0092; Training Acc: 99.8466\n",
1152 | "Test Loss: 0.0679; Test Acc: 97.9868\n",
1153 | "------------\n",
1154 | "Epoch 17\n",
1155 | "Training Loss: 0.0090; Training Acc: 99.8483\n",
1156 | "Test Loss: 0.0685; Test Acc: 98.0469\n",
1157 | "------------\n",
1158 | "Epoch 18\n",
1159 | "Training Loss: 0.0089; Training Acc: 99.8466\n",
1160 | "Test Loss: 0.0687; Test Acc: 97.9768\n",
1161 | "------------\n",
1162 | "Epoch 19\n",
1163 | "Training Loss: 0.0088; Training Acc: 99.8599\n",
1164 | "Test Loss: 0.0685; Test Acc: 98.0168\n",
1165 | "------------\n",
1166 | "Epoch 20\n",
1167 | "Training Loss: 0.0087; Training Acc: 99.8533\n",
1168 | "Test Loss: 0.0689; Test Acc: 98.0469\n",
1169 | "------------\n",
1170 | "\n",
1171 | "Training complete in 5m 15s\n"
1172 | ]
1173 | }
1174 | ],
1175 | "source": [
1176 | "learning_rate = 5e-3\n",
1177 | "optimizer = optim.SGD(sparse_model_dynamic.parameters(), lr=learning_rate, momentum=0.9)\n",
1178 | "train_model(sparse_model_dynamic, optimizer, criterion, train_dataloader, test_dataloader)"
1179 | ]
1180 | },
1181 | {
1182 | "cell_type": "markdown",
1183 | "metadata": {},
1184 | "source": [
1185 | "## Training sparse model with small-world connections \n",
1186 | "\n",
1187 | "Some sparsity patterns tend to perform better than others. Small-world sparsity provides a network that is mostly locally connected with a few global, long-range connections scattered in. See [here](https://en.wikipedia.org/wiki/Small-world_network). We implement an initialization strategy to incorporate small-world sparsity in the model. To specify, set `small_world` to `True`. "
1188 | ]
1189 | },
1190 | {
1191 | "cell_type": "code",
1192 | "execution_count": 28,
1193 | "metadata": {
1194 | "tags": []
1195 | },
1196 | "outputs": [],
1197 | "source": [
1198 | "sparse_model_sw = nn.Sequential(\n",
1199 | "\tFlatten(),\n",
1200 | "\tsl.SparseLinear(784, 2000, small_world=True),\n",
1201 | " nn.LayerNorm(2000),\n",
1202 | "\tnn.ReLU(),\n",
1203 | " sl.SparseLinear(2000, 10, small_world=True)\n",
1204 | ")\n",
1205 | "sparse_model_sw = sparse_model_sw.to(device)"
1206 | ]
1207 | },
1208 | {
1209 | "cell_type": "code",
1210 | "execution_count": 29,
1211 | "metadata": {
1212 | "tags": []
1213 | },
1214 | "outputs": [
1215 | {
1216 | "name": "stdout",
1217 | "output_type": "stream",
1218 | "text": [
1219 | "Epoch 1\n",
1220 | "Training Loss: 0.2043; Training Acc: 93.7817\n",
1221 | "Test Loss: 0.1040; Test Acc: 96.7748\n",
1222 | "------------\n",
1223 | "Epoch 2\n",
1224 | "Training Loss: 0.0856; Training Acc: 97.3202\n",
1225 | "Test Loss: 0.0853; Test Acc: 97.2756\n",
1226 | "------------\n",
1227 | "Epoch 3\n",
1228 | "Training Loss: 0.0573; Training Acc: 98.2057\n",
1229 | "Test Loss: 0.0816; Test Acc: 97.5661\n",
1230 | "------------\n",
1231 | "Epoch 4\n",
1232 | "Training Loss: 0.0404; Training Acc: 98.7176\n",
1233 | "Test Loss: 0.0786; Test Acc: 97.5761\n",
1234 | "------------\n",
1235 | "Epoch 5\n",
1236 | "Training Loss: 0.0297; Training Acc: 99.0211\n",
1237 | "Test Loss: 0.0741; Test Acc: 97.7464\n",
1238 | "------------\n",
1239 | "Epoch 6\n",
1240 | "Training Loss: 0.0211; Training Acc: 99.3246\n",
1241 | "Test Loss: 0.0740; Test Acc: 97.9267\n",
1242 | "------------\n",
1243 | "Epoch 7\n",
1244 | "Training Loss: 0.0162; Training Acc: 99.4597\n",
1245 | "Test Loss: 0.0808; Test Acc: 97.8666\n",
1246 | "------------\n",
1247 | "Epoch 8\n",
1248 | "Training Loss: 0.0108; Training Acc: 99.7132\n",
1249 | "Test Loss: 0.0900; Test Acc: 97.6462\n",
1250 | "------------\n",
1251 | "Epoch 9\n",
1252 | "Training Loss: 0.0090; Training Acc: 99.7148\n",
1253 | "Test Loss: 0.0858; Test Acc: 97.8666\n",
1254 | "------------\n",
1255 | "Epoch 10\n",
1256 | "Training Loss: 0.0055; Training Acc: 99.8666\n",
1257 | "Test Loss: 0.0830; Test Acc: 98.0469\n",
1258 | "------------\n",
1259 | "Epoch 11\n",
1260 | "Training Loss: 0.0029; Training Acc: 99.9516\n",
1261 | "Test Loss: 0.0785; Test Acc: 98.0068\n",
1262 | "------------\n",
1263 | "Epoch 12\n",
1264 | "Training Loss: 0.0013; Training Acc: 99.9950\n",
1265 | "Test Loss: 0.0808; Test Acc: 98.2472\n",
1266 | "------------\n",
1267 | "Epoch 13\n",
1268 | "Training Loss: 0.0007; Training Acc: 100.0000\n",
1269 | "Test Loss: 0.0796; Test Acc: 98.2071\n",
1270 | "------------\n",
1271 | "Epoch 14\n",
1272 | "Training Loss: 0.0005; Training Acc: 100.0000\n",
1273 | "Test Loss: 0.0811; Test Acc: 98.1571\n",
1274 | "------------\n",
1275 | "Epoch 15\n",
1276 | "Training Loss: 0.0005; Training Acc: 100.0000\n",
1277 | "Test Loss: 0.0823; Test Acc: 98.1871\n",
1278 | "------------\n",
1279 | "Epoch 16\n",
1280 | "Training Loss: 0.0004; Training Acc: 100.0000\n",
1281 | "Test Loss: 0.0827; Test Acc: 98.1871\n",
1282 | "------------\n",
1283 | "Epoch 17\n",
1284 | "Training Loss: 0.0004; Training Acc: 100.0000\n",
1285 | "Test Loss: 0.0841; Test Acc: 98.2071\n",
1286 | "------------\n",
1287 | "Epoch 18\n",
1288 | "Training Loss: 0.0003; Training Acc: 100.0000\n",
1289 | "Test Loss: 0.0844; Test Acc: 98.1671\n",
1290 | "------------\n",
1291 | "Epoch 19\n",
1292 | "Training Loss: 0.0003; Training Acc: 100.0000\n",
1293 | "Test Loss: 0.0849; Test Acc: 98.1871\n",
1294 | "------------\n",
1295 | "Epoch 20\n",
1296 | "Training Loss: 0.0003; Training Acc: 100.0000\n",
1297 | "Test Loss: 0.0854; Test Acc: 98.2171\n",
1298 | "------------\n",
1299 | "\n",
1300 | "Training complete in 5m 21s\n"
1301 | ]
1302 | }
1303 | ],
1304 | "source": [
1305 | "learning_rate = 1e-1\n",
1306 | "optimizer = optim.SGD(sparse_model_sw.parameters(), lr=learning_rate, momentum=0.9)\n",
1307 | "train_model(sparse_model_sw, optimizer, criterion, train_dataloader, test_dataloader)"
1308 | ]
1309 | },
1310 | {
1311 | "cell_type": "markdown",
1312 | "metadata": {},
1313 | "source": [
1314 | "## Utilizing the activation sparsity feature \n",
1315 | "\n",
1316 | "The `SparseLinear` layer is constructed for parameter sparsity; however, we make no stipulations on the sparsity (or density) of the activations. We include an option for sparse activations using the K-Winners strategy. This paper describes a potential method ([k-winners](https://arxiv.org/pdf/1903.11257.pdf) layer) which we use to train both linear and sparse linear models."
1317 | ]
1318 | },
1319 | {
1320 | "cell_type": "code",
1321 | "execution_count": 30,
1322 | "metadata": {},
1323 | "outputs": [],
1324 | "source": [
1325 | "import activationsparsity as asy"
1326 | ]
1327 | },
1328 | {
1329 | "cell_type": "markdown",
1330 | "metadata": {},
1331 | "source": [
1332 | "Below we train a linear model using this activation sparsity feature. By default, we set `act_sparsity=0.65` (which means `k=(1-0.65)*2000`) for the layer below. "
1333 | ]
1334 | },
1335 | {
1336 | "cell_type": "code",
1337 | "execution_count": 31,
1338 | "metadata": {},
1339 | "outputs": [],
1340 | "source": [
1341 | "model_asy = nn.Sequential(\n",
1342 | "\tFlatten(),\n",
1343 | " nn.Linear(784, 2000),\n",
1344 | " nn.LayerNorm(2000),\n",
1345 | " asy.ActivationSparsity(),\n",
1346 | " nn.Linear(2000,10)\n",
1347 | ")\n",
1348 | "model_asy = model_asy.to(device)"
1349 | ]
1350 | },
1351 | {
1352 | "cell_type": "code",
1353 | "execution_count": 32,
1354 | "metadata": {},
1355 | "outputs": [
1356 | {
1357 | "name": "stdout",
1358 | "output_type": "stream",
1359 | "text": [
1360 | "Epoch 1\n",
1361 | "Training Loss: 0.2241; Training Acc: 93.3548\n",
1362 | "Test Loss: 0.1153; Test Acc: 96.5845\n",
1363 | "------------\n",
1364 | "Epoch 2\n",
1365 | "Training Loss: 0.0901; Training Acc: 97.3102\n",
1366 | "Test Loss: 0.0894; Test Acc: 97.1855\n",
1367 | "------------\n",
1368 | "Epoch 3\n",
1369 | "Training Loss: 0.0590; Training Acc: 98.3825\n",
1370 | "Test Loss: 0.0785; Test Acc: 97.6462\n",
1371 | "------------\n",
1372 | "Epoch 4\n",
1373 | "Training Loss: 0.0415; Training Acc: 98.8877\n",
1374 | "Test Loss: 0.0671; Test Acc: 97.8866\n",
1375 | "------------\n",
1376 | "Epoch 5\n",
1377 | "Training Loss: 0.0303; Training Acc: 99.2529\n",
1378 | "Test Loss: 0.0616; Test Acc: 98.0569\n",
1379 | "------------\n",
1380 | "Epoch 6\n",
1381 | "Training Loss: 0.0222; Training Acc: 99.5114\n",
1382 | "Test Loss: 0.0602; Test Acc: 98.1270\n",
1383 | "------------\n",
1384 | "Epoch 7\n",
1385 | "Training Loss: 0.0168; Training Acc: 99.6748\n",
1386 | "Test Loss: 0.0611; Test Acc: 98.0569\n",
1387 | "------------\n",
1388 | "Epoch 8\n",
1389 | "Training Loss: 0.0128; Training Acc: 99.8366\n",
1390 | "Test Loss: 0.0604; Test Acc: 98.0970\n",
1391 | "------------\n",
1392 | "Epoch 9\n",
1393 | "Training Loss: 0.0102; Training Acc: 99.8849\n",
1394 | "Test Loss: 0.0569; Test Acc: 98.2472\n",
1395 | "------------\n",
1396 | "Epoch 10\n",
1397 | "Training Loss: 0.0079; Training Acc: 99.9400\n",
1398 | "Test Loss: 0.0581; Test Acc: 98.1971\n",
1399 | "------------\n",
1400 | "Epoch 11\n",
1401 | "Training Loss: 0.0065; Training Acc: 99.9633\n",
1402 | "Test Loss: 0.0559; Test Acc: 98.3173\n",
1403 | "------------\n",
1404 | "Epoch 12\n",
1405 | "Training Loss: 0.0053; Training Acc: 99.9733\n",
1406 | "Test Loss: 0.0538; Test Acc: 98.3874\n",
1407 | "------------\n",
1408 | "Epoch 13\n",
1409 | "Training Loss: 0.0045; Training Acc: 99.9917\n",
1410 | "Test Loss: 0.0543; Test Acc: 98.3474\n",
1411 | "------------\n",
1412 | "Epoch 14\n",
1413 | "Training Loss: 0.0040; Training Acc: 99.9950\n",
1414 | "Test Loss: 0.0576; Test Acc: 98.3073\n",
1415 | "------------\n",
1416 | "Epoch 15\n",
1417 | "Training Loss: 0.0036; Training Acc: 99.9900\n",
1418 | "Test Loss: 0.0555; Test Acc: 98.3273\n",
1419 | "------------\n",
1420 | "Epoch 16\n",
1421 | "Training Loss: 0.0032; Training Acc: 99.9983\n",
1422 | "Test Loss: 0.0556; Test Acc: 98.3974\n",
1423 | "------------\n",
1424 | "Epoch 17\n",
1425 | "Training Loss: 0.0029; Training Acc: 99.9967\n",
1426 | "Test Loss: 0.0581; Test Acc: 98.3574\n",
1427 | "------------\n",
1428 | "Epoch 18\n",
1429 | "Training Loss: 0.0026; Training Acc: 100.0000\n",
1430 | "Test Loss: 0.0577; Test Acc: 98.2472\n",
1431 | "------------\n",
1432 | "Epoch 19\n",
1433 | "Training Loss: 0.0024; Training Acc: 100.0000\n",
1434 | "Test Loss: 0.0568; Test Acc: 98.2372\n",
1435 | "------------\n",
1436 | "Epoch 20\n",
1437 | "Training Loss: 0.0022; Training Acc: 99.9983\n",
1438 | "Test Loss: 0.0561; Test Acc: 98.3574\n",
1439 | "------------\n",
1440 | "\n",
1441 | "Training complete in 6m 5s\n"
1442 | ]
1443 | }
1444 | ],
1445 | "source": [
1446 | "learning_rate = 5e-3\n",
1447 | "optimizer = optim.SGD(model_asy.parameters(), lr=learning_rate, momentum=0.9)\n",
1448 | "\n",
1449 | "#Perform the training \n",
1450 | "train_model(model_asy, optimizer, criterion, train_dataloader, test_dataloader)"
1451 | ]
1452 | },
1453 | {
1454 | "cell_type": "markdown",
1455 | "metadata": {},
1456 | "source": [
1457 | "Now we train another model which uses the sparse linear module along with this activation. As mentioned before, the learning rate is an order of magnitude higher than the linear module. "
1458 | ]
1459 | },
1460 | {
1461 | "cell_type": "code",
1462 | "execution_count": 33,
1463 | "metadata": {},
1464 | "outputs": [],
1465 | "source": [
1466 | "model_asy_sparse = nn.Sequential(\n",
1467 | "\tFlatten(),\n",
1468 | "\tsl.SparseLinear(784, 2000),\n",
1469 | " nn.LayerNorm(2000),\n",
1470 | " asy.ActivationSparsity(),\n",
1471 | " sl.SparseLinear(2000, 10),\n",
1472 | ")\n",
1473 | "model_asy_sparse = model_asy_sparse.to(device)"
1474 | ]
1475 | },
1476 | {
1477 | "cell_type": "code",
1478 | "execution_count": 34,
1479 | "metadata": {},
1480 | "outputs": [
1481 | {
1482 | "name": "stdout",
1483 | "output_type": "stream",
1484 | "text": [
1485 | "Epoch 1\n",
1486 | "Training Loss: 0.2218; Training Acc: 93.2080\n",
1487 | "Test Loss: 0.1237; Test Acc: 96.1138\n",
1488 | "------------\n",
1489 | "Epoch 2\n",
1490 | "Training Loss: 0.0960; Training Acc: 97.0318\n",
1491 | "Test Loss: 0.0965; Test Acc: 97.0954\n",
1492 | "------------\n",
1493 | "Epoch 3\n",
1494 | "Training Loss: 0.0670; Training Acc: 97.9172\n",
1495 | "Test Loss: 0.0930; Test Acc: 97.1855\n",
1496 | "------------\n",
1497 | "Epoch 4\n",
1498 | "Training Loss: 0.0501; Training Acc: 98.4075\n",
1499 | "Test Loss: 0.0898; Test Acc: 97.2556\n",
1500 | "------------\n",
1501 | "Epoch 5\n",
1502 | "Training Loss: 0.0391; Training Acc: 98.7777\n",
1503 | "Test Loss: 0.0747; Test Acc: 97.7163\n",
1504 | "------------\n",
1505 | "Epoch 6\n",
1506 | "Training Loss: 0.0303; Training Acc: 99.0261\n",
1507 | "Test Loss: 0.0824; Test Acc: 97.5661\n",
1508 | "------------\n",
1509 | "Epoch 7\n",
1510 | "Training Loss: 0.0234; Training Acc: 99.2596\n",
1511 | "Test Loss: 0.0777; Test Acc: 97.7764\n",
1512 | "------------\n",
1513 | "Epoch 8\n",
1514 | "Training Loss: 0.0173; Training Acc: 99.4981\n",
1515 | "Test Loss: 0.0848; Test Acc: 97.4960\n",
1516 | "------------\n",
1517 | "Epoch 9\n",
1518 | "Training Loss: 0.0138; Training Acc: 99.6398\n",
1519 | "Test Loss: 0.0780; Test Acc: 97.8766\n",
1520 | "------------\n",
1521 | "Epoch 10\n",
1522 | "Training Loss: 0.0096; Training Acc: 99.7699\n",
1523 | "Test Loss: 0.0840; Test Acc: 97.7764\n",
1524 | "------------\n",
1525 | "Epoch 11\n",
1526 | "Training Loss: 0.0083; Training Acc: 99.8066\n",
1527 | "Test Loss: 0.0860; Test Acc: 97.8365\n",
1528 | "------------\n",
1529 | "Epoch 12\n",
1530 | "Training Loss: 0.0068; Training Acc: 99.8616\n",
1531 | "Test Loss: 0.0851; Test Acc: 97.8866\n",
1532 | "------------\n",
1533 | "Epoch 13\n",
1534 | "Training Loss: 0.0049; Training Acc: 99.9200\n",
1535 | "Test Loss: 0.0787; Test Acc: 98.0369\n",
1536 | "------------\n",
1537 | "Epoch 14\n",
1538 | "Training Loss: 0.0033; Training Acc: 99.9666\n",
1539 | "Test Loss: 0.0841; Test Acc: 98.0168\n",
1540 | "------------\n",
1541 | "Epoch 15\n",
1542 | "Training Loss: 0.0025; Training Acc: 99.9783\n",
1543 | "Test Loss: 0.0845; Test Acc: 97.8866\n",
1544 | "------------\n",
1545 | "Epoch 16\n",
1546 | "Training Loss: 0.0022; Training Acc: 99.9850\n",
1547 | "Test Loss: 0.0857; Test Acc: 98.0469\n",
1548 | "------------\n",
1549 | "Epoch 17\n",
1550 | "Training Loss: 0.0017; Training Acc: 99.9933\n",
1551 | "Test Loss: 0.0850; Test Acc: 98.1270\n",
1552 | "------------\n",
1553 | "Epoch 18\n",
1554 | "Training Loss: 0.0013; Training Acc: 99.9967\n",
1555 | "Test Loss: 0.0890; Test Acc: 98.0268\n",
1556 | "------------\n",
1557 | "Epoch 19\n",
1558 | "Training Loss: 0.0012; Training Acc: 99.9967\n",
1559 | "Test Loss: 0.0891; Test Acc: 98.1370\n",
1560 | "------------\n",
1561 | "Epoch 20\n",
1562 | "Training Loss: 0.0010; Training Acc: 99.9983\n",
1563 | "Test Loss: 0.0882; Test Acc: 98.0669\n",
1564 | "------------\n",
1565 | "\n",
1566 | "Training complete in 6m 24s\n"
1567 | ]
1568 | }
1569 | ],
1570 | "source": [
1571 | "learning_rate = 5e-2\n",
1572 | "optimizer = optim.SGD(model_asy_sparse.parameters(), lr=learning_rate, momentum=0.9)\n",
1573 | "\n",
1574 | "#Perform the training \n",
1575 | "train_model(model_asy_sparse, optimizer, criterion, train_dataloader, test_dataloader)"
1576 | ]
1577 | },
1578 | {
1579 | "cell_type": "markdown",
1580 | "metadata": {},
1581 | "source": [
1582 | "## Training very wide and sparse models \n",
1583 | "\n",
1584 | "The main advantage of utilizing sparse tensors is that it enables us to train very wide models. Below we demonstrate an example of such a model. Of course, it is just a demonstration and the key take away is that we can build these huge models for more complex tasks where the benefits would be more viable. "
1585 | ]
1586 | },
1587 | {
1588 | "cell_type": "code",
1589 | "execution_count": 39,
1590 | "metadata": {},
1591 | "outputs": [],
1592 | "source": [
1593 | "import torch.nn as nn\n",
1594 | "import torch.nn.functional as F\n",
1595 | "\n",
1596 | "class Net(nn.Module):\n",
1597 | " def __init__(self):\n",
1598 | " super(Net, self).__init__()\n",
1599 | " self.sc1 = sl.SparseLinear(10 * 28 * 28, 50000, sparsity=0.999)\n",
1600 | " self.sc2 = sl.SparseLinear(50000, 50000, sparsity=0.999)\n",
1601 | " self.sc3 = sl.SparseLinear(50000, 50000, sparsity=0.999)\n",
1602 | " self.sc4 = sl.SparseLinear(50000, 50000, sparsity=0.999)\n",
1603 | " self.sc5 = sl.SparseLinear(50000, 50000, sparsity=0.999)\n",
1604 | " \n",
1605 | " self.input_scaling = nn.Parameter(torch.ones(10 * 28 * 28))\n",
1606 | " self.input_shifting = nn.Parameter(torch.zeros(10 * 28 * 28))\n",
1607 | " self.ln1 = nn.LayerNorm(50000)\n",
1608 | " self.ln2 = nn.LayerNorm(50000)\n",
1609 | " self.ln3 = nn.LayerNorm(50000)\n",
1610 | " self.ln4 = nn.LayerNorm(50000)\n",
1611 | " \n",
1612 | " def forward(self, x):\n",
1613 | " x = x.view(-1, 28 * 28)\n",
1614 | " x = torch.repeat_interleave(x, 10, dim=1)\n",
1615 | " x = self.input_scaling * x + self.input_shifting\n",
1616 | " x = F.relu(self.ln1(self.sc1(x)))\n",
1617 | " x = F.relu(self.ln2(self.sc2(x)))\n",
1618 | " x = F.relu(self.ln3(self.sc3(x)))\n",
1619 | " x = F.relu(self.ln4(self.sc4(x)))\n",
1620 | " x = self.sc5(x)\n",
1621 | " x = x.view(x.shape[0], -1, 10).sum(dim=1) # sum 5000 outputs per class\n",
1622 | " return x\n",
1623 | "\n",
1624 | "sparse_big = Net().to(device)"
1625 | ]
1626 | },
1627 | {
1628 | "cell_type": "code",
1629 | "execution_count": 40,
1630 | "metadata": {
1631 | "tags": []
1632 | },
1633 | "outputs": [
1634 | {
1635 | "name": "stdout",
1636 | "output_type": "stream",
1637 | "text": [
1638 | "Epoch 1\n",
1639 | "Training Loss: 0.1993; Training Acc: 93.9168\n",
1640 | "Test Loss: 0.1317; Test Acc: 95.9034\n",
1641 | "------------\n",
1642 | "Epoch 2\n",
1643 | "Training Loss: 0.0675; Training Acc: 98.0023\n",
1644 | "Test Loss: 0.0806; Test Acc: 97.5160\n",
1645 | "------------\n",
1646 | "Epoch 3\n",
1647 | "Training Loss: 0.0318; Training Acc: 99.2663\n",
1648 | "Test Loss: 0.0723; Test Acc: 97.6863\n",
1649 | "------------\n",
1650 | "Epoch 4\n",
1651 | "Training Loss: 0.0167; Training Acc: 99.7532\n",
1652 | "Test Loss: 0.0630; Test Acc: 97.9968\n",
1653 | "------------\n",
1654 | "Epoch 5\n",
1655 | "Training Loss: 0.0088; Training Acc: 99.9483\n",
1656 | "Test Loss: 0.0609; Test Acc: 98.0970\n",
1657 | "------------\n",
1658 | "Epoch 6\n",
1659 | "Training Loss: 0.0056; Training Acc: 99.9917\n",
1660 | "Test Loss: 0.0584; Test Acc: 98.2272\n",
1661 | "------------\n",
1662 | "Epoch 7\n",
1663 | "Training Loss: 0.0041; Training Acc: 100.0000\n",
1664 | "Test Loss: 0.0587; Test Acc: 98.1771\n",
1665 | "------------\n",
1666 | "Epoch 8\n",
1667 | "Training Loss: 0.0033; Training Acc: 100.0000\n",
1668 | "Test Loss: 0.0583; Test Acc: 98.1671\n",
1669 | "------------\n",
1670 | "Epoch 9\n",
1671 | "Training Loss: 0.0028; Training Acc: 100.0000\n",
1672 | "Test Loss: 0.0587; Test Acc: 98.1370\n",
1673 | "------------\n",
1674 | "Epoch 10\n",
1675 | "Training Loss: 0.0024; Training Acc: 100.0000\n",
1676 | "Test Loss: 0.0581; Test Acc: 98.1671\n",
1677 | "------------\n",
1678 | "Epoch 11\n",
1679 | "Training Loss: 0.0021; Training Acc: 100.0000\n",
1680 | "Test Loss: 0.0580; Test Acc: 98.1771\n",
1681 | "------------\n",
1682 | "Epoch 12\n",
1683 | "Training Loss: 0.0019; Training Acc: 100.0000\n",
1684 | "Test Loss: 0.0577; Test Acc: 98.2272\n",
1685 | "------------\n",
1686 | "Epoch 13\n",
1687 | "Training Loss: 0.0017; Training Acc: 100.0000\n",
1688 | "Test Loss: 0.0577; Test Acc: 98.2372\n",
1689 | "------------\n",
1690 | "Epoch 14\n",
1691 | "Training Loss: 0.0016; Training Acc: 100.0000\n",
1692 | "Test Loss: 0.0586; Test Acc: 98.1771\n",
1693 | "------------\n",
1694 | "Epoch 15\n",
1695 | "Training Loss: 0.0014; Training Acc: 100.0000\n",
1696 | "Test Loss: 0.0582; Test Acc: 98.1971\n",
1697 | "------------\n",
1698 | "Epoch 16\n",
1699 | "Training Loss: 0.0013; Training Acc: 100.0000\n",
1700 | "Test Loss: 0.0582; Test Acc: 98.2372\n",
1701 | "------------\n",
1702 | "Epoch 17\n",
1703 | "Training Loss: 0.0012; Training Acc: 100.0000\n",
1704 | "Test Loss: 0.0582; Test Acc: 98.2071\n",
1705 | "------------\n",
1706 | "Epoch 18\n",
1707 | "Training Loss: 0.0012; Training Acc: 100.0000\n",
1708 | "Test Loss: 0.0584; Test Acc: 98.1971\n",
1709 | "------------\n",
1710 | "Epoch 19\n",
1711 | "Training Loss: 0.0011; Training Acc: 100.0000\n",
1712 | "Test Loss: 0.0582; Test Acc: 98.2272\n",
1713 | "------------\n",
1714 | "Epoch 20\n",
1715 | "Training Loss: 0.0010; Training Acc: 100.0000\n",
1716 | "Test Loss: 0.0583; Test Acc: 98.1871\n",
1717 | "------------\n",
1718 | "\n",
1719 | "Training complete in 39m 29s\n"
1720 | ]
1721 | }
1722 | ],
1723 | "source": [
1724 | "learning_rate = 5e-5\n",
1725 | "optimizer = optim.SGD(sparse_big.parameters(), lr=learning_rate, momentum=0.9, nesterov=True)\n",
1726 | "train_model(sparse_big, optimizer, criterion, train_dataloader, test_dataloader)"
1727 | ]
1728 | },
1729 | {
1730 | "cell_type": "markdown",
1731 | "metadata": {},
1732 | "source": [
1733 | "In conclusion, we demonstrated the `SparseLinear` layer. From a user's perspective, it is very similar to PyTorch's `Linear` layer. We also showed extra features namely user-defined sparsity, dynamic sparsity, small-world connectivity, and activation sparsity. \n",
1734 | "\n",
1735 | "Our experiments showed that even with a huge reduction in parameters, we were able to achieve a performance similar to that of massively parameterised layers. \n",
1736 | "\n",
1737 | "We hope this excites and enables people to build highly scalable sparse networks!"
1738 | ]
1739 | },
1740 | {
1741 | "cell_type": "markdown",
1742 | "metadata": {},
1743 | "source": [
1744 | ""
1745 | ]
1746 | }
1747 | ],
1748 | "metadata": {
1749 | "kernelspec": {
1750 | "display_name": "Environment (conda_pytorch_p36)",
1751 | "language": "python",
1752 | "name": "conda_pytorch_p36"
1753 | },
1754 | "language_info": {
1755 | "codemirror_mode": {
1756 | "name": "ipython",
1757 | "version": 3
1758 | },
1759 | "file_extension": ".py",
1760 | "mimetype": "text/x-python",
1761 | "name": "python",
1762 | "nbconvert_exporter": "python",
1763 | "pygments_lexer": "ipython3",
1764 | "version": "3.6.10"
1765 | }
1766 | },
1767 | "nbformat": 4,
1768 | "nbformat_minor": 2
1769 | }
1770 |
--------------------------------------------------------------------------------