├── Bipartite_small_world_network.pdf ├── LICENSE ├── README.md ├── setup.py ├── sparselinear ├── __init__.py ├── activationsparsity.py └── sparselinear.py └── tutorials └── SparseLinearDemo.ipynb /Bipartite_small_world_network.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hyeon95y/SparseLinear/35b4a9cf843ecc9ce560b37615f369fdc6aad4cf/Bipartite_small_world_network.pdf -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Copyright (c) 2020 Rain Neuromorphics Inc. 2 | 3 | Permission is hereby granted, free of charge, to any person obtaining a copy 4 | of this software and associated documentation files (the "Software"), to deal 5 | in the Software without restriction, including without limitation the rights 6 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 7 | copies of the Software, and to permit persons to whom the Software is 8 | furnished to do so, subject to the following conditions: 9 | 10 | The above copyright notice and this permission notice shall be included in all 11 | copies or substantial portions of the Software. 12 | 13 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 14 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 15 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 16 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 17 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 18 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 19 | SOFTWARE. 20 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # SparseLinear 2 | 3 | SparseLinear is a PyTorch package that allows a user to create extremely wide and sparse linear layers efficiently. A sparsely connected network is a network where each node is connected to a fraction of available nodes. This differs from a fully connected network, where each node in one layer is connected to every node in the next layer. 4 | 5 | The provided layer along with the dynamic activation sparsity module is compatible with backpropagation. The sparse linear layer is initialized with sparsity, supports unstructured sparsity and allows dynamic growth and pruning. We achieve this by building a linear layer on top of [PyTorch Sparse](https://github.com/rusty1s/pytorch_sparse), which provides optimized sparse matrix operations with autograd support in PyTorch. 6 | 7 | ## Table of Contents 8 | 9 | - [More about SparseLinear](#intro) 10 | - [More about Dynamic Activation](#kwin) 11 | - [Installation](#install) 12 | - [Getting Started](#usage) 13 | 14 | ## More about SparseLinear 15 | The default arguments initialize a sparse linear layer with random connections that applies a linear transformation to the incoming data 16 | 17 | #### Parameters 18 | 19 | - **in_features** - size of each input sample 20 | - **out_features** - size of each output sample 21 | - **bias** - If set to ``False``, the layer will not learn an additive bias. Default: ``True`` 22 | - **sparsity** - sparsity of weight matrix. Default: `0.9` 23 | - **connectivity** - user-defined sparsity matrix. Default: `None` 24 | - **small_world** - boolean flag to generate small-world sparsity. Default: ``False`` 25 | - **dynamic** - boolean flag to dynamically change the network structure. Default: ``False`` 26 | - **deltaT** - frequency for growing and pruning update step. Default: `6000` 27 | - **Tend** - stopping time for growing and pruning algorithm update step. Default: `150000` 28 | - **alpha** - f-decay parameter for cosine updates. Default: `0.1` 29 | - **max_size** - maximum number of entries allowed before chunking occurrs for small-world network generation and dynamic connections. Default: `1e8` 30 | 31 | #### Shape 32 | 33 | - Input: `(N, *, H_{in})` where `*` means any number of additional dimensions and `H_{in} = in_features` 34 | - Output: `(N, *, H_{out})` where all but the last dimension are the same shape as the input and `H_{out} = out_features` 35 | 36 | #### Variables 37 | 38 | - **~SparseLinear.weight** - the learnable weights of the module of shape `(out_features, in_features)`. The values are initialized from , where 39 | - **~SparseLinear.bias** - the learnable bias of the module of shape `(out_features)`. If `bias` is ``True``, the values are initialized from where 40 | 41 | #### Examples: 42 | 43 | ```python 44 | >>> m = sl.SparseLinear(20, 30) 45 | >>> input = torch.randn(128, 20) 46 | >>> output = m(input) 47 | >>> print(output.size()) 48 | torch.Size([128, 30]) 49 | ``` 50 | 51 | The following customization can also be done using appropriate arguments - 52 | 53 | #### User-defined Sparsity 54 | 55 | One can choose to add self-defined static sparsity. The `connectivity` flag accepts a (2, nnz) LongTensor that represents the rows and columns of nonzero elements in the layer. 56 | 57 | #### Small-world Sparsity 58 | 59 | The default static sparsity is random. With this flag, one can instead use small-world sparsity. See [here](https://en.wikipedia.org/wiki/Small-world_network). To specify, set `small_world` to `True`. Specifically, we make connections distance-dependent to ensure small-world behavior. 60 | 61 | #### Dynamic Growing and Pruning Algorithm 62 | 63 | The user can grow and prune units during training starting from a sparse configuration using this feature. The implementation is based on [Rigging the lottery](https://arxiv.org/pdf/1911.11134.pdf) algorithm. Specify `dynamic` to be `True` to dynamically alter the layer connections while training. 64 | 65 | ## Dynamic Activation Sparsity 66 | 67 | In addition, we provide a Dynamic Activation Sparsity module to utilize principled, per-layer activation sparsity. The algorithm implementation is based on the [K-Winners strategy](https://arxiv.org/pdf/1903.11257.pdf). 68 | 69 | #### Parameters 70 | 71 | - **alpha** - constant used in updating duty-cycle. Default: `0.1` 72 | - **beta** - boosting factor for neurons not activated in the previous duty cycle. Default: `1.5` 73 | - **act_sparsity** - fraction of the input used in calculating K for K-Winners strategy. Default: `0.65` 74 | 75 | #### Shape 76 | 77 | - Input: `(N, *)` where `*` means, any number of additional dimensions 78 | - Output: `(N, *)`, same shape as the input 79 | 80 | #### Examples: 81 | 82 | ```python 83 | >>> x = asy.ActivationSparsity(10) 84 | >>> input = torch.randn(3,10) 85 | >>> output = x(input) 86 | ``` 87 | 88 | ## Installation 89 | 90 | - Follow the installation instructions and install PyTorch Sparse package from [here](https://github.com/rusty1s/pytorch_sparse). 91 | - Then run ```pip install sparselinear``` 92 | 93 | ## Getting Started 94 | 95 | We provide a Jupyter notebook in [this](https://github.com/rain-neuromorphics/SparseLinear/blob/master/tutorials/SparseLinearDemo.ipynb) repository that demonstrates the basic functionalities of the sparse linear layer. We also show steps to train various models using the additional features of this package. 96 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | 3 | from setuptools import setup, find_packages 4 | 5 | with open("README.md", "r") as fh: 6 | long_description = fh.read() 7 | 8 | install_requires = ['numpy', 'torch'] 9 | setup(name='sparselinear', 10 | version='0.0.5', 11 | description='Pytorch extension library for creating sparse linear layers', 12 | long_description=long_description, 13 | long_description_content_type="text/markdown", 14 | author='Rain Neuromorphics', 15 | author_email='ross@rain-neuromorphics.com', 16 | url='https://github.com/rain-neuromorphics/SparseLinear', 17 | keywords=['pytorch', 'sparse', 'linear'], 18 | license='MIT', 19 | install_requires=install_requires, 20 | packages=find_packages(), 21 | python_requires='>=3.6', 22 | ) 23 | -------------------------------------------------------------------------------- /sparselinear/__init__.py: -------------------------------------------------------------------------------- 1 | from .sparselinear import SparseLinear 2 | from .activationsparsity import ActivationSparsity 3 | 4 | -------------------------------------------------------------------------------- /sparselinear/activationsparsity.py: -------------------------------------------------------------------------------- 1 | import math 2 | import torch 3 | import torch.nn as nn 4 | 5 | class ActivationSparsity(nn.Module): 6 | """Applies activation sparsity to the last dimension of input using K-winners strategy 7 | 8 | Args: 9 | alpha (float): constant used in updating duty-cycle 10 | Default: 0.1 11 | beta (float): boosting factor for neurons not activated in the previous duty cycle 12 | Default: 1.5 13 | act_sparsity (float): fraction of the input used in calculating K for K-Winners strategy 14 | Default: 0.65 15 | 16 | Shape: 17 | - Input: :math:`(N, *)` where `*` means, any number of additional dimensions 18 | - Output: :math:`(N, *)`, same shape as the input 19 | 20 | Examples:: 21 | 22 | >>> x = asy.ActivationSparsity(10) 23 | >>> input = torch.randn(3,10) 24 | >>> output = x(input) 25 | """ 26 | def __init__(self, alpha=0.1, beta=1.5, act_sparsity=0.65): 27 | super(ActivationSparsity, self).__init__() 28 | self.alpha = alpha 29 | self.beta = beta 30 | self.act_sparsity = act_sparsity 31 | self.duty_cycle = None 32 | 33 | def updateDC(self, inputs, duty_cycle): 34 | duty_cycle = (1 - self.alpha) * duty_cycle + self.alpha * (inputs.gt(0).sum(dim=0,dtype=torch.float)) 35 | return duty_cycle 36 | 37 | def forward(self, inputs): 38 | in_features = inputs.shape[-1] 39 | out_shape=list(inputs.shape) 40 | inputs = inputs.reshape(inputs.shape[0],-1) 41 | 42 | device = inputs.device 43 | 44 | if self.duty_cycle is None: 45 | self.duty_cycle = torch.zeros(in_features, requires_grad=True).to(device) 46 | 47 | k = math.floor((1-self.act_sparsity) * in_features) 48 | with torch.no_grad(): 49 | 50 | target = k / inputs.shape[-1] 51 | boost_coefficient = torch.exp(self.beta * (target - self.duty_cycle)) 52 | boosted_input = inputs * boost_coefficient 53 | 54 | # Get top k values 55 | values, indices = boosted_input.topk( k, dim=-1, sorted=False) 56 | row_indices = torch.arange(inputs.shape[0]).repeat_interleave(k).view(-1,k) 57 | 58 | outputs = torch.zeros_like(inputs).to(device) 59 | outputs = outputs.index_put((row_indices, indices), inputs[row_indices, indices], accumulate=False) 60 | 61 | if self.training: 62 | with torch.no_grad(): 63 | self.duty_cycle = self.updateDC(outputs, self.duty_cycle) 64 | 65 | return outputs.view(out_shape) 66 | 67 | def extra_repr(self): 68 | return 'act_sparsity={}, alpha={}, beta={}, duty_cycle={}'.format( 69 | self.act_sparsity, self.alpha, self.beta, self.duty_cycle 70 | ) -------------------------------------------------------------------------------- /sparselinear/sparselinear.py: -------------------------------------------------------------------------------- 1 | import math 2 | import warnings 3 | 4 | import numpy as np 5 | import torch 6 | import torch.nn as nn 7 | 8 | 9 | def small_world_chunker(inputs, outputs, nnz): 10 | """Utility function for small world initialization as presented in the write up Bipartite_small_world_network""" 11 | pair_distance = inputs.view(-1, 1) - outputs 12 | arg = torch.abs(pair_distance) + 1.0 13 | 14 | # lambda search 15 | L, U = 1e-5, 5.0 16 | lamb = 1.0 # initial guess 17 | itr = 1 18 | error_threshold = 10.0 19 | max_itr = 1000 20 | P = arg ** (-lamb) 21 | P_sum = P.sum() 22 | error = abs(P_sum - nnz) 23 | 24 | while error > error_threshold: 25 | assert ( 26 | itr <= max_itr 27 | ), "No solution found; please try different network sizes and sparsity levels" 28 | if P_sum < nnz: 29 | U = lamb 30 | lamb = (lamb + L) / 2.0 31 | elif P_sum > nnz: 32 | L = lamb 33 | lamb = (lamb + U) / 2.0 34 | 35 | P = arg ** (-lamb) 36 | P_sum = P.sum() 37 | error = abs(P_sum - nnz) 38 | itr += 1 39 | return P 40 | 41 | 42 | class GrowConnections(torch.autograd.Function): 43 | """ Custom pytorch function to handle growing connections""" 44 | 45 | @staticmethod 46 | def forward(ctx, inputs, weights, k, indices, features, max_size): 47 | out_features, in_features = features 48 | output_shape = list(inputs.shape) 49 | output_shape[-1] = out_features 50 | if len(output_shape) == 1: 51 | inputs = inputs.view(1, -1) 52 | inputs = inputs.flatten(end_dim=-2) 53 | 54 | # output = torch_sparse.spmm(indices, weights, out_features, in_features, inputs.t()).t() 55 | target = torch.sparse.FloatTensor( 56 | indices, weights, torch.Size([out_features, in_features]), 57 | ).to_dense() 58 | output = torch.mm(target, inputs.t()).t() 59 | 60 | ctx.save_for_backward(inputs, weights, indices) 61 | ctx.in1 = k 62 | ctx.in2 = out_features 63 | ctx.in3 = in_features 64 | ctx.in4 = max_size 65 | 66 | return output 67 | 68 | @staticmethod 69 | def backward(ctx, grad_output): 70 | inputs, weights, indices = ctx.saved_tensors 71 | k = ctx.in1 72 | out_features = ctx.in2 73 | in_features = ctx.in3 74 | max_size = ctx.in4 75 | 76 | device = grad_output.device 77 | p_index = torch.LongTensor([1, 0]) 78 | new_indices = torch.zeros_like(indices).to(device=device) 79 | new_indices[p_index] = indices 80 | 81 | # grad_input = torch_sparse.spmm(new_indices, weights, in_features, out_features, grad_output.t()).t() 82 | target = torch.sparse.FloatTensor( 83 | new_indices, weights, torch.Size([in_features, out_features]), 84 | ).to_dense() 85 | grad_input = torch.mm(target, grad_outputs.t()).t() 86 | 87 | if in_features * out_features <= max_size: 88 | grad_weights = torch.matmul(inputs.t(), grad_output) 89 | grad_weights = torch.abs(grad_weights.t()) 90 | mask = torch.ones_like(grad_weights) 91 | mask[indices[0], indices[1]] = 0 92 | 93 | masked_weights = mask * grad_weights 94 | _, lm_indices = torch.topk(masked_weights.reshape(-1), k, sorted=False) 95 | row = lm_indices.floor_divide(in_features) 96 | col = lm_indices.fmod(in_features) 97 | else: 98 | tk = None 99 | m = max_size / in_features 100 | chunks = math.ceil(out_features / m) 101 | 102 | for item in range(chunks): 103 | if item != chunks - 1: 104 | sliced_input = inputs.t()[item * m : (item + 1) * m, :] 105 | grad_m = torch.matmul(sliced_input, grad_output).t() 106 | grad_m_abs = torch.abs(grad_m) 107 | topk_values, topk_indices = torch.topk( 108 | grad_m_abs.view(-1), k, sorted=False, 109 | ) 110 | else: 111 | grad_m = torch.matmul(inputs.t()[item * m :, :], grad_output).t() 112 | grad_m_abs = torch.abs(grad_m) 113 | topk_values, topk_indices = torch.topk( 114 | grad_m_abs.view(-1), k, sorted=False, 115 | ) 116 | 117 | row = ( 118 | topk_indices.floor_divide(in_features) 119 | + torch.ones_like(topk_indices) * item * m 120 | ) 121 | col = topk_indices.fmod(in_features) 122 | indices = torch.stack((row, col)) 123 | 124 | if tk is None: 125 | tk = torch.cat((topk_values, indices), dim=0) 126 | else: 127 | topk_values_prev = tk[0] 128 | concat_values = torch.cat( 129 | (topk_values_prev, topk_values), dim=1, 130 | ).view(-1) 131 | topk_values_2k, topk_indices_2k = torch.topk( 132 | concat_values, k, sorted=False, 133 | ) 134 | 135 | # Get the topk indices from the combination of two indices 136 | topk_prev = topk_indices_2k[topk_indices_2k < k] 137 | topk_values_indices = tk[:][topk_prev] 138 | 139 | topk_curr = topk_indices_2k[topk_indices_2k >= k] 140 | topk_curr = topk_curr % k 141 | 142 | curr_indices = indices[:][topk_curr] 143 | curr_values = topk_values[topk_curr] 144 | 145 | indices_values = torch.cat((curr_indices, curr_values), dim=0) 146 | tk = torch.cat((topk_values_indices, indices_values), dim=1) 147 | row = tk[1] 148 | col = tk[2] 149 | 150 | new_indices = torch.stack((row, col)) 151 | x = torch.cat((indices[:, :-k], new_indices), dim=1) 152 | 153 | if indices.shape[1] > x.shape[1]: 154 | diff = indices.shape[1] - x.shape[1] 155 | new_entries = torch.zeros((2, diff), dtype=torch.long).to(device=device) 156 | x = torch.cat((x, new_entries), dim=1) 157 | 158 | indices.copy_(x) 159 | 160 | return grad_input, None, None, None, None, None 161 | 162 | 163 | class SparseLinear(nn.Module): 164 | """Applies a linear transformation to the incoming data: :math:`y = xA^T + b` 165 | 166 | Args: 167 | in_features: size of each input sample 168 | out_features: size of each output sample 169 | bias: If set to ``False``, the layer will not learn an additive bias. 170 | Default: ``True`` 171 | sparsity: sparsity of weight matrix 172 | Default: 0.9 173 | connectivity: user defined sparsity matrix 174 | Default: None 175 | small_world: boolean flag to generate small world sparsity 176 | Default: ``False`` 177 | dynamic: boolean flag to dynamically change the network structure 178 | Default: ``False`` 179 | deltaT (int): frequency for growing and pruning update step 180 | Default: 6000 181 | Tend (int): stopping time for growing and pruning algorithm update step 182 | Default: 150000 183 | alpha (float): f-decay parameter for cosine updates 184 | Default: 0.1 185 | max_size (int): maximum number of entries allowed before chunking occurrs 186 | Default: 1e8 187 | 188 | Shape: 189 | - Input: :math:`(N, *, H_{in})` where :math:`*` means any number of 190 | additional dimensions and :math:`H_{in} = \text{in\_features}` 191 | - Output: :math:`(N, *, H_{out})` where all but the last dimension 192 | are the same shape as the input and :math:`H_{out} = \text{out\_features}`. 193 | 194 | Attributes: 195 | weight: the learnable weights of the module of shape 196 | :math:`(\text{out\_features}, \text{in\_features})`. The values are 197 | initialized from :math:`\mathcal{U}(-\sqrt{k}, \sqrt{k})`, where 198 | :math:`k = \frac{1}{\text{in\_features}}` 199 | bias: the learnable bias of the module of shape :math:`(\text{out\_features})`. 200 | If :attr:`bias` is ``True``, the values are initialized from 201 | :math:`\mathcal{U}(-\sqrt{k}, \sqrt{k})` where 202 | :math:`k = \frac{1}{\text{in\_features}}` 203 | 204 | Examples: 205 | 206 | >>> m = nn.SparseLinear(20, 30) 207 | >>> input = torch.randn(128, 20) 208 | >>> output = m(input) 209 | >>> print(output.size()) 210 | torch.Size([128, 30]) 211 | """ 212 | 213 | def __init__( 214 | self, 215 | in_features, 216 | out_features, 217 | bias=True, 218 | sparsity=0.9, 219 | connectivity=None, 220 | small_world=False, 221 | dynamic=False, 222 | deltaT=6000, 223 | Tend=150000, 224 | alpha=0.1, 225 | max_size=1e8, 226 | ): 227 | assert in_features < 2 ** 31 and out_features < 2 ** 31 and sparsity < 1.0 228 | assert ( 229 | connectivity is None or not small_world 230 | ), "Cannot specify connectivity along with small world sparsity" 231 | if connectivity is not None: 232 | assert isinstance(connectivity, torch.LongTensor) or isinstance( 233 | connectivity, torch.cuda.LongTensor, 234 | ), "Connectivity must be a Long Tensor" 235 | assert ( 236 | connectivity.shape[0] == 2 and connectivity.shape[1] > 0 237 | ), "Input shape for connectivity should be (2,nnz)" 238 | assert ( 239 | connectivity.shape[1] <= in_features * out_features 240 | ), "Nnz can't be bigger than the weight matrix" 241 | super(SparseLinear, self).__init__() 242 | self.in_features = in_features 243 | self.out_features = out_features 244 | self.connectivity = connectivity 245 | self.small_world = small_world 246 | self.dynamic = dynamic 247 | self.max_size = max_size 248 | 249 | # Generate and coalesce indices : Faster to coalesce on GPU 250 | coalesce_device = ( 251 | torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu") 252 | ) 253 | 254 | if not small_world: 255 | if connectivity is None: 256 | self.sparsity = sparsity 257 | nnz = round((1.0 - sparsity) * in_features * out_features) 258 | if in_features * out_features <= 10 ** 8: 259 | indices = np.random.choice( 260 | in_features * out_features, nnz, replace=False, 261 | ) 262 | indices = torch.as_tensor(indices, device=coalesce_device) 263 | row_ind = indices.floor_divide(in_features) 264 | col_ind = indices.fmod(in_features) 265 | else: 266 | warnings.warn( 267 | "Matrix too large to sample non-zero indices without replacement, sparsity will be approximate", 268 | RuntimeWarning, 269 | ) 270 | row_ind = torch.randint( 271 | 0, out_features, (nnz,), device=coalesce_device, 272 | ) 273 | col_ind = torch.randint( 274 | 0, in_features, (nnz,), device=coalesce_device, 275 | ) 276 | indices = torch.stack((row_ind, col_ind)) 277 | else: 278 | # User defined sparsity 279 | nnz = connectivity.shape[1] 280 | self.sparsity = 1.0 - nnz / (out_features * in_features) 281 | connectivity = connectivity.to(device=coalesce_device) 282 | indices = connectivity 283 | else: 284 | # Generate small world sparsity 285 | self.sparsity = sparsity 286 | nnz = round((1.0 - sparsity) * in_features * out_features) 287 | assert nnz > min( 288 | in_features, out_features, 289 | ), "The matrix is too sparse for small-world algorithm; please decrease sparsity" 290 | offset = abs(out_features - in_features) / 2.0 291 | 292 | # Node labels 293 | inputs = torch.arange( 294 | 1 + offset * (out_features > in_features), 295 | in_features + 1 + offset * (out_features > in_features), 296 | device=coalesce_device, 297 | ) 298 | outputs = torch.arange( 299 | 1 + offset * (out_features < in_features), 300 | out_features + 1 + offset * (out_features < in_features), 301 | device=coalesce_device, 302 | ) 303 | 304 | # Creating chunks for small world algorithm 305 | total_data = in_features * out_features # Total params 306 | chunks = math.ceil(total_data / self.max_size) 307 | split_div = max(in_features, out_features) // chunks # Full chunks 308 | split_mod = max(in_features, out_features) % chunks # Remaining chunk 309 | idx = ( 310 | torch.repeat_interleave(torch.Tensor([split_div]), chunks) 311 | .int() 312 | .to(device=coalesce_device) 313 | ) 314 | idx[:split_mod] += 1 315 | idx = torch.cumsum(idx, dim=0) 316 | idx = torch.cat([torch.LongTensor([0]).to(device=coalesce_device), idx]) 317 | 318 | count = 0 319 | 320 | rows = torch.empty(0).long().to(device=coalesce_device) 321 | cols = torch.empty(0).long().to(device=coalesce_device) 322 | 323 | for i in range(chunks): 324 | inputs_ = ( 325 | inputs[idx[i] : idx[i + 1]] 326 | if out_features <= in_features 327 | else inputs 328 | ) 329 | outputs_ = ( 330 | outputs[idx[i] : idx[i + 1]] 331 | if out_features > in_features 332 | else outputs 333 | ) 334 | 335 | y = small_world_chunker(inputs_, outputs_, round(nnz / chunks)) 336 | ref = torch.rand_like(y) 337 | 338 | # Refer to Eq.7 from Bipartite_small_world_network write-up 339 | mask = torch.empty(y.shape, dtype=bool).to(device=coalesce_device) 340 | mask[y < ref] = False 341 | mask[y >= ref] = True 342 | 343 | rows_, cols_ = mask.to_sparse().indices() 344 | 345 | rows = torch.cat([rows, rows_ + idx[i]]) 346 | cols = torch.cat([cols, cols_]) 347 | 348 | indices = torch.stack((cols, rows)) 349 | nnz = indices.shape[1] 350 | 351 | values = torch.empty(nnz, device=coalesce_device) 352 | # indices, values = torch_sparse.coalesce(indices, values, out_features, in_features) 353 | 354 | self.register_buffer("indices", indices.cpu()) 355 | self.weights = nn.Parameter(values.cpu()) 356 | 357 | if bias: 358 | self.bias = nn.Parameter(torch.Tensor(out_features)) 359 | else: 360 | self.register_parameter("bias", None) 361 | 362 | if self.dynamic: 363 | self.deltaT = deltaT 364 | self.Tend = Tend 365 | self.alpha = alpha 366 | self.itr_count = 0 367 | 368 | self.reset_parameters() 369 | 370 | def reset_parameters(self): 371 | bound = 1 / self.in_features ** 0.5 372 | nn.init.uniform_(self.weights, -bound, bound) 373 | if self.bias is not None: 374 | nn.init.uniform_(self.bias, -bound, bound) 375 | 376 | @property 377 | def weight(self): 378 | """ returns a torch.sparse.FloatTensor view of the underlying weight matrix 379 | This is only for inspection purposes and should not be modified or used in any autograd operations 380 | """ 381 | weight = torch.sparse.FloatTensor( 382 | self.indices, self.weights, (self.out_features, self.in_features), 383 | ) 384 | return weight.coalesce().detach() 385 | 386 | def forward(self, inputs): 387 | if self.training and self.dynamic: 388 | self.itr_count += 1 389 | output_shape = list(inputs.shape) 390 | output_shape[-1] = self.out_features 391 | 392 | # Handle dynamic sparsity 393 | if ( 394 | self.training 395 | and self.dynamic 396 | and self.itr_count < self.Tend 397 | and self.itr_count % self.deltaT == 0 398 | ): 399 | # Drop criterion 400 | f_decay = ( 401 | self.alpha * (1 + math.cos(self.itr_count * math.pi / self.Tend)) / 2 402 | ) 403 | k = int(f_decay * (1 - self.sparsity) * self.weights.view(-1, 1).shape[0]) 404 | n = self.weights.shape[0] 405 | 406 | neg_weights = -1 * torch.abs(self.weights) 407 | _, lm_indices = torch.topk(neg_weights, n - k, largest=False, sorted=False) 408 | 409 | self.indices = torch.index_select(self.indices, 1, lm_indices) 410 | self.weights = nn.Parameter(torch.index_select(self.weights, 0, lm_indices)) 411 | 412 | device = inputs.device 413 | # Growth criterion 414 | new_weights = torch.zeros(k).to(device=device) 415 | self.weights = nn.Parameter(torch.cat((self.weights, new_weights), dim=0)) 416 | 417 | new_indices = torch.zeros((2, k), dtype=torch.long).to(device=device) 418 | self.indices = torch.cat((self.indices, new_indices), dim=1) 419 | output = GrowConnections.apply( 420 | inputs, 421 | self.weights, 422 | k, 423 | self.indices, 424 | (self.out_features, self.in_features), 425 | self.max_size, 426 | ) 427 | else: 428 | if len(output_shape) == 1: 429 | inputs = inputs.view(1, -1) 430 | inputs = inputs.flatten(end_dim=-2) 431 | 432 | # output = torch_sparse.spmm(self.indices, self.weights, self.out_features, self.in_features, inputs.t()).t() 433 | target = torch.sparse.FloatTensor( 434 | self.indices, 435 | self.weights, 436 | torch.Size([self.out_features, self.in_features]), 437 | ).to_dense() 438 | output = torch.mm(target, inputs.t()).t() 439 | 440 | if self.bias is not None: 441 | output += self.bias 442 | 443 | return output.view(output_shape) 444 | 445 | def extra_repr(self): 446 | return "in_features={}, out_features={}, bias={}, sparsity={}, connectivity={}, small_world={}".format( 447 | self.in_features, 448 | self.out_features, 449 | self.bias is not None, 450 | self.sparsity, 451 | self.connectivity, 452 | self.small_world, 453 | ) 454 | 455 | -------------------------------------------------------------------------------- /tutorials/SparseLinearDemo.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": {}, 6 | "source": [ 7 | "# SparseLinear Demonstration using MNIST \n", 8 | "\n", 9 | "##### Training models consisting of sparsely connected linear layers\n", 10 | "\n", 11 | "## Table of Contents\n", 12 | "\n", 13 | "- [Introduction](#intro)\n", 14 | "- [Setup](#setup)\n", 15 | "- [Time and memory efficiency](#efficiency)\n", 16 | "- [Training with random inputs](#random)\n", 17 | "- [Training on MNIST](#mnist)\n", 18 | "- [Training sparse models with user-defined connections](#user)\n", 19 | "- [Training sparse models with dynamic connections](#dynamic)\n", 20 | "- [Training sparse models with small-world connections](#sw)\n", 21 | "- [Utilizing the activation sparsity feature](#activation)\n", 22 | "- [Training very wide and sparse models](#big)" 23 | ] 24 | }, 25 | { 26 | "cell_type": "markdown", 27 | "metadata": {}, 28 | "source": [ 29 | "## Introduction \n", 30 | "\n", 31 | "SparseLinear is a PyTorch package that allows a user to create extremely wide and sparse linear layers efficiently. A sparsely connected network is a network where each node is connected to some fraction of available nodes.\n", 32 | "\n", 33 | "The provided package is built on top of [PyTorch Sparse](https://github.com/rusty1s/pytorch_sparse), which provides optimized sparse matrix operations with autograd support in PyTorch.\n", 34 | "\n", 35 | "In this tutorial, we demonstrate its basic usage along with steps to train using the package features. Note that it is advisable to run these on the GPU instead of the CPU owing to much faster training times on the former." 36 | ] 37 | }, 38 | { 39 | "cell_type": "markdown", 40 | "metadata": {}, 41 | "source": [ 42 | "## Setup \n", 43 | "\n", 44 | "We import PyTorch, which contains the (dense)linear module, and load the device." 45 | ] 46 | }, 47 | { 48 | "cell_type": "code", 49 | "execution_count": 1, 50 | "metadata": { 51 | "tags": [] 52 | }, 53 | "outputs": [ 54 | { 55 | "name": "stdout", 56 | "output_type": "stream", 57 | "text": [ 58 | "cuda:0\n" 59 | ] 60 | } 61 | ], 62 | "source": [ 63 | "import torch\n", 64 | "import torch.nn as nn\n", 65 | "\n", 66 | "import warnings\n", 67 | "warnings.filterwarnings('ignore')\n", 68 | "\n", 69 | "device = torch.device(\"cuda:0\" if torch.cuda.is_available() else \"cpu\")\n", 70 | "print(device)" 71 | ] 72 | }, 73 | { 74 | "cell_type": "markdown", 75 | "metadata": {}, 76 | "source": [ 77 | "We create the linear layer and demonstrate some of its built-in attributes. " 78 | ] 79 | }, 80 | { 81 | "cell_type": "code", 82 | "execution_count": 2, 83 | "metadata": {}, 84 | "outputs": [ 85 | { 86 | "data": { 87 | "text/plain": [ 88 | "('in_features=10, out_features=20, bias=True',\n", 89 | " torch.Size([20, 10]),\n", 90 | " torch.Size([20]))" 91 | ] 92 | }, 93 | "execution_count": 2, 94 | "metadata": {}, 95 | "output_type": "execute_result" 96 | } 97 | ], 98 | "source": [ 99 | "fc1 = nn.Linear(10, 20)\n", 100 | "fc1.extra_repr(), fc1.weight.shape, fc1.bias.shape" 101 | ] 102 | }, 103 | { 104 | "cell_type": "markdown", 105 | "metadata": {}, 106 | "source": [ 107 | "In a similar manner, we now import the `sparselinear` package. As can be observed, the custom layer's weights and biases can be accessed in the same manner as before. The new layer also returns some extra attributes about which we will discuss later. " 108 | ] 109 | }, 110 | { 111 | "cell_type": "code", 112 | "execution_count": 3, 113 | "metadata": { 114 | "tags": [] 115 | }, 116 | "outputs": [], 117 | "source": [ 118 | "import sparselinear as sl" 119 | ] 120 | }, 121 | { 122 | "cell_type": "code", 123 | "execution_count": 4, 124 | "metadata": {}, 125 | "outputs": [ 126 | { 127 | "data": { 128 | "text/plain": [ 129 | "('in_features=10, out_features=20, bias=True, sparsity=0.9, connectivity=None, small_world=False',\n", 130 | " torch.Size([20, 10]),\n", 131 | " torch.Size([20]))" 132 | ] 133 | }, 134 | "execution_count": 4, 135 | "metadata": {}, 136 | "output_type": "execute_result" 137 | } 138 | ], 139 | "source": [ 140 | "sl1 = sl.SparseLinear(10,20)\n", 141 | "sl1.extra_repr(), sl1.weight.shape, sl1.bias.shape" 142 | ] 143 | }, 144 | { 145 | "cell_type": "markdown", 146 | "metadata": {}, 147 | "source": [ 148 | "We now take a look at the two weight matrices." 149 | ] 150 | }, 151 | { 152 | "cell_type": "code", 153 | "execution_count": 5, 154 | "metadata": {}, 155 | "outputs": [ 156 | { 157 | "data": { 158 | "text/plain": [ 159 | "Parameter containing:\n", 160 | "tensor([[ 2.2700e-01, 1.9950e-04, 1.5138e-01, -2.3028e-01, 3.0498e-01,\n", 161 | " -1.1592e-01, 8.9561e-03, -1.9644e-01, -1.8278e-01, -1.2167e-01],\n", 162 | " [ 1.8054e-01, 4.9672e-03, -2.7930e-01, 1.7971e-02, -2.5313e-01,\n", 163 | " -1.6389e-01, 2.8138e-02, 2.3216e-01, 8.5033e-02, 2.6193e-01],\n", 164 | " [-1.3997e-01, -2.0780e-01, -1.3777e-01, 9.5758e-02, -1.1465e-01,\n", 165 | " -3.0299e-01, 2.3639e-01, 2.3740e-01, 3.4879e-02, -2.8988e-01],\n", 166 | " [ 3.4024e-02, -3.4284e-02, -3.1449e-01, -7.3634e-02, 1.0884e-01,\n", 167 | " 3.4649e-02, 2.2210e-01, -2.2692e-01, 1.7318e-01, 1.0567e-01],\n", 168 | " [ 1.8497e-01, 8.6446e-02, -1.3994e-02, -1.8335e-01, 7.1342e-02,\n", 169 | " -5.4367e-02, -1.2261e-01, -1.2711e-01, 1.2817e-01, 3.0136e-01],\n", 170 | " [ 2.7756e-01, -2.6505e-01, 2.1932e-02, 2.2353e-01, -2.0779e-01,\n", 171 | " 2.9041e-01, -2.9108e-01, 2.5556e-02, 2.6355e-02, 9.2430e-02],\n", 172 | " [-6.9308e-02, 1.4349e-01, 2.1799e-01, 9.2573e-02, -1.1946e-01,\n", 173 | " -8.8171e-02, 1.8941e-01, 2.6366e-01, -2.2858e-01, -7.3599e-02],\n", 174 | " [ 2.7514e-01, 5.0453e-02, -2.7328e-01, 1.8520e-01, -1.9852e-01,\n", 175 | " -1.9735e-01, 2.4275e-01, -3.9498e-02, -9.1360e-02, -1.1861e-01],\n", 176 | " [-1.1171e-01, -9.5010e-02, 8.9707e-02, 4.6313e-02, 1.4619e-01,\n", 177 | " 8.1823e-02, 1.7853e-01, -8.7963e-02, -1.1446e-01, 2.6627e-01],\n", 178 | " [ 2.3615e-01, -3.6197e-02, -8.4897e-03, -1.6606e-01, 2.2675e-01,\n", 179 | " 2.4547e-01, -9.0289e-03, 2.4520e-01, -5.2978e-02, 2.1697e-01],\n", 180 | " [-3.0513e-01, 1.3863e-01, 7.8270e-02, 1.4909e-01, 1.9973e-01,\n", 181 | " -3.0507e-01, 2.5586e-01, -1.7424e-01, -1.5309e-01, 5.5867e-04],\n", 182 | " [-7.0133e-02, -1.9032e-02, -1.2506e-01, -1.7848e-01, -2.9393e-02,\n", 183 | " -1.2914e-01, 2.6337e-01, 2.6768e-02, 1.2672e-01, 2.7603e-01],\n", 184 | " [-2.8693e-01, -2.3720e-01, 2.5797e-01, 1.4165e-01, 1.3373e-01,\n", 185 | " 1.8109e-01, -1.6957e-01, -1.7997e-02, 8.7014e-02, -2.4652e-01],\n", 186 | " [-2.5947e-01, -2.8787e-01, -2.5539e-01, 1.6618e-01, 1.8108e-01,\n", 187 | " 2.5726e-01, -8.1458e-02, -5.7938e-02, -1.2392e-01, 2.4684e-01],\n", 188 | " [-1.4422e-01, 2.5416e-01, -3.1530e-01, -3.4516e-02, 2.4788e-03,\n", 189 | " -4.7461e-02, -4.3661e-02, 6.9811e-02, -1.5750e-01, -2.6075e-02],\n", 190 | " [-2.1888e-02, 7.3086e-02, -1.8007e-01, 1.3642e-01, 3.5587e-02,\n", 191 | " -9.6858e-02, -2.9306e-01, 3.6169e-02, 1.5752e-01, -1.6869e-01],\n", 192 | " [-1.4439e-01, 1.7146e-01, -2.5298e-01, -8.1204e-02, -2.2055e-01,\n", 193 | " 7.1367e-02, -1.9483e-01, 6.2341e-02, 2.7298e-01, -1.9703e-01],\n", 194 | " [ 2.4394e-01, -1.1722e-01, 1.9469e-01, -1.3386e-01, 2.3983e-02,\n", 195 | " -1.6841e-01, -2.1754e-01, -3.1110e-01, -3.0143e-01, -2.0946e-01],\n", 196 | " [-5.3311e-02, -2.9350e-01, -2.7600e-01, -2.5163e-02, 1.9405e-01,\n", 197 | " 2.2620e-01, -2.3203e-01, 2.2993e-01, 2.7201e-02, -2.3870e-01],\n", 198 | " [-1.0589e-01, 2.1078e-01, 1.7010e-01, 1.1657e-01, 9.0184e-02,\n", 199 | " 2.8292e-01, -9.2838e-02, 6.3635e-02, 6.2464e-02, -3.0613e-01]],\n", 200 | " requires_grad=True)" 201 | ] 202 | }, 203 | "execution_count": 5, 204 | "metadata": {}, 205 | "output_type": "execute_result" 206 | } 207 | ], 208 | "source": [ 209 | "fc1.weight" 210 | ] 211 | }, 212 | { 213 | "cell_type": "code", 214 | "execution_count": 6, 215 | "metadata": {}, 216 | "outputs": [ 217 | { 218 | "data": { 219 | "text/plain": [ 220 | "tensor(indices=tensor([[ 0, 0, 1, 2, 3, 3, 4, 5, 5, 6, 7, 11, 12, 12,\n", 221 | " 13, 14, 14, 15, 15, 19],\n", 222 | " [ 3, 8, 1, 4, 0, 7, 1, 5, 8, 7, 0, 2, 3, 4,\n", 223 | " 7, 3, 7, 4, 7, 9]]),\n", 224 | " values=tensor([-0.0633, -0.0889, -0.1319, -0.0976, 0.0678, -0.0097,\n", 225 | " -0.1309, -0.0588, -0.2626, -0.1929, 0.2060, -0.2528,\n", 226 | " -0.0253, 0.1744, 0.2165, 0.1699, -0.0991, -0.2581,\n", 227 | " 0.3090, 0.0959]),\n", 228 | " size=(20, 10), nnz=20, layout=torch.sparse_coo)" 229 | ] 230 | }, 231 | "execution_count": 6, 232 | "metadata": {}, 233 | "output_type": "execute_result" 234 | } 235 | ], 236 | "source": [ 237 | "sl1.weight" 238 | ] 239 | }, 240 | { 241 | "cell_type": "markdown", 242 | "metadata": {}, 243 | "source": [ 244 | "As can be seen, the first weight matrix has 200 non-zero entries while the second one has 20 non-zero entries as specified by `nnz`. The indices tensor keeps track of all the indices where a non-zero entry is present with the corresponding entry in the values tensor providing the entry at that index." 245 | ] 246 | }, 247 | { 248 | "cell_type": "markdown", 249 | "metadata": {}, 250 | "source": [ 251 | "## Time and Memory Efficiency \n", 252 | "\n", 253 | "The `SparseLinear` class is ideal for very wide and sparse layers. Since we utilize sparse tensors and only store non-zero values (and their corresponding indices), `SparseLinear` is much more efficient in terms of memory consumption than simply applying a mask over a standard dense weight matrix -- as is often done by researchers and practioners. Since we only perform computations on non-zero values, we see speedups in computation time for large matrices as well. As hardware becomes more well-suited for sparse computations, these speedups will likely increase.\n", 254 | "\n", 255 | "To show this, we create a (20000, 20000) `SparseLinear` layer with 99% sparsity and compare its runtime to that of a standard `Linear` layer. Later in this notebook, we create massive layers that would not be possible with standard `Linear` layers due to memory inefficiencies.\n", 256 | "\n", 257 | "We initialize two layers and define the input." 258 | ] 259 | }, 260 | { 261 | "cell_type": "code", 262 | "execution_count": 7, 263 | "metadata": {}, 264 | "outputs": [], 265 | "source": [ 266 | "sl2 = sl.SparseLinear(20000, 20000, sparsity=.99).cuda()\n", 267 | "\n", 268 | "# Reduce weight dimensions if memory errors are raised\n", 269 | "fc2 = nn.Linear(20000, 20000).cuda()\n", 270 | "\n", 271 | "x = torch.rand(20000, device=device)" 272 | ] 273 | }, 274 | { 275 | "cell_type": "markdown", 276 | "metadata": {}, 277 | "source": [ 278 | "We time the inference steps." 279 | ] 280 | }, 281 | { 282 | "cell_type": "code", 283 | "execution_count": 8, 284 | "metadata": {}, 285 | "outputs": [ 286 | { 287 | "name": "stdout", 288 | "output_type": "stream", 289 | "text": [ 290 | "583 µs ± 120 ns per loop (mean ± std. dev. of 7 runs, 1000 loops each)\n", 291 | "1.85 ms ± 93.1 ns per loop (mean ± std. dev. of 7 runs, 1000 loops each)\n" 292 | ] 293 | } 294 | ], 295 | "source": [ 296 | "%timeit y = sl2(x)\n", 297 | "%timeit y = fc2(x)" 298 | ] 299 | }, 300 | { 301 | "cell_type": "markdown", 302 | "metadata": {}, 303 | "source": [ 304 | "We time the training step for SparseLinear." 305 | ] 306 | }, 307 | { 308 | "cell_type": "code", 309 | "execution_count": 9, 310 | "metadata": {}, 311 | "outputs": [ 312 | { 313 | "name": "stdout", 314 | "output_type": "stream", 315 | "text": [ 316 | "789 µs ± 666 ns per loop (mean ± std. dev. of 7 runs, 100 loops each)\n" 317 | ] 318 | } 319 | ], 320 | "source": [ 321 | "%%timeit\n", 322 | "y = sl2(x)\n", 323 | "y.sum().backward()" 324 | ] 325 | }, 326 | { 327 | "cell_type": "markdown", 328 | "metadata": {}, 329 | "source": [ 330 | "We time the training step for Linear." 331 | ] 332 | }, 333 | { 334 | "cell_type": "code", 335 | "execution_count": 10, 336 | "metadata": {}, 337 | "outputs": [ 338 | { 339 | "name": "stdout", 340 | "output_type": "stream", 341 | "text": [ 342 | "9.29 ms ± 86 ns per loop (mean ± std. dev. of 7 runs, 1000 loops each)\n" 343 | ] 344 | } 345 | ], 346 | "source": [ 347 | "%%timeit\n", 348 | "y = fc2(x)\n", 349 | "y.sum().backward()" 350 | ] 351 | }, 352 | { 353 | "cell_type": "markdown", 354 | "metadata": {}, 355 | "source": [ 356 | "We delete layers to save GPU memory when running this notebook." 357 | ] 358 | }, 359 | { 360 | "cell_type": "code", 361 | "execution_count": 11, 362 | "metadata": {}, 363 | "outputs": [], 364 | "source": [ 365 | "del sl2, fc2" 366 | ] 367 | }, 368 | { 369 | "cell_type": "markdown", 370 | "metadata": {}, 371 | "source": [ 372 | "## Training with random inputs \n", 373 | "\n", 374 | "Next, we demonstrate how to train a two-layer network using the `SparseLinear` module provided in the package. The code has been built upon the PyTorch [tutorial](https://pytorch.org/tutorials/beginner/examples_nn/two_layer_net_nn.html) to highlight the parallels between the `nn.Linear` and `sl.SparseLinear` modules." 375 | ] 376 | }, 377 | { 378 | "cell_type": "code", 379 | "execution_count": 12, 380 | "metadata": { 381 | "tags": [] 382 | }, 383 | "outputs": [ 384 | { 385 | "name": "stdout", 386 | "output_type": "stream", 387 | "text": [ 388 | "Dense model loss: 4.417; Sparse model loss: 5.111\n", 389 | "Dense model loss: 0.075; Sparse model loss: 0.031\n", 390 | "Dense model loss: 0.002; Sparse model loss: 0.000\n", 391 | "Dense model loss: 0.000; Sparse model loss: 0.000\n", 392 | "Dense model loss: 0.000; Sparse model loss: 0.000\n" 393 | ] 394 | } 395 | ], 396 | "source": [ 397 | "# N is batch size; D_in is input dimension;\n", 398 | "# H is hidden dimension; D_out is output dimension.\n", 399 | "N, D_in, H, D_out = 64, 200, 1000, 10\n", 400 | "\n", 401 | "# Create random Tensors to hold inputs and outputs.\n", 402 | "x = torch.randn(N, D_in)\n", 403 | "y = torch.randn(N, D_out)\n", 404 | "\n", 405 | "# Use the nn package to define our dense model as a sequence of layers. \n", 406 | "model_dense = torch.nn.Sequential(\n", 407 | " torch.nn.Linear(D_in, H),\n", 408 | " torch.nn.ReLU(),\n", 409 | " torch.nn.Linear(H, D_out),\n", 410 | ")\n", 411 | "\n", 412 | "# Use the sl package to define our sparse model as a sequence of layers.\n", 413 | "# Note that the default sparsity is 90%.\n", 414 | "model_sparse = torch.nn.Sequential(\n", 415 | " sl.SparseLinear(D_in, H),\n", 416 | " torch.nn.ReLU(),\n", 417 | " sl.SparseLinear(H, D_out),\n", 418 | ")\n", 419 | "\n", 420 | "# We will use Mean Squared Error (MSE) as our loss function.\n", 421 | "loss_fn = torch.nn.MSELoss(reduction='sum')\n", 422 | "\n", 423 | "# We define our learning rates.\n", 424 | "# Note that sparse and dense models may require different learning rates.\n", 425 | "learning_rate_dense = 1e-4\n", 426 | "learning_rate_sparse = 1e-3\n", 427 | "\n", 428 | "for t in range(500):\n", 429 | " # Forward pass\n", 430 | " y_pred_dense = model_dense(x)\n", 431 | " y_pred_sparse = model_sparse(x)\n", 432 | "\n", 433 | " # Compute and print loss\n", 434 | " loss_dense = loss_fn(y_pred_dense, y)\n", 435 | " loss_sparse = loss_fn(y_pred_sparse, y)\n", 436 | " if t % 100 == 99:\n", 437 | " print(\"Dense model loss: %.3f; Sparse model loss: %.3f\" %(loss_dense.item(), loss_sparse.item()))\n", 438 | "\n", 439 | " # Zero the gradients before running the backward pass.\n", 440 | " model_dense.zero_grad()\n", 441 | " model_sparse.zero_grad()\n", 442 | "\n", 443 | " # Backward pass\n", 444 | " loss_dense.backward()\n", 445 | " loss_sparse.backward()\n", 446 | "\n", 447 | " # Update the weights using gradient descent\n", 448 | " with torch.no_grad():\n", 449 | " for param in model_dense.parameters():\n", 450 | " param -= learning_rate_dense * param.grad\n", 451 | " \n", 452 | " for param in model_sparse.parameters():\n", 453 | " param -= learning_rate_sparse * param.grad" 454 | ] 455 | }, 456 | { 457 | "cell_type": "markdown", 458 | "metadata": {}, 459 | "source": [ 460 | "As we can see, the loss value in both models decreases. Let's now build models using this module and train on the MNIST digit classification task." 461 | ] 462 | }, 463 | { 464 | "cell_type": "markdown", 465 | "metadata": {}, 466 | "source": [ 467 | "## Training on MNIST \n", 468 | "\n", 469 | "We start by doing the initial imports, generating transforms, creating the dataset along with the dataloader, defining the loss function and some other helper functions." 470 | ] 471 | }, 472 | { 473 | "cell_type": "code", 474 | "execution_count": 13, 475 | "metadata": {}, 476 | "outputs": [], 477 | "source": [ 478 | "import time\n", 479 | "import torchvision\n", 480 | "import torchvision.transforms as transforms\n", 481 | "import torch.optim as optim\n", 482 | "import torch.nn.functional as F\n", 483 | "from torch.utils.data import sampler" 484 | ] 485 | }, 486 | { 487 | "cell_type": "code", 488 | "execution_count": 14, 489 | "metadata": { 490 | "tags": [] 491 | }, 492 | "outputs": [], 493 | "source": [ 494 | "tf = transforms.Compose([transforms.ToTensor(),\n", 495 | " transforms.Normalize((0.1307,), (0.3081,))])\n", 496 | "\n", 497 | "batch_size = 64\n", 498 | "trainset = torchvision.datasets.MNIST(root='./data', train=True, download=True, transform=tf)\n", 499 | "testset = torchvision.datasets.MNIST(root='./data', train=False, download=True, transform=tf)\n", 500 | "train_dataloader = torch.utils.data.DataLoader(trainset, batch_size=batch_size, shuffle=True, drop_last=True)\n", 501 | "test_dataloader = torch.utils.data.DataLoader(testset, batch_size=batch_size, shuffle=False, drop_last=True)" 502 | ] 503 | }, 504 | { 505 | "cell_type": "code", 506 | "execution_count": 15, 507 | "metadata": { 508 | "tags": [] 509 | }, 510 | "outputs": [], 511 | "source": [ 512 | "def train_model(model, optimizer, criterion, train_dataloader, test_dataloader, num_epochs=20):\n", 513 | " since = time.time()\n", 514 | " for epoch in range(num_epochs):\n", 515 | " cum_loss, total, correct = 0, 0, 0\n", 516 | " model.train()\n", 517 | " \n", 518 | " # Training epoch\n", 519 | " for i, (images, labels) in enumerate(train_dataloader, 0):\n", 520 | " images = images.to(device)\n", 521 | " labels = labels.to(device)\n", 522 | "\n", 523 | " # Forward pass & statistics\n", 524 | " out = model(images)\n", 525 | " predicted = out.argmax(dim=1)\n", 526 | " correct += (predicted == labels).sum().item()\n", 527 | " total += labels.size(0)\n", 528 | " loss = criterion(out, labels)\n", 529 | " cum_loss += loss.item()\n", 530 | "\n", 531 | " # Backwards pass & update\n", 532 | " loss.backward()\n", 533 | " optimizer.step()\n", 534 | " optimizer.zero_grad()\n", 535 | " \n", 536 | " epoch_loss = images.shape[0] * cum_loss / total\n", 537 | " epoch_acc = 100 * (correct / total)\n", 538 | " print('Epoch %d' % (epoch + 1))\n", 539 | " print('Training Loss: {:.4f}; Training Acc: {:.4f}'.format(epoch_loss, epoch_acc))\n", 540 | " \n", 541 | " cum_loss, total, correct = 0, 0, 0\n", 542 | " model.eval()\n", 543 | " \n", 544 | " # Test epoch\n", 545 | " for i, (images, labels) in enumerate(test_dataloader, 0):\n", 546 | " images = images.to(device)\n", 547 | " labels = labels.to(device)\n", 548 | "\n", 549 | " # Forward pass & statistics\n", 550 | " out = model(images)\n", 551 | " predicted = out.argmax(dim=1)\n", 552 | " correct += (predicted == labels).sum().item()\n", 553 | " total += labels.size(0)\n", 554 | " loss = criterion(out, labels)\n", 555 | " cum_loss += loss.item()\n", 556 | " \n", 557 | " epoch_loss = images.shape[0] * cum_loss / total\n", 558 | " epoch_acc = 100 * (correct / total)\n", 559 | " \n", 560 | " print('Test Loss: {:.4f}; Test Acc: {:.4f}'.format(epoch_loss, epoch_acc))\n", 561 | " print('------------')\n", 562 | " \n", 563 | " time_elapsed = time.time() - since\n", 564 | " print('\\nTraining complete in {:.0f}m {:.0f}s'.format(\n", 565 | " time_elapsed // 60, time_elapsed % 60))" 566 | ] 567 | }, 568 | { 569 | "cell_type": "code", 570 | "execution_count": 16, 571 | "metadata": {}, 572 | "outputs": [], 573 | "source": [ 574 | "def flatten(x):\n", 575 | "\tN = x.shape[0]\n", 576 | "\treturn x.view(N, -1)\n", 577 | "\n", 578 | "class Flatten(nn.Module):\n", 579 | " def forward(self, x):\n", 580 | " return flatten(x)" 581 | ] 582 | }, 583 | { 584 | "cell_type": "code", 585 | "execution_count": 17, 586 | "metadata": {}, 587 | "outputs": [], 588 | "source": [ 589 | "criterion = nn.CrossEntropyLoss()" 590 | ] 591 | }, 592 | { 593 | "cell_type": "markdown", 594 | "metadata": {}, 595 | "source": [ 596 | "### Training a dense model\n", 597 | "\n", 598 | "We start off with training a two-layer fully connected network. " 599 | ] 600 | }, 601 | { 602 | "cell_type": "code", 603 | "execution_count": 18, 604 | "metadata": {}, 605 | "outputs": [], 606 | "source": [ 607 | "model = nn.Sequential(\n", 608 | "\tFlatten(),\n", 609 | "\tnn.Linear(784, 2000),\n", 610 | " nn.LayerNorm(2000),\n", 611 | "\tnn.ReLU(),\n", 612 | " nn.Linear(2000, 10),\n", 613 | ")\n", 614 | "model = model.to(device)" 615 | ] 616 | }, 617 | { 618 | "cell_type": "markdown", 619 | "metadata": {}, 620 | "source": [ 621 | "After we set everything up, we declare the optimizer and start training the dense model. We use SGD as the optimizer since we found its behavior to be slightly better than that of others. However, one is free to choose any optimizer as long as there exists an implementation for it to handle sparse tensors. " 622 | ] 623 | }, 624 | { 625 | "cell_type": "code", 626 | "execution_count": 19, 627 | "metadata": { 628 | "tags": [] 629 | }, 630 | "outputs": [ 631 | { 632 | "name": "stdout", 633 | "output_type": "stream", 634 | "text": [ 635 | "Epoch 1\n", 636 | "Training Loss: 0.2062; Training Acc: 93.5916\n", 637 | "Test Loss: 0.1250; Test Acc: 95.9736\n", 638 | "------------\n", 639 | "Epoch 2\n", 640 | "Training Loss: 0.0778; Training Acc: 97.6387\n", 641 | "Test Loss: 0.0949; Test Acc: 96.9351\n", 642 | "------------\n", 643 | "Epoch 3\n", 644 | "Training Loss: 0.0491; Training Acc: 98.4775\n", 645 | "Test Loss: 0.0785; Test Acc: 97.5160\n", 646 | "------------\n", 647 | "Epoch 4\n", 648 | "Training Loss: 0.0311; Training Acc: 99.0678\n", 649 | "Test Loss: 0.0664; Test Acc: 97.9667\n", 650 | "------------\n", 651 | "Epoch 5\n", 652 | "Training Loss: 0.0201; Training Acc: 99.4614\n", 653 | "Test Loss: 0.0704; Test Acc: 97.7564\n", 654 | "------------\n", 655 | "Epoch 6\n", 656 | "Training Loss: 0.0131; Training Acc: 99.7065\n", 657 | "Test Loss: 0.0549; Test Acc: 98.3173\n", 658 | "------------\n", 659 | "Epoch 7\n", 660 | "Training Loss: 0.0081; Training Acc: 99.8749\n", 661 | "Test Loss: 0.0592; Test Acc: 98.2372\n", 662 | "------------\n", 663 | "Epoch 8\n", 664 | "Training Loss: 0.0057; Training Acc: 99.9350\n", 665 | "Test Loss: 0.0536; Test Acc: 98.3674\n", 666 | "------------\n", 667 | "Epoch 9\n", 668 | "Training Loss: 0.0035; Training Acc: 99.9850\n", 669 | "Test Loss: 0.0551; Test Acc: 98.3173\n", 670 | "------------\n", 671 | "Epoch 10\n", 672 | "Training Loss: 0.0025; Training Acc: 99.9983\n", 673 | "Test Loss: 0.0518; Test Acc: 98.4976\n", 674 | "------------\n", 675 | "Epoch 11\n", 676 | "Training Loss: 0.0020; Training Acc: 100.0000\n", 677 | "Test Loss: 0.0526; Test Acc: 98.3974\n", 678 | "------------\n", 679 | "Epoch 12\n", 680 | "Training Loss: 0.0017; Training Acc: 99.9983\n", 681 | "Test Loss: 0.0530; Test Acc: 98.4976\n", 682 | "------------\n", 683 | "Epoch 13\n", 684 | "Training Loss: 0.0015; Training Acc: 100.0000\n", 685 | "Test Loss: 0.0531; Test Acc: 98.5377\n", 686 | "------------\n", 687 | "Epoch 14\n", 688 | "Training Loss: 0.0013; Training Acc: 100.0000\n", 689 | "Test Loss: 0.0532; Test Acc: 98.5076\n", 690 | "------------\n", 691 | "Epoch 15\n", 692 | "Training Loss: 0.0012; Training Acc: 100.0000\n", 693 | "Test Loss: 0.0544; Test Acc: 98.4675\n", 694 | "------------\n", 695 | "Epoch 16\n", 696 | "Training Loss: 0.0011; Training Acc: 100.0000\n", 697 | "Test Loss: 0.0539; Test Acc: 98.4776\n", 698 | "------------\n", 699 | "Epoch 17\n", 700 | "Training Loss: 0.0010; Training Acc: 100.0000\n", 701 | "Test Loss: 0.0537; Test Acc: 98.5276\n", 702 | "------------\n", 703 | "Epoch 18\n", 704 | "Training Loss: 0.0009; Training Acc: 100.0000\n", 705 | "Test Loss: 0.0546; Test Acc: 98.4976\n", 706 | "------------\n", 707 | "Epoch 19\n", 708 | "Training Loss: 0.0009; Training Acc: 100.0000\n", 709 | "Test Loss: 0.0545; Test Acc: 98.5076\n", 710 | "------------\n", 711 | "Epoch 20\n", 712 | "Training Loss: 0.0008; Training Acc: 100.0000\n", 713 | "Test Loss: 0.0544; Test Acc: 98.4776\n", 714 | "------------\n", 715 | "\n", 716 | "Training complete in 4m 56s\n" 717 | ] 718 | } 719 | ], 720 | "source": [ 721 | "learning_rate = 1e-2\n", 722 | "optimizer = optim.SGD(model.parameters(), lr=learning_rate, momentum=0.9)\n", 723 | "\n", 724 | "#Perform the training \n", 725 | "train_model(model, optimizer, criterion, train_dataloader, test_dataloader)" 726 | ] 727 | }, 728 | { 729 | "cell_type": "markdown", 730 | "metadata": {}, 731 | "source": [ 732 | "### Training a Sparse Model\n", 733 | "\n", 734 | "(with the default configuration)\n", 735 | "\n", 736 | "In the same way that we declared a dense model, we now declare a sparse model with the same number of input and output features but far fewer parameters. " 737 | ] 738 | }, 739 | { 740 | "cell_type": "code", 741 | "execution_count": 20, 742 | "metadata": {}, 743 | "outputs": [], 744 | "source": [ 745 | "sparse_model = nn.Sequential(\n", 746 | "\tFlatten(),\n", 747 | "\tsl.SparseLinear(784, 2000),\n", 748 | " nn.LayerNorm(2000),\n", 749 | "\tnn.ReLU(),\n", 750 | " sl.SparseLinear(2000, 10)\n", 751 | ")\n", 752 | "sparse_model = sparse_model.to(device)" 753 | ] 754 | }, 755 | { 756 | "cell_type": "markdown", 757 | "metadata": {}, 758 | "source": [ 759 | "We now train this model. Note that the learning rate is an order of magnitude higher. This is something we have found to be a rule of thumb while training these models. " 760 | ] 761 | }, 762 | { 763 | "cell_type": "code", 764 | "execution_count": 21, 765 | "metadata": { 766 | "tags": [] 767 | }, 768 | "outputs": [ 769 | { 770 | "name": "stdout", 771 | "output_type": "stream", 772 | "text": [ 773 | "Epoch 1\n", 774 | "Training Loss: 0.2103; Training Acc: 93.6383\n", 775 | "Test Loss: 0.1184; Test Acc: 96.3141\n", 776 | "------------\n", 777 | "Epoch 2\n", 778 | "Training Loss: 0.0896; Training Acc: 97.2269\n", 779 | "Test Loss: 0.1022; Test Acc: 96.7949\n", 780 | "------------\n", 781 | "Epoch 3\n", 782 | "Training Loss: 0.0604; Training Acc: 98.0340\n", 783 | "Test Loss: 0.0880; Test Acc: 97.4159\n", 784 | "------------\n", 785 | "Epoch 4\n", 786 | "Training Loss: 0.0424; Training Acc: 98.6860\n", 787 | "Test Loss: 0.0799; Test Acc: 97.6262\n", 788 | "------------\n", 789 | "Epoch 5\n", 790 | "Training Loss: 0.0313; Training Acc: 99.0061\n", 791 | "Test Loss: 0.0782; Test Acc: 97.7764\n", 792 | "------------\n", 793 | "Epoch 6\n", 794 | "Training Loss: 0.0228; Training Acc: 99.2696\n", 795 | "Test Loss: 0.0843; Test Acc: 97.6963\n", 796 | "------------\n", 797 | "Epoch 7\n", 798 | "Training Loss: 0.0163; Training Acc: 99.5147\n", 799 | "Test Loss: 0.0861; Test Acc: 97.8466\n", 800 | "------------\n", 801 | "Epoch 8\n", 802 | "Training Loss: 0.0117; Training Acc: 99.6565\n", 803 | "Test Loss: 0.0917; Test Acc: 97.7364\n", 804 | "------------\n", 805 | "Epoch 9\n", 806 | "Training Loss: 0.0074; Training Acc: 99.8332\n", 807 | "Test Loss: 0.0913; Test Acc: 97.8666\n", 808 | "------------\n", 809 | "Epoch 10\n", 810 | "Training Loss: 0.0044; Training Acc: 99.9200\n", 811 | "Test Loss: 0.0819; Test Acc: 97.9667\n", 812 | "------------\n", 813 | "Epoch 11\n", 814 | "Training Loss: 0.0023; Training Acc: 99.9800\n", 815 | "Test Loss: 0.0852; Test Acc: 98.0970\n", 816 | "------------\n", 817 | "Epoch 12\n", 818 | "Training Loss: 0.0012; Training Acc: 99.9950\n", 819 | "Test Loss: 0.0894; Test Acc: 98.0469\n", 820 | "------------\n", 821 | "Epoch 13\n", 822 | "Training Loss: 0.0008; Training Acc: 99.9967\n", 823 | "Test Loss: 0.0864; Test Acc: 98.1871\n", 824 | "------------\n", 825 | "Epoch 14\n", 826 | "Training Loss: 0.0006; Training Acc: 100.0000\n", 827 | "Test Loss: 0.0874; Test Acc: 98.1470\n", 828 | "------------\n", 829 | "Epoch 15\n", 830 | "Training Loss: 0.0005; Training Acc: 100.0000\n", 831 | "Test Loss: 0.0884; Test Acc: 98.1270\n", 832 | "------------\n", 833 | "Epoch 16\n", 834 | "Training Loss: 0.0004; Training Acc: 100.0000\n", 835 | "Test Loss: 0.0891; Test Acc: 98.1571\n", 836 | "------------\n", 837 | "Epoch 17\n", 838 | "Training Loss: 0.0004; Training Acc: 100.0000\n", 839 | "Test Loss: 0.0901; Test Acc: 98.1470\n", 840 | "------------\n", 841 | "Epoch 18\n", 842 | "Training Loss: 0.0004; Training Acc: 100.0000\n", 843 | "Test Loss: 0.0910; Test Acc: 98.1070\n", 844 | "------------\n", 845 | "Epoch 19\n", 846 | "Training Loss: 0.0003; Training Acc: 100.0000\n", 847 | "Test Loss: 0.0916; Test Acc: 98.0970\n", 848 | "------------\n", 849 | "Epoch 20\n", 850 | "Training Loss: 0.0003; Training Acc: 100.0000\n", 851 | "Test Loss: 0.0918; Test Acc: 98.1571\n", 852 | "------------\n", 853 | "\n", 854 | "Training complete in 5m 12s\n" 855 | ] 856 | } 857 | ], 858 | "source": [ 859 | "learning_rate = 1e-1\n", 860 | "optimizer = optim.SGD(sparse_model.parameters(), lr=learning_rate, momentum=0.9)\n", 861 | "\n", 862 | "train_model(sparse_model, optimizer, criterion, train_dataloader, test_dataloader)" 863 | ] 864 | }, 865 | { 866 | "cell_type": "markdown", 867 | "metadata": {}, 868 | "source": [ 869 | "As can be seen, the two models perform comparably.\n", 870 | "\n", 871 | "\n", 872 | "However, while the dense model has a total of 1590010 parameters, the sparse model only has 160810 parameters. This translates to **~89.8%** parameter reduction in the sparse model! \n", 873 | "\n", 874 | "We display the weight parameter counts of the two layers below. " 875 | ] 876 | }, 877 | { 878 | "cell_type": "code", 879 | "execution_count": 22, 880 | "metadata": {}, 881 | "outputs": [ 882 | { 883 | "data": { 884 | "text/plain": [ 885 | "(torch.Size([156800]),\n", 886 | " torch.Size([2000, 784]),\n", 887 | " torch.Size([10, 2000]),\n", 888 | " torch.Size([2000]))" 889 | ] 890 | }, 891 | "execution_count": 22, 892 | "metadata": {}, 893 | "output_type": "execute_result" 894 | } 895 | ], 896 | "source": [ 897 | "sparse_model[1].weights.shape, model[1].weight.shape, model[4].weight.shape, sparse_model[4].weights.shape" 898 | ] 899 | }, 900 | { 901 | "cell_type": "markdown", 902 | "metadata": {}, 903 | "source": [ 904 | "## Training sparse models with user-defined connections \n", 905 | "\n", 906 | "Instead of using the random set of connections created during initialization between the input and output neurons, one can choose to define one's own connections to the sparse linear layer by providing an input long tensor of shape (2,`nnz`) specifying connections from input to output neurons using the `connectivity` argument. " 907 | ] 908 | }, 909 | { 910 | "cell_type": "markdown", 911 | "metadata": {}, 912 | "source": [ 913 | "Below we create a connectivity matrix where the input layer is connected to random entries in the output layer. Of course, this is just a small demonstration and one can experiment here with different connectivity matrices. " 914 | ] 915 | }, 916 | { 917 | "cell_type": "code", 918 | "execution_count": 23, 919 | "metadata": {}, 920 | "outputs": [], 921 | "source": [ 922 | "num_connections = 200\n", 923 | "col = torch.arange(784).repeat_interleave(num_connections).view(1,-1).long()\n", 924 | "row = torch.randint(low=0, high=2000, size=(784*num_connections,)).view(1,-1).long()\n", 925 | "connections = torch.cat((row, col), dim=0)" 926 | ] 927 | }, 928 | { 929 | "cell_type": "markdown", 930 | "metadata": {}, 931 | "source": [ 932 | "We provide our connectivity matrix as an input to the `SparseLinear` module and follow the same training procedure as before. " 933 | ] 934 | }, 935 | { 936 | "cell_type": "code", 937 | "execution_count": 24, 938 | "metadata": {}, 939 | "outputs": [], 940 | "source": [ 941 | "sparse_model_user = nn.Sequential(\n", 942 | "\tFlatten(),\n", 943 | "\tsl.SparseLinear(784, 2000, connectivity=connections),\n", 944 | " nn.LayerNorm(2000),\n", 945 | "\tnn.ReLU(),\n", 946 | " sl.SparseLinear(2000, 10)\n", 947 | ")\n", 948 | "sparse_model_user = sparse_model_user.to(device)" 949 | ] 950 | }, 951 | { 952 | "cell_type": "code", 953 | "execution_count": 25, 954 | "metadata": {}, 955 | "outputs": [ 956 | { 957 | "name": "stdout", 958 | "output_type": "stream", 959 | "text": [ 960 | "Epoch 1\n", 961 | "Training Loss: 0.2136; Training Acc: 93.6283\n", 962 | "Test Loss: 0.1089; Test Acc: 96.8249\n", 963 | "------------\n", 964 | "Epoch 2\n", 965 | "Training Loss: 0.0902; Training Acc: 97.3302\n", 966 | "Test Loss: 0.1025; Test Acc: 96.8049\n", 967 | "------------\n", 968 | "Epoch 3\n", 969 | "Training Loss: 0.0619; Training Acc: 98.0606\n", 970 | "Test Loss: 0.0877; Test Acc: 97.3658\n", 971 | "------------\n", 972 | "Epoch 4\n", 973 | "Training Loss: 0.0437; Training Acc: 98.5859\n", 974 | "Test Loss: 0.0814; Test Acc: 97.5160\n", 975 | "------------\n", 976 | "Epoch 5\n", 977 | "Training Loss: 0.0340; Training Acc: 98.8661\n", 978 | "Test Loss: 0.0854; Test Acc: 97.5461\n", 979 | "------------\n", 980 | "Epoch 6\n", 981 | "Training Loss: 0.0239; Training Acc: 99.2496\n", 982 | "Test Loss: 0.0815; Test Acc: 97.6062\n", 983 | "------------\n", 984 | "Epoch 7\n", 985 | "Training Loss: 0.0165; Training Acc: 99.4997\n", 986 | "Test Loss: 0.0901; Test Acc: 97.5060\n", 987 | "------------\n", 988 | "Epoch 8\n", 989 | "Training Loss: 0.0122; Training Acc: 99.6298\n", 990 | "Test Loss: 0.0873; Test Acc: 97.7063\n", 991 | "------------\n", 992 | "Epoch 9\n", 993 | "Training Loss: 0.0080; Training Acc: 99.8016\n", 994 | "Test Loss: 0.0917; Test Acc: 97.6963\n", 995 | "------------\n", 996 | "Epoch 10\n", 997 | "Training Loss: 0.0041; Training Acc: 99.9450\n", 998 | "Test Loss: 0.0837; Test Acc: 97.8766\n", 999 | "------------\n", 1000 | "Epoch 11\n", 1001 | "Training Loss: 0.0020; Training Acc: 99.9833\n", 1002 | "Test Loss: 0.0864; Test Acc: 97.8566\n", 1003 | "------------\n", 1004 | "Epoch 12\n", 1005 | "Training Loss: 0.0012; Training Acc: 99.9967\n", 1006 | "Test Loss: 0.0852; Test Acc: 97.9267\n", 1007 | "------------\n", 1008 | "Epoch 13\n", 1009 | "Training Loss: 0.0008; Training Acc: 100.0000\n", 1010 | "Test Loss: 0.0849; Test Acc: 97.9167\n", 1011 | "------------\n", 1012 | "Epoch 14\n", 1013 | "Training Loss: 0.0006; Training Acc: 100.0000\n", 1014 | "Test Loss: 0.0871; Test Acc: 97.9167\n", 1015 | "------------\n", 1016 | "Epoch 15\n", 1017 | "Training Loss: 0.0005; Training Acc: 100.0000\n", 1018 | "Test Loss: 0.0881; Test Acc: 97.8766\n", 1019 | "------------\n", 1020 | "Epoch 16\n", 1021 | "Training Loss: 0.0005; Training Acc: 100.0000\n", 1022 | "Test Loss: 0.0884; Test Acc: 97.8966\n", 1023 | "------------\n", 1024 | "Epoch 17\n", 1025 | "Training Loss: 0.0004; Training Acc: 100.0000\n", 1026 | "Test Loss: 0.0896; Test Acc: 97.8866\n", 1027 | "------------\n", 1028 | "Epoch 18\n", 1029 | "Training Loss: 0.0004; Training Acc: 100.0000\n", 1030 | "Test Loss: 0.0903; Test Acc: 97.8966\n", 1031 | "------------\n", 1032 | "Epoch 19\n", 1033 | "Training Loss: 0.0004; Training Acc: 100.0000\n", 1034 | "Test Loss: 0.0905; Test Acc: 97.8966\n", 1035 | "------------\n", 1036 | "Epoch 20\n", 1037 | "Training Loss: 0.0003; Training Acc: 100.0000\n", 1038 | "Test Loss: 0.0918; Test Acc: 97.8766\n", 1039 | "------------\n", 1040 | "\n", 1041 | "Training complete in 5m 19s\n" 1042 | ] 1043 | } 1044 | ], 1045 | "source": [ 1046 | "learning_rate = 1e-1\n", 1047 | "optimizer = optim.SGD(sparse_model_user.parameters(), lr=learning_rate, momentum=0.9)\n", 1048 | "\n", 1049 | "train_model(sparse_model_user, optimizer, criterion, train_dataloader, test_dataloader)" 1050 | ] 1051 | }, 1052 | { 1053 | "cell_type": "markdown", 1054 | "metadata": {}, 1055 | "source": [ 1056 | "## Training sparse model with dynamic connections \n", 1057 | "\n", 1058 | "The default `SparseLinear` model creates a random set of connections during initialization between the input and output neurons. An improvement over this strategy is to prune some non-required connections and grow (hopefully)required ones. We implement the [Rigging the Lottery](https://arxiv.org/pdf/1911.11134.pdf) algorithm to achieve this. Specifying `dynamic` to be `True` alters the layer connections dynamically while training." 1059 | ] 1060 | }, 1061 | { 1062 | "cell_type": "code", 1063 | "execution_count": 26, 1064 | "metadata": { 1065 | "tags": [] 1066 | }, 1067 | "outputs": [], 1068 | "source": [ 1069 | "sparse_model_dynamic = nn.Sequential(\n", 1070 | "\tFlatten(),\n", 1071 | "\tsl.SparseLinear(784, 2000, dynamic=True),\n", 1072 | " nn.LayerNorm(2000),\n", 1073 | "\tnn.ReLU(),\n", 1074 | " sl.SparseLinear(2000, 10, dynamic=True)\n", 1075 | ")\n", 1076 | "sparse_model_dynamic = sparse_model_dynamic.to(device)" 1077 | ] 1078 | }, 1079 | { 1080 | "cell_type": "code", 1081 | "execution_count": 27, 1082 | "metadata": { 1083 | "tags": [] 1084 | }, 1085 | "outputs": [ 1086 | { 1087 | "name": "stdout", 1088 | "output_type": "stream", 1089 | "text": [ 1090 | "Epoch 1\n", 1091 | "Training Loss: 0.2137; Training Acc: 93.5616\n", 1092 | "Test Loss: 0.1246; Test Acc: 96.0737\n", 1093 | "------------\n", 1094 | "Epoch 2\n", 1095 | "Training Loss: 0.0915; Training Acc: 97.2269\n", 1096 | "Test Loss: 0.0815; Test Acc: 97.4960\n", 1097 | "------------\n", 1098 | "Epoch 3\n", 1099 | "Training Loss: 0.0613; Training Acc: 98.1057\n", 1100 | "Test Loss: 0.0907; Test Acc: 97.2155\n", 1101 | "------------\n", 1102 | "Epoch 4\n", 1103 | "Training Loss: 0.0452; Training Acc: 98.5676\n", 1104 | "Test Loss: 0.0839; Test Acc: 97.3658\n", 1105 | "------------\n", 1106 | "Epoch 5\n", 1107 | "Training Loss: 0.0336; Training Acc: 98.9695\n", 1108 | "Test Loss: 0.0789; Test Acc: 97.5761\n", 1109 | "------------\n", 1110 | "Epoch 6\n", 1111 | "Training Loss: 0.0219; Training Acc: 99.3647\n", 1112 | "Test Loss: 0.0681; Test Acc: 97.8365\n", 1113 | "------------\n", 1114 | "Epoch 7\n", 1115 | "Training Loss: 0.0128; Training Acc: 99.7182\n", 1116 | "Test Loss: 0.0670; Test Acc: 97.9267\n", 1117 | "------------\n", 1118 | "Epoch 8\n", 1119 | "Training Loss: 0.0116; Training Acc: 99.7665\n", 1120 | "Test Loss: 0.0672; Test Acc: 97.9267\n", 1121 | "------------\n", 1122 | "Epoch 9\n", 1123 | "Training Loss: 0.0110; Training Acc: 99.7816\n", 1124 | "Test Loss: 0.0667; Test Acc: 97.9367\n", 1125 | "------------\n", 1126 | "Epoch 10\n", 1127 | "Training Loss: 0.0106; Training Acc: 99.7932\n", 1128 | "Test Loss: 0.0669; Test Acc: 97.9567\n", 1129 | "------------\n", 1130 | "Epoch 11\n", 1131 | "Training Loss: 0.0102; Training Acc: 99.8032\n", 1132 | "Test Loss: 0.0672; Test Acc: 97.9667\n", 1133 | "------------\n", 1134 | "Epoch 12\n", 1135 | "Training Loss: 0.0100; Training Acc: 99.8299\n", 1136 | "Test Loss: 0.0673; Test Acc: 98.0068\n", 1137 | "------------\n", 1138 | "Epoch 13\n", 1139 | "Training Loss: 0.0097; Training Acc: 99.8199\n", 1140 | "Test Loss: 0.0674; Test Acc: 97.9968\n", 1141 | "------------\n", 1142 | "Epoch 14\n", 1143 | "Training Loss: 0.0095; Training Acc: 99.8332\n", 1144 | "Test Loss: 0.0677; Test Acc: 98.0068\n", 1145 | "------------\n", 1146 | "Epoch 15\n", 1147 | "Training Loss: 0.0093; Training Acc: 99.8432\n", 1148 | "Test Loss: 0.0677; Test Acc: 98.0068\n", 1149 | "------------\n", 1150 | "Epoch 16\n", 1151 | "Training Loss: 0.0092; Training Acc: 99.8466\n", 1152 | "Test Loss: 0.0679; Test Acc: 97.9868\n", 1153 | "------------\n", 1154 | "Epoch 17\n", 1155 | "Training Loss: 0.0090; Training Acc: 99.8483\n", 1156 | "Test Loss: 0.0685; Test Acc: 98.0469\n", 1157 | "------------\n", 1158 | "Epoch 18\n", 1159 | "Training Loss: 0.0089; Training Acc: 99.8466\n", 1160 | "Test Loss: 0.0687; Test Acc: 97.9768\n", 1161 | "------------\n", 1162 | "Epoch 19\n", 1163 | "Training Loss: 0.0088; Training Acc: 99.8599\n", 1164 | "Test Loss: 0.0685; Test Acc: 98.0168\n", 1165 | "------------\n", 1166 | "Epoch 20\n", 1167 | "Training Loss: 0.0087; Training Acc: 99.8533\n", 1168 | "Test Loss: 0.0689; Test Acc: 98.0469\n", 1169 | "------------\n", 1170 | "\n", 1171 | "Training complete in 5m 15s\n" 1172 | ] 1173 | } 1174 | ], 1175 | "source": [ 1176 | "learning_rate = 5e-3\n", 1177 | "optimizer = optim.SGD(sparse_model_dynamic.parameters(), lr=learning_rate, momentum=0.9)\n", 1178 | "train_model(sparse_model_dynamic, optimizer, criterion, train_dataloader, test_dataloader)" 1179 | ] 1180 | }, 1181 | { 1182 | "cell_type": "markdown", 1183 | "metadata": {}, 1184 | "source": [ 1185 | "## Training sparse model with small-world connections \n", 1186 | "\n", 1187 | "Some sparsity patterns tend to perform better than others. Small-world sparsity provides a network that is mostly locally connected with a few global, long-range connections scattered in. See [here](https://en.wikipedia.org/wiki/Small-world_network). We implement an initialization strategy to incorporate small-world sparsity in the model. To specify, set `small_world` to `True`. " 1188 | ] 1189 | }, 1190 | { 1191 | "cell_type": "code", 1192 | "execution_count": 28, 1193 | "metadata": { 1194 | "tags": [] 1195 | }, 1196 | "outputs": [], 1197 | "source": [ 1198 | "sparse_model_sw = nn.Sequential(\n", 1199 | "\tFlatten(),\n", 1200 | "\tsl.SparseLinear(784, 2000, small_world=True),\n", 1201 | " nn.LayerNorm(2000),\n", 1202 | "\tnn.ReLU(),\n", 1203 | " sl.SparseLinear(2000, 10, small_world=True)\n", 1204 | ")\n", 1205 | "sparse_model_sw = sparse_model_sw.to(device)" 1206 | ] 1207 | }, 1208 | { 1209 | "cell_type": "code", 1210 | "execution_count": 29, 1211 | "metadata": { 1212 | "tags": [] 1213 | }, 1214 | "outputs": [ 1215 | { 1216 | "name": "stdout", 1217 | "output_type": "stream", 1218 | "text": [ 1219 | "Epoch 1\n", 1220 | "Training Loss: 0.2043; Training Acc: 93.7817\n", 1221 | "Test Loss: 0.1040; Test Acc: 96.7748\n", 1222 | "------------\n", 1223 | "Epoch 2\n", 1224 | "Training Loss: 0.0856; Training Acc: 97.3202\n", 1225 | "Test Loss: 0.0853; Test Acc: 97.2756\n", 1226 | "------------\n", 1227 | "Epoch 3\n", 1228 | "Training Loss: 0.0573; Training Acc: 98.2057\n", 1229 | "Test Loss: 0.0816; Test Acc: 97.5661\n", 1230 | "------------\n", 1231 | "Epoch 4\n", 1232 | "Training Loss: 0.0404; Training Acc: 98.7176\n", 1233 | "Test Loss: 0.0786; Test Acc: 97.5761\n", 1234 | "------------\n", 1235 | "Epoch 5\n", 1236 | "Training Loss: 0.0297; Training Acc: 99.0211\n", 1237 | "Test Loss: 0.0741; Test Acc: 97.7464\n", 1238 | "------------\n", 1239 | "Epoch 6\n", 1240 | "Training Loss: 0.0211; Training Acc: 99.3246\n", 1241 | "Test Loss: 0.0740; Test Acc: 97.9267\n", 1242 | "------------\n", 1243 | "Epoch 7\n", 1244 | "Training Loss: 0.0162; Training Acc: 99.4597\n", 1245 | "Test Loss: 0.0808; Test Acc: 97.8666\n", 1246 | "------------\n", 1247 | "Epoch 8\n", 1248 | "Training Loss: 0.0108; Training Acc: 99.7132\n", 1249 | "Test Loss: 0.0900; Test Acc: 97.6462\n", 1250 | "------------\n", 1251 | "Epoch 9\n", 1252 | "Training Loss: 0.0090; Training Acc: 99.7148\n", 1253 | "Test Loss: 0.0858; Test Acc: 97.8666\n", 1254 | "------------\n", 1255 | "Epoch 10\n", 1256 | "Training Loss: 0.0055; Training Acc: 99.8666\n", 1257 | "Test Loss: 0.0830; Test Acc: 98.0469\n", 1258 | "------------\n", 1259 | "Epoch 11\n", 1260 | "Training Loss: 0.0029; Training Acc: 99.9516\n", 1261 | "Test Loss: 0.0785; Test Acc: 98.0068\n", 1262 | "------------\n", 1263 | "Epoch 12\n", 1264 | "Training Loss: 0.0013; Training Acc: 99.9950\n", 1265 | "Test Loss: 0.0808; Test Acc: 98.2472\n", 1266 | "------------\n", 1267 | "Epoch 13\n", 1268 | "Training Loss: 0.0007; Training Acc: 100.0000\n", 1269 | "Test Loss: 0.0796; Test Acc: 98.2071\n", 1270 | "------------\n", 1271 | "Epoch 14\n", 1272 | "Training Loss: 0.0005; Training Acc: 100.0000\n", 1273 | "Test Loss: 0.0811; Test Acc: 98.1571\n", 1274 | "------------\n", 1275 | "Epoch 15\n", 1276 | "Training Loss: 0.0005; Training Acc: 100.0000\n", 1277 | "Test Loss: 0.0823; Test Acc: 98.1871\n", 1278 | "------------\n", 1279 | "Epoch 16\n", 1280 | "Training Loss: 0.0004; Training Acc: 100.0000\n", 1281 | "Test Loss: 0.0827; Test Acc: 98.1871\n", 1282 | "------------\n", 1283 | "Epoch 17\n", 1284 | "Training Loss: 0.0004; Training Acc: 100.0000\n", 1285 | "Test Loss: 0.0841; Test Acc: 98.2071\n", 1286 | "------------\n", 1287 | "Epoch 18\n", 1288 | "Training Loss: 0.0003; Training Acc: 100.0000\n", 1289 | "Test Loss: 0.0844; Test Acc: 98.1671\n", 1290 | "------------\n", 1291 | "Epoch 19\n", 1292 | "Training Loss: 0.0003; Training Acc: 100.0000\n", 1293 | "Test Loss: 0.0849; Test Acc: 98.1871\n", 1294 | "------------\n", 1295 | "Epoch 20\n", 1296 | "Training Loss: 0.0003; Training Acc: 100.0000\n", 1297 | "Test Loss: 0.0854; Test Acc: 98.2171\n", 1298 | "------------\n", 1299 | "\n", 1300 | "Training complete in 5m 21s\n" 1301 | ] 1302 | } 1303 | ], 1304 | "source": [ 1305 | "learning_rate = 1e-1\n", 1306 | "optimizer = optim.SGD(sparse_model_sw.parameters(), lr=learning_rate, momentum=0.9)\n", 1307 | "train_model(sparse_model_sw, optimizer, criterion, train_dataloader, test_dataloader)" 1308 | ] 1309 | }, 1310 | { 1311 | "cell_type": "markdown", 1312 | "metadata": {}, 1313 | "source": [ 1314 | "## Utilizing the activation sparsity feature \n", 1315 | "\n", 1316 | "The `SparseLinear` layer is constructed for parameter sparsity; however, we make no stipulations on the sparsity (or density) of the activations. We include an option for sparse activations using the K-Winners strategy. This paper describes a potential method ([k-winners](https://arxiv.org/pdf/1903.11257.pdf) layer) which we use to train both linear and sparse linear models." 1317 | ] 1318 | }, 1319 | { 1320 | "cell_type": "code", 1321 | "execution_count": 30, 1322 | "metadata": {}, 1323 | "outputs": [], 1324 | "source": [ 1325 | "import activationsparsity as asy" 1326 | ] 1327 | }, 1328 | { 1329 | "cell_type": "markdown", 1330 | "metadata": {}, 1331 | "source": [ 1332 | "Below we train a linear model using this activation sparsity feature. By default, we set `act_sparsity=0.65` (which means `k=(1-0.65)*2000`) for the layer below. " 1333 | ] 1334 | }, 1335 | { 1336 | "cell_type": "code", 1337 | "execution_count": 31, 1338 | "metadata": {}, 1339 | "outputs": [], 1340 | "source": [ 1341 | "model_asy = nn.Sequential(\n", 1342 | "\tFlatten(),\n", 1343 | " nn.Linear(784, 2000),\n", 1344 | " nn.LayerNorm(2000),\n", 1345 | " asy.ActivationSparsity(),\n", 1346 | " nn.Linear(2000,10)\n", 1347 | ")\n", 1348 | "model_asy = model_asy.to(device)" 1349 | ] 1350 | }, 1351 | { 1352 | "cell_type": "code", 1353 | "execution_count": 32, 1354 | "metadata": {}, 1355 | "outputs": [ 1356 | { 1357 | "name": "stdout", 1358 | "output_type": "stream", 1359 | "text": [ 1360 | "Epoch 1\n", 1361 | "Training Loss: 0.2241; Training Acc: 93.3548\n", 1362 | "Test Loss: 0.1153; Test Acc: 96.5845\n", 1363 | "------------\n", 1364 | "Epoch 2\n", 1365 | "Training Loss: 0.0901; Training Acc: 97.3102\n", 1366 | "Test Loss: 0.0894; Test Acc: 97.1855\n", 1367 | "------------\n", 1368 | "Epoch 3\n", 1369 | "Training Loss: 0.0590; Training Acc: 98.3825\n", 1370 | "Test Loss: 0.0785; Test Acc: 97.6462\n", 1371 | "------------\n", 1372 | "Epoch 4\n", 1373 | "Training Loss: 0.0415; Training Acc: 98.8877\n", 1374 | "Test Loss: 0.0671; Test Acc: 97.8866\n", 1375 | "------------\n", 1376 | "Epoch 5\n", 1377 | "Training Loss: 0.0303; Training Acc: 99.2529\n", 1378 | "Test Loss: 0.0616; Test Acc: 98.0569\n", 1379 | "------------\n", 1380 | "Epoch 6\n", 1381 | "Training Loss: 0.0222; Training Acc: 99.5114\n", 1382 | "Test Loss: 0.0602; Test Acc: 98.1270\n", 1383 | "------------\n", 1384 | "Epoch 7\n", 1385 | "Training Loss: 0.0168; Training Acc: 99.6748\n", 1386 | "Test Loss: 0.0611; Test Acc: 98.0569\n", 1387 | "------------\n", 1388 | "Epoch 8\n", 1389 | "Training Loss: 0.0128; Training Acc: 99.8366\n", 1390 | "Test Loss: 0.0604; Test Acc: 98.0970\n", 1391 | "------------\n", 1392 | "Epoch 9\n", 1393 | "Training Loss: 0.0102; Training Acc: 99.8849\n", 1394 | "Test Loss: 0.0569; Test Acc: 98.2472\n", 1395 | "------------\n", 1396 | "Epoch 10\n", 1397 | "Training Loss: 0.0079; Training Acc: 99.9400\n", 1398 | "Test Loss: 0.0581; Test Acc: 98.1971\n", 1399 | "------------\n", 1400 | "Epoch 11\n", 1401 | "Training Loss: 0.0065; Training Acc: 99.9633\n", 1402 | "Test Loss: 0.0559; Test Acc: 98.3173\n", 1403 | "------------\n", 1404 | "Epoch 12\n", 1405 | "Training Loss: 0.0053; Training Acc: 99.9733\n", 1406 | "Test Loss: 0.0538; Test Acc: 98.3874\n", 1407 | "------------\n", 1408 | "Epoch 13\n", 1409 | "Training Loss: 0.0045; Training Acc: 99.9917\n", 1410 | "Test Loss: 0.0543; Test Acc: 98.3474\n", 1411 | "------------\n", 1412 | "Epoch 14\n", 1413 | "Training Loss: 0.0040; Training Acc: 99.9950\n", 1414 | "Test Loss: 0.0576; Test Acc: 98.3073\n", 1415 | "------------\n", 1416 | "Epoch 15\n", 1417 | "Training Loss: 0.0036; Training Acc: 99.9900\n", 1418 | "Test Loss: 0.0555; Test Acc: 98.3273\n", 1419 | "------------\n", 1420 | "Epoch 16\n", 1421 | "Training Loss: 0.0032; Training Acc: 99.9983\n", 1422 | "Test Loss: 0.0556; Test Acc: 98.3974\n", 1423 | "------------\n", 1424 | "Epoch 17\n", 1425 | "Training Loss: 0.0029; Training Acc: 99.9967\n", 1426 | "Test Loss: 0.0581; Test Acc: 98.3574\n", 1427 | "------------\n", 1428 | "Epoch 18\n", 1429 | "Training Loss: 0.0026; Training Acc: 100.0000\n", 1430 | "Test Loss: 0.0577; Test Acc: 98.2472\n", 1431 | "------------\n", 1432 | "Epoch 19\n", 1433 | "Training Loss: 0.0024; Training Acc: 100.0000\n", 1434 | "Test Loss: 0.0568; Test Acc: 98.2372\n", 1435 | "------------\n", 1436 | "Epoch 20\n", 1437 | "Training Loss: 0.0022; Training Acc: 99.9983\n", 1438 | "Test Loss: 0.0561; Test Acc: 98.3574\n", 1439 | "------------\n", 1440 | "\n", 1441 | "Training complete in 6m 5s\n" 1442 | ] 1443 | } 1444 | ], 1445 | "source": [ 1446 | "learning_rate = 5e-3\n", 1447 | "optimizer = optim.SGD(model_asy.parameters(), lr=learning_rate, momentum=0.9)\n", 1448 | "\n", 1449 | "#Perform the training \n", 1450 | "train_model(model_asy, optimizer, criterion, train_dataloader, test_dataloader)" 1451 | ] 1452 | }, 1453 | { 1454 | "cell_type": "markdown", 1455 | "metadata": {}, 1456 | "source": [ 1457 | "Now we train another model which uses the sparse linear module along with this activation. As mentioned before, the learning rate is an order of magnitude higher than the linear module. " 1458 | ] 1459 | }, 1460 | { 1461 | "cell_type": "code", 1462 | "execution_count": 33, 1463 | "metadata": {}, 1464 | "outputs": [], 1465 | "source": [ 1466 | "model_asy_sparse = nn.Sequential(\n", 1467 | "\tFlatten(),\n", 1468 | "\tsl.SparseLinear(784, 2000),\n", 1469 | " nn.LayerNorm(2000),\n", 1470 | " asy.ActivationSparsity(),\n", 1471 | " sl.SparseLinear(2000, 10),\n", 1472 | ")\n", 1473 | "model_asy_sparse = model_asy_sparse.to(device)" 1474 | ] 1475 | }, 1476 | { 1477 | "cell_type": "code", 1478 | "execution_count": 34, 1479 | "metadata": {}, 1480 | "outputs": [ 1481 | { 1482 | "name": "stdout", 1483 | "output_type": "stream", 1484 | "text": [ 1485 | "Epoch 1\n", 1486 | "Training Loss: 0.2218; Training Acc: 93.2080\n", 1487 | "Test Loss: 0.1237; Test Acc: 96.1138\n", 1488 | "------------\n", 1489 | "Epoch 2\n", 1490 | "Training Loss: 0.0960; Training Acc: 97.0318\n", 1491 | "Test Loss: 0.0965; Test Acc: 97.0954\n", 1492 | "------------\n", 1493 | "Epoch 3\n", 1494 | "Training Loss: 0.0670; Training Acc: 97.9172\n", 1495 | "Test Loss: 0.0930; Test Acc: 97.1855\n", 1496 | "------------\n", 1497 | "Epoch 4\n", 1498 | "Training Loss: 0.0501; Training Acc: 98.4075\n", 1499 | "Test Loss: 0.0898; Test Acc: 97.2556\n", 1500 | "------------\n", 1501 | "Epoch 5\n", 1502 | "Training Loss: 0.0391; Training Acc: 98.7777\n", 1503 | "Test Loss: 0.0747; Test Acc: 97.7163\n", 1504 | "------------\n", 1505 | "Epoch 6\n", 1506 | "Training Loss: 0.0303; Training Acc: 99.0261\n", 1507 | "Test Loss: 0.0824; Test Acc: 97.5661\n", 1508 | "------------\n", 1509 | "Epoch 7\n", 1510 | "Training Loss: 0.0234; Training Acc: 99.2596\n", 1511 | "Test Loss: 0.0777; Test Acc: 97.7764\n", 1512 | "------------\n", 1513 | "Epoch 8\n", 1514 | "Training Loss: 0.0173; Training Acc: 99.4981\n", 1515 | "Test Loss: 0.0848; Test Acc: 97.4960\n", 1516 | "------------\n", 1517 | "Epoch 9\n", 1518 | "Training Loss: 0.0138; Training Acc: 99.6398\n", 1519 | "Test Loss: 0.0780; Test Acc: 97.8766\n", 1520 | "------------\n", 1521 | "Epoch 10\n", 1522 | "Training Loss: 0.0096; Training Acc: 99.7699\n", 1523 | "Test Loss: 0.0840; Test Acc: 97.7764\n", 1524 | "------------\n", 1525 | "Epoch 11\n", 1526 | "Training Loss: 0.0083; Training Acc: 99.8066\n", 1527 | "Test Loss: 0.0860; Test Acc: 97.8365\n", 1528 | "------------\n", 1529 | "Epoch 12\n", 1530 | "Training Loss: 0.0068; Training Acc: 99.8616\n", 1531 | "Test Loss: 0.0851; Test Acc: 97.8866\n", 1532 | "------------\n", 1533 | "Epoch 13\n", 1534 | "Training Loss: 0.0049; Training Acc: 99.9200\n", 1535 | "Test Loss: 0.0787; Test Acc: 98.0369\n", 1536 | "------------\n", 1537 | "Epoch 14\n", 1538 | "Training Loss: 0.0033; Training Acc: 99.9666\n", 1539 | "Test Loss: 0.0841; Test Acc: 98.0168\n", 1540 | "------------\n", 1541 | "Epoch 15\n", 1542 | "Training Loss: 0.0025; Training Acc: 99.9783\n", 1543 | "Test Loss: 0.0845; Test Acc: 97.8866\n", 1544 | "------------\n", 1545 | "Epoch 16\n", 1546 | "Training Loss: 0.0022; Training Acc: 99.9850\n", 1547 | "Test Loss: 0.0857; Test Acc: 98.0469\n", 1548 | "------------\n", 1549 | "Epoch 17\n", 1550 | "Training Loss: 0.0017; Training Acc: 99.9933\n", 1551 | "Test Loss: 0.0850; Test Acc: 98.1270\n", 1552 | "------------\n", 1553 | "Epoch 18\n", 1554 | "Training Loss: 0.0013; Training Acc: 99.9967\n", 1555 | "Test Loss: 0.0890; Test Acc: 98.0268\n", 1556 | "------------\n", 1557 | "Epoch 19\n", 1558 | "Training Loss: 0.0012; Training Acc: 99.9967\n", 1559 | "Test Loss: 0.0891; Test Acc: 98.1370\n", 1560 | "------------\n", 1561 | "Epoch 20\n", 1562 | "Training Loss: 0.0010; Training Acc: 99.9983\n", 1563 | "Test Loss: 0.0882; Test Acc: 98.0669\n", 1564 | "------------\n", 1565 | "\n", 1566 | "Training complete in 6m 24s\n" 1567 | ] 1568 | } 1569 | ], 1570 | "source": [ 1571 | "learning_rate = 5e-2\n", 1572 | "optimizer = optim.SGD(model_asy_sparse.parameters(), lr=learning_rate, momentum=0.9)\n", 1573 | "\n", 1574 | "#Perform the training \n", 1575 | "train_model(model_asy_sparse, optimizer, criterion, train_dataloader, test_dataloader)" 1576 | ] 1577 | }, 1578 | { 1579 | "cell_type": "markdown", 1580 | "metadata": {}, 1581 | "source": [ 1582 | "## Training very wide and sparse models \n", 1583 | "\n", 1584 | "The main advantage of utilizing sparse tensors is that it enables us to train very wide models. Below we demonstrate an example of such a model. Of course, it is just a demonstration and the key take away is that we can build these huge models for more complex tasks where the benefits would be more viable. " 1585 | ] 1586 | }, 1587 | { 1588 | "cell_type": "code", 1589 | "execution_count": 39, 1590 | "metadata": {}, 1591 | "outputs": [], 1592 | "source": [ 1593 | "import torch.nn as nn\n", 1594 | "import torch.nn.functional as F\n", 1595 | "\n", 1596 | "class Net(nn.Module):\n", 1597 | " def __init__(self):\n", 1598 | " super(Net, self).__init__()\n", 1599 | " self.sc1 = sl.SparseLinear(10 * 28 * 28, 50000, sparsity=0.999)\n", 1600 | " self.sc2 = sl.SparseLinear(50000, 50000, sparsity=0.999)\n", 1601 | " self.sc3 = sl.SparseLinear(50000, 50000, sparsity=0.999)\n", 1602 | " self.sc4 = sl.SparseLinear(50000, 50000, sparsity=0.999)\n", 1603 | " self.sc5 = sl.SparseLinear(50000, 50000, sparsity=0.999)\n", 1604 | " \n", 1605 | " self.input_scaling = nn.Parameter(torch.ones(10 * 28 * 28))\n", 1606 | " self.input_shifting = nn.Parameter(torch.zeros(10 * 28 * 28))\n", 1607 | " self.ln1 = nn.LayerNorm(50000)\n", 1608 | " self.ln2 = nn.LayerNorm(50000)\n", 1609 | " self.ln3 = nn.LayerNorm(50000)\n", 1610 | " self.ln4 = nn.LayerNorm(50000)\n", 1611 | " \n", 1612 | " def forward(self, x):\n", 1613 | " x = x.view(-1, 28 * 28)\n", 1614 | " x = torch.repeat_interleave(x, 10, dim=1)\n", 1615 | " x = self.input_scaling * x + self.input_shifting\n", 1616 | " x = F.relu(self.ln1(self.sc1(x)))\n", 1617 | " x = F.relu(self.ln2(self.sc2(x)))\n", 1618 | " x = F.relu(self.ln3(self.sc3(x)))\n", 1619 | " x = F.relu(self.ln4(self.sc4(x)))\n", 1620 | " x = self.sc5(x)\n", 1621 | " x = x.view(x.shape[0], -1, 10).sum(dim=1) # sum 5000 outputs per class\n", 1622 | " return x\n", 1623 | "\n", 1624 | "sparse_big = Net().to(device)" 1625 | ] 1626 | }, 1627 | { 1628 | "cell_type": "code", 1629 | "execution_count": 40, 1630 | "metadata": { 1631 | "tags": [] 1632 | }, 1633 | "outputs": [ 1634 | { 1635 | "name": "stdout", 1636 | "output_type": "stream", 1637 | "text": [ 1638 | "Epoch 1\n", 1639 | "Training Loss: 0.1993; Training Acc: 93.9168\n", 1640 | "Test Loss: 0.1317; Test Acc: 95.9034\n", 1641 | "------------\n", 1642 | "Epoch 2\n", 1643 | "Training Loss: 0.0675; Training Acc: 98.0023\n", 1644 | "Test Loss: 0.0806; Test Acc: 97.5160\n", 1645 | "------------\n", 1646 | "Epoch 3\n", 1647 | "Training Loss: 0.0318; Training Acc: 99.2663\n", 1648 | "Test Loss: 0.0723; Test Acc: 97.6863\n", 1649 | "------------\n", 1650 | "Epoch 4\n", 1651 | "Training Loss: 0.0167; Training Acc: 99.7532\n", 1652 | "Test Loss: 0.0630; Test Acc: 97.9968\n", 1653 | "------------\n", 1654 | "Epoch 5\n", 1655 | "Training Loss: 0.0088; Training Acc: 99.9483\n", 1656 | "Test Loss: 0.0609; Test Acc: 98.0970\n", 1657 | "------------\n", 1658 | "Epoch 6\n", 1659 | "Training Loss: 0.0056; Training Acc: 99.9917\n", 1660 | "Test Loss: 0.0584; Test Acc: 98.2272\n", 1661 | "------------\n", 1662 | "Epoch 7\n", 1663 | "Training Loss: 0.0041; Training Acc: 100.0000\n", 1664 | "Test Loss: 0.0587; Test Acc: 98.1771\n", 1665 | "------------\n", 1666 | "Epoch 8\n", 1667 | "Training Loss: 0.0033; Training Acc: 100.0000\n", 1668 | "Test Loss: 0.0583; Test Acc: 98.1671\n", 1669 | "------------\n", 1670 | "Epoch 9\n", 1671 | "Training Loss: 0.0028; Training Acc: 100.0000\n", 1672 | "Test Loss: 0.0587; Test Acc: 98.1370\n", 1673 | "------------\n", 1674 | "Epoch 10\n", 1675 | "Training Loss: 0.0024; Training Acc: 100.0000\n", 1676 | "Test Loss: 0.0581; Test Acc: 98.1671\n", 1677 | "------------\n", 1678 | "Epoch 11\n", 1679 | "Training Loss: 0.0021; Training Acc: 100.0000\n", 1680 | "Test Loss: 0.0580; Test Acc: 98.1771\n", 1681 | "------------\n", 1682 | "Epoch 12\n", 1683 | "Training Loss: 0.0019; Training Acc: 100.0000\n", 1684 | "Test Loss: 0.0577; Test Acc: 98.2272\n", 1685 | "------------\n", 1686 | "Epoch 13\n", 1687 | "Training Loss: 0.0017; Training Acc: 100.0000\n", 1688 | "Test Loss: 0.0577; Test Acc: 98.2372\n", 1689 | "------------\n", 1690 | "Epoch 14\n", 1691 | "Training Loss: 0.0016; Training Acc: 100.0000\n", 1692 | "Test Loss: 0.0586; Test Acc: 98.1771\n", 1693 | "------------\n", 1694 | "Epoch 15\n", 1695 | "Training Loss: 0.0014; Training Acc: 100.0000\n", 1696 | "Test Loss: 0.0582; Test Acc: 98.1971\n", 1697 | "------------\n", 1698 | "Epoch 16\n", 1699 | "Training Loss: 0.0013; Training Acc: 100.0000\n", 1700 | "Test Loss: 0.0582; Test Acc: 98.2372\n", 1701 | "------------\n", 1702 | "Epoch 17\n", 1703 | "Training Loss: 0.0012; Training Acc: 100.0000\n", 1704 | "Test Loss: 0.0582; Test Acc: 98.2071\n", 1705 | "------------\n", 1706 | "Epoch 18\n", 1707 | "Training Loss: 0.0012; Training Acc: 100.0000\n", 1708 | "Test Loss: 0.0584; Test Acc: 98.1971\n", 1709 | "------------\n", 1710 | "Epoch 19\n", 1711 | "Training Loss: 0.0011; Training Acc: 100.0000\n", 1712 | "Test Loss: 0.0582; Test Acc: 98.2272\n", 1713 | "------------\n", 1714 | "Epoch 20\n", 1715 | "Training Loss: 0.0010; Training Acc: 100.0000\n", 1716 | "Test Loss: 0.0583; Test Acc: 98.1871\n", 1717 | "------------\n", 1718 | "\n", 1719 | "Training complete in 39m 29s\n" 1720 | ] 1721 | } 1722 | ], 1723 | "source": [ 1724 | "learning_rate = 5e-5\n", 1725 | "optimizer = optim.SGD(sparse_big.parameters(), lr=learning_rate, momentum=0.9, nesterov=True)\n", 1726 | "train_model(sparse_big, optimizer, criterion, train_dataloader, test_dataloader)" 1727 | ] 1728 | }, 1729 | { 1730 | "cell_type": "markdown", 1731 | "metadata": {}, 1732 | "source": [ 1733 | "In conclusion, we demonstrated the `SparseLinear` layer. From a user's perspective, it is very similar to PyTorch's `Linear` layer. We also showed extra features namely user-defined sparsity, dynamic sparsity, small-world connectivity, and activation sparsity. \n", 1734 | "\n", 1735 | "Our experiments showed that even with a huge reduction in parameters, we were able to achieve a performance similar to that of massively parameterised layers. \n", 1736 | "\n", 1737 | "We hope this excites and enables people to build highly scalable sparse networks!" 1738 | ] 1739 | }, 1740 | { 1741 | "cell_type": "markdown", 1742 | "metadata": {}, 1743 | "source": [ 1744 | "![Alt Text](https://media.giphy.com/media/L0O3TQpp0WnSXmxV8p/giphy.gif)" 1745 | ] 1746 | } 1747 | ], 1748 | "metadata": { 1749 | "kernelspec": { 1750 | "display_name": "Environment (conda_pytorch_p36)", 1751 | "language": "python", 1752 | "name": "conda_pytorch_p36" 1753 | }, 1754 | "language_info": { 1755 | "codemirror_mode": { 1756 | "name": "ipython", 1757 | "version": 3 1758 | }, 1759 | "file_extension": ".py", 1760 | "mimetype": "text/x-python", 1761 | "name": "python", 1762 | "nbconvert_exporter": "python", 1763 | "pygments_lexer": "ipython3", 1764 | "version": "3.6.10" 1765 | } 1766 | }, 1767 | "nbformat": 4, 1768 | "nbformat_minor": 2 1769 | } 1770 | --------------------------------------------------------------------------------