├── .DS_Store ├── models ├── .DS_Store ├── init_utils.py ├── initializers.py ├── gcn_conv.py ├── fagcn_conv.py ├── gat_conv.py └── model.py ├── sparselearning ├── .DS_Store ├── __init__.py ├── sparse_sgd.py ├── models.py └── core.py ├── README.md ├── run_base.sh ├── run.sh ├── run_wf.sh ├── run_multi.sh ├── run_waf.sh └── main_stgnn.py /.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CausalLearning/CGP/HEAD/.DS_Store -------------------------------------------------------------------------------- /models/.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CausalLearning/CGP/HEAD/models/.DS_Store -------------------------------------------------------------------------------- /sparselearning/.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CausalLearning/CGP/HEAD/sparselearning/.DS_Store -------------------------------------------------------------------------------- /sparselearning/__init__.py: -------------------------------------------------------------------------------- 1 | import logging 2 | logging.getLogger(__name__).addHandler(logging.NullHandler()) 3 | -------------------------------------------------------------------------------- /models/init_utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.init as init 4 | 5 | 6 | def weights_init(m): 7 | # print('=> weights init') 8 | if isinstance(m, nn.Conv2d): 9 | nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') 10 | # nn.init.normal_(m.weight, 0, 0.1) 11 | if m.bias is not None: 12 | m.bias.data.zero_() 13 | elif isinstance(m, nn.Linear): 14 | # nn.init.xavier_normal(m.weight) 15 | nn.init.normal_(m.weight, 0, 0.01) 16 | nn.init.constant_(m.bias, 0) 17 | elif isinstance(m, nn.BatchNorm2d): 18 | # Note that BN's running_var/mean are 19 | # already initialized to 1 and 0 respectively. 20 | if m.weight is not None: 21 | m.weight.data.fill_(1.0) 22 | if m.bias is not None: 23 | m.bias.data.zero_() -------------------------------------------------------------------------------- /models/initializers.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | 3 | # This source code is licensed under the MIT license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | import torch 7 | 8 | 9 | def binary(w): 10 | if isinstance(w, torch.nn.Linear) or isinstance(w, torch.nn.Conv2d): 11 | torch.nn.init.kaiming_normal_(w.weight) 12 | sigma = w.weight.data.std() 13 | w.weight.data = torch.sign(w.weight.data) * sigma 14 | 15 | 16 | def kaiming_normal(w): 17 | if isinstance(w, torch.nn.Linear) or isinstance(w, torch.nn.Conv2d): 18 | torch.nn.init.kaiming_normal_(w.weight) 19 | 20 | 21 | def kaiming_uniform(w): 22 | if isinstance(w, torch.nn.Linear) or isinstance(w, torch.nn.Conv2d): 23 | torch.nn.init.kaiming_uniform_(w.weight) 24 | 25 | 26 | def orthogonal(w): 27 | if isinstance(w, torch.nn.Linear) or isinstance(w, torch.nn.Conv2d): 28 | torch.nn.init.orthogonal_(w.weight) -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Comprehensive Graph Gradual Pruning for Sparse Training in Graph Neural Networks 2 | 3 | Open-sourced implementation for TNNLS 2023. 4 | 5 | 6 | 7 |

Abstract

8 | 9 | 1) We propose a graph gradual pruning framework, namely 10 | CGP, to reduce the training and inference computing costs 11 | of GNN models while preserving their accuracy. 12 | 13 | 2) We comprehensively sparsify the elements of GNNs, 14 | including graph structures, the node feature dimension, and 15 | model parameters, to significantly improve the efficiency 16 | of GNN models. 17 | 18 | 3) Experimental results on various GNN models and datasets 19 | consistently validate the effectiveness and efficiency of 20 | our proposed CGP. 21 | 22 | 23 | 24 |

Python Dependencies

25 | 26 | Our proposed Gapformer is implemented in Python 3.7 and major libraries include: 27 | 28 | * [Pytorch](https://pytorch.org/) = 1.11.0+cu113 29 | * [PyG](https://pytorch-geometric.readthedocs.io/en/latest/) torch-geometric=2.2.0 30 | 31 | More dependencies are provided in **requirements.txt**. 32 | 33 |

To Run

34 | 35 | Once the requirements are fulfilled, use this command to run: 36 | 37 | `sh xx.sh` 38 | 39 |

Datasets

40 | 41 | All datasets used in this paper can be downloaded from [PyG](https://pytorch-geometric.readthedocs.io/en/latest/modules/datasets.html). 42 | -------------------------------------------------------------------------------- /run_base.sh: -------------------------------------------------------------------------------- 1 | for model in gcn gat sgc appnp 2 | do 3 | for data in cora citeseer pubmed Cornell Texas Wisconsin Actor CS Physics Computers Photo WikiCS ogbn-arxiv 4 | do 5 | python main_stgnn.py --method GraNet \ 6 | --optimizer adam \ 7 | --sparse-init ERK \ 8 | --init-density 1.0 \ 9 | --l2 0.0005 \ 10 | --lr 0.01 \ 11 | --cuda $1 \ 12 | --epochs 200 \ 13 | --model $model \ 14 | --data $data 15 | done 16 | done 17 | 18 | 19 | # cora citeseer 20 | 21 | 22 | 23 | # --model: gcn, gat, sgc, appnp, gcnii (5) 24 | # --data: cora, citeseer, citeseer, Cornell, Texas, Wisconsin, Actor 25 | # --data: CS, Physics, Computers, Photo, WikiCS, reddit 26 | # --data: ogbn-arxiv, ogbn-proteins, ogbn-products, ogbn-papers100M (17) 27 | # --weight_sparse or --feature_sparse --sparse (7) 28 | # --sparse: base or sparse train (2) 29 | 30 | # --method: GraNet, GraNet_uniform, GMP, GMP_uniform (4) 31 | # --growth_schedule: gradient, momentum, random (3) 32 | # --sparse_init: uniform, ERK (2) 33 | 34 | # --prune-rate : regenration rate : 0.1, 0.2, 0.3 (3) 35 | # --update-frequency 10 20 30 (3) 36 | # --final-prune-epoch 50 100 150 (3) 37 | 38 | # --init-density: weight init density: 1, (dense to sparse) 39 | # --final-density: weight : 0.5 0.1 0.01 0.0001 (4) 40 | # --final-density_adj : 0.9 0.8 0.7 0.6 0.5 0.4 0.3 0.2 0.1 0.01 (10) 41 | # --final-density_feature: 0.9 0.8 0.7 0.6 0.5 0.4 0.3 0.2 0.1 0.01 (10) 42 | 43 | 44 | # 5 x 17 x 7 x 4 x 3 x 2 X 3 X 3 X 3 X 4 X 10 X 10 = 154,224,000 45 | 46 | # Actual: 4 x 13 x 4 x 10 x 3 x 3 x 3 = 56160 47 | 48 | 49 | # python main_stgnn.py --method GraNet \ 50 | # --prune-rate 0.5 \ 51 | # --optimizer adam \ 52 | # --sparse-init ERK \ 53 | # --init-density 0.5 \ 54 | # --final-density 0.1 \ 55 | # --update-frequency 10 \ 56 | # --l2 0.0005 \ 57 | # --lr 0.01 \ 58 | # --epochs 200 \ 59 | # --model gcn \ 60 | # --data cora \ 61 | # --final-prune-epoch 100 -------------------------------------------------------------------------------- /run.sh: -------------------------------------------------------------------------------- 1 | for model in gcn gat sgc appnp 2 | do 3 | for data in cora citeseer pubmed Cornell Texas Wisconsin Actor CS Physics Computers Photo WikiCS ogbn-arxiv 4 | do 5 | for fde in 0.5 0.1 0.01 0.0001 6 | do 7 | for fda in 0.9 0.8 0.7 0.6 0.5 0.4 0.3 0.2 0.1 0.01 8 | do 9 | for pr in 0.1 0.2 0.3 10 | do 11 | for uf in 10 20 30 12 | do 13 | for fpe in 50 100 150 14 | do 15 | python main_stgnn.py --method GraNet \ 16 | --prune-rate $pr \ 17 | --optimizer adam \ 18 | --sparse-init ERK \ 19 | --init-density 1.0 \ 20 | --final-density $fde \ 21 | --final-density_adj $fda \ 22 | --final-density_feature 0.5 \ 23 | --update-frequency $uf \ 24 | --l2 0.0005 \ 25 | --lr 0.01 \ 26 | --cuda $1 \ 27 | --epochs 200 \ 28 | --model $model \ 29 | --data $data \ 30 | --final-prune-epoch $fpe \ 31 | --growth_schedule momentum \ 32 | --adj_sparse \ 33 | --weight_sparse \ 34 | --sparse 35 | done 36 | done 37 | done 38 | done 39 | done 40 | done 41 | done 42 | 43 | 44 | 45 | 46 | 47 | # --model: gcn, gat, sgc, appnp, gcnii (5) 48 | # --data: cora, citeseer, citeseer, Cornell, Texas, Wisconsin, Actor 49 | # --data: CS, Physics, Computers, Photo, WikiCS, reddit 50 | # --data: ogbn-arxiv, ogbn-proteins, ogbn-products, ogbn-papers100M (17) 51 | # --weight_sparse or --feature_sparse --sparse (7) 52 | # --sparse: base or sparse train (2) 53 | 54 | # --method: GraNet, GraNet_uniform, GMP, GMP_uniform (4) 55 | # --growth_schedule: gradient, momentum, random (3) 56 | # --sparse_init: uniform, ERK (2) 57 | 58 | # --prune-rate : regenration rate : 0.1, 0.2, 0.3 (3) 59 | # --update-frequency 10 20 30 (3) 60 | # --final-prune-epoch 50 100 150 (3) 61 | 62 | # --init-density: weight init density: 1, (dense to sparse) 63 | # --final-density: weight : 0.5 0.1 0.01 0.0001 (4) 64 | # --final-density_adj : 0.9 0.8 0.7 0.6 0.5 0.4 0.3 0.2 0.1 0.01 (10) 65 | # --final-density_feature: 0.9 0.8 0.7 0.6 0.5 0.4 0.3 0.2 0.1 0.01 (10) 66 | 67 | 68 | # 5 x 17 x 7 x 4 x 3 x 2 X 3 X 3 X 3 X 4 X 10 X 10 = 154,224,000 69 | 70 | # Actual: 4 x 13 x 4 x 10 x 3 x 3 x 3 = 56160 71 | 72 | 73 | # python main_stgnn.py --method GraNet \ 74 | # --prune-rate 0.5 \ 75 | # --optimizer adam \ 76 | # --sparse-init ERK \ 77 | # --init-density 0.5 \ 78 | # --final-density 0.1 \ 79 | # --update-frequency 10 \ 80 | # --l2 0.0005 \ 81 | # --lr 0.01 \ 82 | # --epochs 200 \ 83 | # --model gcn \ 84 | # --data cora \ 85 | # --final-prune-epoch 100 -------------------------------------------------------------------------------- /run_wf.sh: -------------------------------------------------------------------------------- 1 | for model in gcn gat sgc appnp 2 | do 3 | for data in cora citeseer pubmed Cornell Texas Wisconsin Actor CS Physics Computers Photo ogbn-arxiv 4 | do 5 | for fde in 0.8 0.5 0.1 0.01 0.001 6 | do 7 | for fdf in 0.9 0.8 0.7 0.6 0.5 0.4 0.3 0.2 0.1 0.01 8 | do 9 | for pr in 0.1 0.2 0.3 10 | do 11 | for uf in 10 20 30 12 | do 13 | for fpe in 50 100 150 14 | do 15 | python main_stgnn.py --method GraNet \ 16 | --prune-rate $pr \ 17 | --optimizer adam \ 18 | --sparse-init ERK \ 19 | --init-density 1.0 \ 20 | --final-density $fde \ 21 | --final-density_adj 1.0 \ 22 | --final-density_feature $fdf \ 23 | --update-frequency $uf \ 24 | --l2 0.0005 \ 25 | --lr 0.01 \ 26 | --cuda $1 \ 27 | --epochs 200 \ 28 | --model $model \ 29 | --data $data \ 30 | --final-prune-epoch $fpe \ 31 | --growth_schedule momentum \ 32 | --feature_sparse \ 33 | --weight_sparse \ 34 | --sparse 35 | done 36 | done 37 | done 38 | done 39 | done 40 | done 41 | done 42 | 43 | 44 | # cora citeseer 45 | 46 | 47 | 48 | # --model: gcn, gat, sgc, appnp, gcnii (5) 49 | # --data: cora, citeseer, citeseer, Cornell, Texas, Wisconsin, Actor 50 | # --data: CS, Physics, Computers, Photo, WikiCS, reddit 51 | # --data: ogbn-arxiv, ogbn-proteins, ogbn-products, ogbn-papers100M (17) 52 | # --weight_sparse or --feature_sparse --sparse (7) 53 | # --sparse: base or sparse train (2) 54 | 55 | # --method: GraNet, GraNet_uniform, GMP, GMP_uniform (4) 56 | # --growth_schedule: gradient, momentum, random (3) 57 | # --sparse_init: uniform, ERK (2) 58 | 59 | # --prune-rate : regenration rate : 0.1, 0.2, 0.3 (3) 60 | # --update-frequency 10 20 30 (3) 61 | # --final-prune-epoch 50 100 150 (3) 62 | 63 | # --init-density: weight init density: 1, (dense to sparse) 64 | # --final-density: weight : 0.5 0.1 0.01 0.0001 (4) 65 | # --final-density_adj : 0.9 0.8 0.7 0.6 0.5 0.4 0.3 0.2 0.1 0.01 (10) 66 | # --final-density_feature: 0.9 0.8 0.7 0.6 0.5 0.4 0.3 0.2 0.1 0.01 (10) 67 | 68 | 69 | # 5 x 17 x 7 x 4 x 3 x 2 X 3 X 3 X 3 X 4 X 10 X 10 = 154,224,000 70 | 71 | # Actual: 4 x 13 x 4 x 10 x 3 x 3 x 3 = 56160 72 | 73 | 74 | # python main_stgnn.py --method GraNet \ 75 | # --prune-rate 0.5 \ 76 | # --optimizer adam \ 77 | # --sparse-init ERK \ 78 | # --init-density 0.5 \ 79 | # --final-density 0.1 \ 80 | # --update-frequency 10 \ 81 | # --l2 0.0005 \ 82 | # --lr 0.01 \ 83 | # --epochs 200 \ 84 | # --model gcn \ 85 | # --data cora \ 86 | # --final-prune-epoch 100 -------------------------------------------------------------------------------- /models/gcn_conv.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.nn import Parameter 3 | from torch_scatter import scatter_add 4 | from torch_geometric.nn.conv import MessagePassing 5 | from torch_geometric.utils import add_remaining_self_loops 6 | from torch_geometric.nn.inits import glorot, zeros 7 | import pdb 8 | 9 | class GCNConv(MessagePassing): 10 | 11 | def __init__(self, in_channels, out_channels, improved=False, cached=False, 12 | bias=True, normalize=True, **kwargs): 13 | super(GCNConv, self).__init__(aggr='add', **kwargs) 14 | 15 | self.in_channels = in_channels 16 | self.out_channels = out_channels 17 | self.improved = improved 18 | self.cached = cached 19 | self.normalize = normalize 20 | 21 | self.weight = Parameter(torch.Tensor(in_channels, out_channels)) 22 | 23 | if bias: 24 | self.bias = Parameter(torch.Tensor(out_channels)) 25 | else: 26 | self.register_parameter('bias', None) 27 | 28 | self.reset_parameters() 29 | 30 | def reset_parameters(self): 31 | glorot(self.weight) 32 | zeros(self.bias) 33 | self.cached_result = None 34 | self.cached_num_edges = None 35 | 36 | @staticmethod 37 | def norm(edge_index, num_nodes, edge_weight=None, improved=False, 38 | dtype=None): 39 | 40 | if edge_weight is None: 41 | edge_weight = torch.ones((edge_index.size(1), ), dtype=dtype, 42 | device=edge_index.device) 43 | 44 | 45 | fill_value = 1 if not improved else 2 46 | edge_index, edge_weight = add_remaining_self_loops( 47 | edge_index, edge_weight, fill_value, num_nodes) 48 | 49 | row, col = edge_index 50 | deg = scatter_add(edge_weight, row, dim=0, dim_size=num_nodes) 51 | deg_inv_sqrt = deg.pow(-0.5) 52 | deg_inv_sqrt[deg_inv_sqrt == float('inf')] = 0 53 | 54 | return edge_index, deg_inv_sqrt[row] * edge_weight * deg_inv_sqrt[col] 55 | 56 | def forward(self, x, edge_index, edge_weight=None): 57 | """""" 58 | x = torch.matmul(x, self.weight) 59 | 60 | if self.cached and self.cached_result is not None: 61 | if edge_index.size(1) != self.cached_num_edges: 62 | raise RuntimeError( 63 | 'Cached {} number of edges, but found {}. Please ' 64 | 'disable the caching behavior of this layer by removing ' 65 | 'the `cached=True` argument in its constructor.'.format( 66 | self.cached_num_edges, edge_index.size(1))) 67 | 68 | if not self.cached or self.cached_result is None: 69 | self.cached_num_edges = edge_index.size(1) 70 | if self.normalize: 71 | edge_index, norm = self.norm(edge_index, x.size(self.node_dim), 72 | edge_weight, self.improved, 73 | x.dtype) 74 | else: 75 | norm = edge_weight 76 | self.cached_result = edge_index, norm 77 | 78 | edge_index, norm = self.cached_result 79 | 80 | return self.propagate(edge_index, x=x, norm=norm) 81 | 82 | def message(self, x_j, norm): 83 | 84 | return norm.view(-1, 1) * x_j if norm is not None else x_j 85 | 86 | def update(self, aggr_out): 87 | if self.bias is not None: 88 | aggr_out = aggr_out + self.bias 89 | return aggr_out 90 | 91 | def __repr__(self): 92 | return '{}({}, {})'.format(self.__class__.__name__, self.in_channels, 93 | self.out_channels) 94 | -------------------------------------------------------------------------------- /run_multi.sh: -------------------------------------------------------------------------------- 1 | wei_arr=( 2 | 0.0783 3 | 0.0620 4 | 0.0493 5 | 0.0394 6 | 0.0315 7 | 0.0252 8 | 0.0201 9 | 0.0161 10 | 0.0129 11 | 0.0103) 12 | 13 | adj_arr=( 14 | 0.5656 15 | 0.5372 16 | 0.5102 17 | 0.4843 18 | 0.4598 19 | 0.4368 20 | 0.4149 21 | 0.3941 22 | 0.3737 23 | 0.3547) 24 | 25 | 26 | for model in gcn 27 | do 28 | for data in cora citeseer pubmed Cornell Texas Wisconsin Computers Photo 29 | do 30 | for i in ${!wei_arr[@]} 31 | do 32 | for pr in 0.1 0.2 0.3 33 | do 34 | for uf in 10 20 30 35 | do 36 | for fpe in 50 100 150 37 | do 38 | python main_stgnn.py --method GraNet \ 39 | --prune-rate $pr \ 40 | --optimizer adam \ 41 | --sparse-init ERK \ 42 | --init-density 1.0 \ 43 | --final-density ${wei_arr[$i]} \ 44 | --final-density_adj ${adj_arr[$i]} \ 45 | --final-density_feature 1.0 \ 46 | --update-frequency $uf \ 47 | --l2 0.0005 \ 48 | --lr 0.01 \ 49 | --epochs 200 \ 50 | --model $model \ 51 | --data $data \ 52 | --final-prune-epoch $fpe \ 53 | --growth_schedule momentum \ 54 | --adj_sparse \ 55 | --weight_sparse \ 56 | --sparse 57 | done 58 | done 59 | done 60 | done 61 | done 62 | done 63 | 64 | 65 | 66 | 67 | 68 | # --model: gcn, gat, sgc, appnp, gcnii (5) 69 | # --data: cora, citeseer, citeseer, Cornell, Texas, Wisconsin, Actor 70 | # --data: CS, Physics, Computers, Photo, WikiCS, reddit 71 | # --data: ogbn-arxiv, ogbn-proteins, ogbn-products, ogbn-papers100M (17) 72 | # --weight_sparse or --feature_sparse --sparse (7) 73 | # --sparse: base or sparse train (2) 74 | 75 | # --method: GraNet, GraNet_uniform, GMP, GMP_uniform (4) 76 | # --growth_schedule: gradient, momentum, random (3) 77 | # --sparse_init: uniform, ERK (2) 78 | 79 | # --prune-rate : regenration rate : 0.1, 0.2, 0.3 (3) 80 | # --update-frequency 10 20 30 (3) 81 | # --final-prune-epoch 50 100 150 (3) 82 | 83 | # --init-density: weight init density: 1, (dense to sparse) 84 | # --final-density: weight : 0.5 0.1 0.01 0.0001 (4) 85 | # --final-density_adj : 0.9 0.8 0.7 0.6 0.5 0.4 0.3 0.2 0.1 0.01 (10) 86 | # --final-density_feature: 0.9 0.8 0.7 0.6 0.5 0.4 0.3 0.2 0.1 0.01 (10) 87 | 88 | 89 | # 5 x 17 x 7 x 4 x 3 x 2 X 3 X 3 X 3 X 4 X 10 X 10 = 154,224,000 90 | 91 | # Actual: 4 x 13 x 4 x 10 x 3 x 3 x 3 = 56160 92 | 93 | 94 | # python main_stgnn.py --method GraNet \ 95 | # --prune-rate 0.5 \ 96 | # --optimizer adam \ 97 | # --sparse-init ERK \ 98 | # --init-density 0.5 \ 99 | # --final-density 0.1 \ 100 | # --update-frequency 10 \ 101 | # --l2 0.0005 \ 102 | # --lr 0.01 \ 103 | # --epochs 200 \ 104 | # --model gcn \ 105 | # --data cora \ 106 | # --final-prune-epoch 100 -------------------------------------------------------------------------------- /run_waf.sh: -------------------------------------------------------------------------------- 1 | for model in gcn gat sgc appnp 2 | do 3 | for data in cora citeseer pubmed Cornell Texas Wisconsin Actor CS Physics Computers Photo ogbn-arxiv 4 | do 5 | for fde in 0.5 0.1 0.01 6 | do 7 | for fdf in 0.9 0.8 0.7 0.6 0.5 0.4 0.3 0.2 0.1 0.01 8 | do 9 | for fda in 0.9 0.8 0.7 0.6 0.5 0.4 0.3 0.2 0.1 0.01 10 | do 11 | for pr in 0.1 0.2 0.3 12 | do 13 | for uf in 10 20 30 14 | do 15 | for fpe in 50 100 150 16 | do 17 | python main_stgnn.py --method GraNet \ 18 | --prune-rate $pr \ 19 | --optimizer adam \ 20 | --sparse-init ERK \ 21 | --init-density 1.0 \ 22 | --final-density $fde \ 23 | --final-density_adj $fda \ 24 | --final-density_feature $fdf \ 25 | --update-frequency $uf \ 26 | --l2 0.0005 \ 27 | --lr 0.01 \ 28 | --cuda $1 \ 29 | --epochs 200 \ 30 | --model $model \ 31 | --data $data \ 32 | --final-prune-epoch $fpe \ 33 | --growth_schedule momentum \ 34 | --feature_sparse \ 35 | --weight_sparse \ 36 | --adj_sparse \ 37 | --sparse 38 | done 39 | done 40 | done 41 | done 42 | done 43 | done 44 | done 45 | done 46 | 47 | 48 | # cora citeseer 49 | 50 | 51 | 52 | # --model: gcn, gat, sgc, appnp, gcnii (5) 53 | # --data: cora, citeseer, citeseer, Cornell, Texas, Wisconsin, Actor 54 | # --data: CS, Physics, Computers, Photo, WikiCS, reddit 55 | # --data: ogbn-arxiv, ogbn-proteins, ogbn-products, ogbn-papers100M (17) 56 | # --weight_sparse or --feature_sparse --sparse (7) 57 | # --sparse: base or sparse train (2) 58 | 59 | # --method: GraNet, GraNet_uniform, GMP, GMP_uniform (4) 60 | # --growth_schedule: gradient, momentum, random (3) 61 | # --sparse_init: uniform, ERK (2) 62 | 63 | # --prune-rate : regenration rate : 0.1, 0.2, 0.3 (3) 64 | # --update-frequency 10 20 30 (3) 65 | # --final-prune-epoch 50 100 150 (3) 66 | 67 | # --init-density: weight init density: 1, (dense to sparse) 68 | # --final-density: weight : 0.5 0.1 0.01 0.0001 (4) 69 | # --final-density_adj : 0.9 0.8 0.7 0.6 0.5 0.4 0.3 0.2 0.1 0.01 (10) 70 | # --final-density_feature: 0.9 0.8 0.7 0.6 0.5 0.4 0.3 0.2 0.1 0.01 (10) 71 | 72 | 73 | # 5 x 17 x 7 x 4 x 3 x 2 X 3 X 3 X 3 X 4 X 10 X 10 = 154,224,000 74 | 75 | # Actual: 4 x 13 x 4 x 10 x 3 x 3 x 3 = 56160 76 | 77 | 78 | # python main_stgnn.py --method GraNet \ 79 | # --prune-rate 0.5 \ 80 | # --optimizer adam \ 81 | # --sparse-init ERK \ 82 | # --init-density 0.5 \ 83 | # --final-density 0.1 \ 84 | # --update-frequency 10 \ 85 | # --l2 0.0005 \ 86 | # --lr 0.01 \ 87 | # --epochs 200 \ 88 | # --model gcn \ 89 | # --data cora \ 90 | # --final-prune-epoch 100 -------------------------------------------------------------------------------- /models/fagcn_conv.py: -------------------------------------------------------------------------------- 1 | from typing import Optional, Tuple 2 | 3 | import torch.nn.functional as F 4 | from torch import Tensor 5 | from torch_sparse import SparseTensor 6 | 7 | from torch_geometric.nn.conv import MessagePassing 8 | from torch_geometric.nn.conv.gcn_conv import gcn_norm 9 | from torch_geometric.nn.dense.linear import Linear 10 | from torch_geometric.typing import Adj, OptTensor 11 | 12 | 13 | class FAConv(MessagePassing): 14 | 15 | _cached_edge_index: Optional[Tuple[Tensor, Tensor]] 16 | _cached_adj_t: Optional[SparseTensor] 17 | _alpha: OptTensor 18 | 19 | def __init__(self, channels: int, eps: float = 0.1, dropout: float = 0.0, 20 | cached: bool = False, add_self_loops: bool = True, 21 | normalize: bool = True, **kwargs): 22 | 23 | kwargs.setdefault('aggr', 'add') 24 | super(FAConv, self).__init__(**kwargs) 25 | 26 | self.channels = channels 27 | self.eps = eps 28 | self.dropout = dropout 29 | self.cached = cached 30 | self.add_self_loops = add_self_loops 31 | self.normalize = normalize 32 | 33 | self._cached_edge_index = None 34 | self._cached_adj_t = None 35 | self._alpha = None 36 | 37 | self.att_l = Linear(channels, 1, bias=False) 38 | self.att_r = Linear(channels, 1, bias=False) 39 | 40 | self.reset_parameters() 41 | 42 | def reset_parameters(self): 43 | self.att_l.reset_parameters() 44 | self.att_r.reset_parameters() 45 | self._cached_edge_index = None 46 | self._cached_adj_t = None 47 | 48 | 49 | def forward(self, x: Tensor, x_0: Tensor, edge_index: Adj, edge_weight: OptTensor = None, return_attention_weights=None): 50 | if self.normalize: 51 | if isinstance(edge_index, Tensor): 52 | #assert edge_weight is None 53 | cache = self._cached_edge_index 54 | if cache is None: 55 | edge_index, edge_weight = gcn_norm( # yapf: disable 56 | edge_index, None, x.size(self.node_dim), False, 57 | self.add_self_loops, dtype=x.dtype) 58 | if self.cached: 59 | self._cached_edge_index = (edge_index, edge_weight) 60 | else: 61 | edge_index, edge_weight = cache[0], cache[1] 62 | 63 | elif isinstance(edge_index, SparseTensor): 64 | assert not edge_index.has_value() 65 | cache = self._cached_adj_t 66 | if cache is None: 67 | edge_index = gcn_norm( # yapf: disable 68 | edge_index, None, x.size(self.node_dim), False, 69 | self.add_self_loops, dtype=x.dtype) 70 | if self.cached: 71 | self._cached_adj_t = edge_index 72 | else: 73 | edge_index = cache 74 | else: 75 | if isinstance(edge_index, Tensor): 76 | assert edge_weight is not None 77 | elif isinstance(edge_index, SparseTensor): 78 | assert edge_index.has_value() 79 | 80 | alpha_l = self.att_l(x) 81 | alpha_r = self.att_r(x) 82 | 83 | # propagate_type: (x: Tensor, alpha: PairTensor, edge_weight: OptTensor) # noqa 84 | out = self.propagate(edge_index, x=x, alpha=(alpha_l, alpha_r), 85 | edge_weight=edge_weight, size=None) 86 | 87 | alpha = self._alpha 88 | self._alpha = None 89 | 90 | if self.eps != 0.0: 91 | out += self.eps * x_0 92 | 93 | if isinstance(return_attention_weights, bool): 94 | assert alpha is not None 95 | if isinstance(edge_index, Tensor): 96 | return out, (edge_index, alpha) 97 | elif isinstance(edge_index, SparseTensor): 98 | return out, edge_index.set_value(alpha, layout='coo') 99 | else: 100 | return out 101 | 102 | 103 | def message(self, x_j: Tensor, alpha_j: Tensor, alpha_i: Tensor, 104 | edge_weight: OptTensor) -> Tensor: 105 | assert edge_weight is not None 106 | alpha = (alpha_j + alpha_i).tanh().squeeze(-1) 107 | self._alpha = alpha 108 | alpha = F.dropout(alpha, p=self.dropout, training=self.training) 109 | return x_j * (alpha * edge_weight).view(-1, 1) 110 | 111 | def __repr__(self) -> str: 112 | return f'{self.__class__.__name__}({self.channels}, eps={self.eps})' -------------------------------------------------------------------------------- /sparselearning/sparse_sgd.py: -------------------------------------------------------------------------------- 1 | from torch.optim.optimizer import Optimizer, required 2 | import torch 3 | import numpy as np 4 | class sparse_SGD(Optimizer): 5 | r"""Implements sparse stochastic gradient descent (optionally with momentum), according to the pytorch version 1.5.1. 6 | 7 | Nesterov momentum is based on the formula from 8 | `On the importance of initialization and momentum in deep learning`__. 9 | 10 | Args: 11 | params (iterable): iterable of parameters to optimize or dicts defining 12 | parameter groups 13 | lr (float): learning rate 14 | momentum (float, optional): momentum factor (default: 0) 15 | weight_decay (float, optional): weight decay (L2 penalty) (default: 0) 16 | dampening (float, optional): dampening for momentum (default: 0) 17 | nesterov (bool, optional): enables Nesterov momentum (default: False) 18 | 19 | Example: 20 | >>> optimizer = torch.optim.SGD(model.parameters(), lr=0.1, momentum=0.9) 21 | >>> optimizer.zero_grad() 22 | >>> loss_fn(model(input), target).backward() 23 | >>> optimizer.step() 24 | 25 | __ http://www.cs.toronto.edu/%7Ehinton/absps/momentum.pdf 26 | 27 | .. note:: 28 | The implementation of SGD with Momentum/Nesterov subtly differs from 29 | Sutskever et. al. and implementations in some other frameworks. 30 | 31 | Considering the specific case of Momentum, the update can be written as 32 | 33 | .. math:: 34 | v = \rho * v + g \\ 35 | p = p - lr * v 36 | 37 | where p, g, v and :math:`\rho` denote the parameters, gradient, 38 | velocity, and momentum respectively. 39 | 40 | This is in contrast to Sutskever et. al. and 41 | other frameworks which employ an update of the form 42 | 43 | .. math:: 44 | v = \rho * v + lr * g \\ 45 | p = p - v 46 | 47 | The Nesterov version is analogously modified. 48 | """ 49 | 50 | def __init__(self, params, lr=required, momentum=0, dampening=0, 51 | weight_decay=0, nesterov=False): 52 | if lr is not required and lr < 0.0: 53 | raise ValueError("Invalid learning rate: {}".format(lr)) 54 | if momentum < 0.0: 55 | raise ValueError("Invalid momentum value: {}".format(momentum)) 56 | if weight_decay < 0.0: 57 | raise ValueError("Invalid weight_decay value: {}".format(weight_decay)) 58 | 59 | defaults = dict(lr=lr, momentum=momentum, dampening=dampening, 60 | weight_decay=weight_decay, nesterov=nesterov) 61 | if nesterov and (momentum <= 0 or dampening != 0): 62 | raise ValueError("Nesterov momentum requires a momentum and zero dampening") 63 | super(sparse_SGD, self).__init__(params, defaults) 64 | 65 | def __setstate__(self, state): 66 | super(sparse_SGD, self).__setstate__(state) 67 | for group in self.param_groups: 68 | group.setdefault('nesterov', False) 69 | 70 | @torch.no_grad() 71 | def step(self, closure=None, nonzero_masks=None, new_masks=None, gamma=None, epoch=None): 72 | """Performs a single optimization step. 73 | 74 | Arguments: 75 | closure (callable, optional): A closure that reevaluates the model 76 | and returns the loss. 77 | """ 78 | loss = None 79 | if closure is not None: 80 | with torch.enable_grad(): 81 | loss = closure() 82 | 83 | if epoch <= 100: 84 | for group in self.param_groups: 85 | weight_decay = group['weight_decay'] 86 | momentum = group['momentum'] 87 | dampening = group['dampening'] 88 | nesterov = group['nesterov'] 89 | 90 | for p in group['params']: 91 | if p.grad is None: 92 | continue 93 | d_p = p.grad 94 | if weight_decay != 0: 95 | d_p = d_p.add(p, alpha=weight_decay) 96 | if momentum != 0: 97 | param_state = self.state[p] 98 | if 'momentum_buffer' not in param_state: 99 | buf = param_state['momentum_buffer'] = torch.clone(d_p).detach() 100 | else: 101 | buf = param_state['momentum_buffer'] 102 | buf.mul_(momentum).add_(d_p, alpha=1 - dampening) 103 | if nesterov: 104 | d_p = d_p.add(buf, alpha=momentum) 105 | else: 106 | d_p = buf 107 | 108 | p.add_(d_p, alpha=-group['lr']) 109 | else: 110 | for group in self.param_groups: 111 | weight_decay = group['weight_decay'] 112 | momentum = group['momentum'] 113 | dampening = group['dampening'] 114 | nesterov = group['nesterov'] 115 | 116 | for i, p in enumerate(group['params']): 117 | if p.grad is None: 118 | continue 119 | 120 | sparse_layer_flag = False 121 | for key in nonzero_masks.keys(): 122 | if i == float(key.split('_')[-1]): 123 | nonzero_mask = nonzero_masks[key] 124 | new_mask = new_masks[key] 125 | sparse_layer_flag = True 126 | 127 | d_p = p.grad 128 | if weight_decay != 0: 129 | d_p = d_p.add(p, alpha=weight_decay) 130 | if momentum != 0: 131 | param_state = self.state[p] 132 | if 'momentum_buffer' not in param_state: 133 | buf = param_state['momentum_buffer'] = torch.clone(d_p).detach() 134 | else: 135 | buf = param_state['momentum_buffer'] 136 | buf.mul_(momentum).add_(d_p, alpha=1 - dampening) 137 | if nesterov: 138 | d_p = d_p.add(buf, alpha=momentum) 139 | else: 140 | d_p = buf 141 | 142 | p.add_(d_p, alpha=-group['lr']) 143 | 144 | if sparse_layer_flag: 145 | p.add_(d_p * nonzero_mask, alpha=-group['lr']) 146 | p.add_(d_p * new_mask, alpha=-gamma) 147 | 148 | else: 149 | p.add_(d_p, alpha=-group['lr']) 150 | 151 | return loss -------------------------------------------------------------------------------- /models/gat_conv.py: -------------------------------------------------------------------------------- 1 | from typing import Union, Tuple, Optional 2 | from torch_geometric.typing import (OptPairTensor, Adj, Size, NoneType, OptTensor) 3 | import torch 4 | from torch import Tensor 5 | import torch.nn.functional as F 6 | from torch.nn import Parameter, Linear 7 | from torch_sparse import SparseTensor, set_diag 8 | from torch_geometric.nn.conv import MessagePassing 9 | from torch_geometric.utils import remove_self_loops, add_self_loops, softmax 10 | from torch_geometric.nn.inits import glorot, zeros 11 | import pdb 12 | 13 | class GATConv(MessagePassing): 14 | 15 | _alpha: OptTensor 16 | 17 | def __init__(self, in_channels: Union[int, Tuple[int, int]], 18 | out_channels: int, heads: int = 1, concat: bool = True, 19 | negative_slope: float = 0.2, dropout: float = 0.0, 20 | add_self_loops: bool = True, bias: bool = True, **kwargs): 21 | kwargs.setdefault('aggr', 'add') 22 | super(GATConv, self).__init__(node_dim=0, **kwargs) 23 | 24 | self.in_channels = in_channels 25 | self.out_channels = out_channels 26 | self.heads = heads 27 | self.concat = concat 28 | self.negative_slope = negative_slope 29 | self.dropout = dropout 30 | self.add_self_loops = add_self_loops 31 | 32 | if isinstance(in_channels, int): 33 | self.lin_l = Linear(in_channels, heads * out_channels, bias=False) 34 | self.lin_r = self.lin_l 35 | else: 36 | self.lin_l = Linear(in_channels[0], heads * out_channels, False) 37 | self.lin_r = Linear(in_channels[1], heads * out_channels, False) 38 | 39 | self.att_l = Parameter(torch.Tensor(1, heads, out_channels)) 40 | self.att_r = Parameter(torch.Tensor(1, heads, out_channels)) 41 | 42 | if bias and concat: 43 | self.bias = Parameter(torch.Tensor(heads * out_channels)) 44 | elif bias and not concat: 45 | self.bias = Parameter(torch.Tensor(out_channels)) 46 | else: 47 | self.register_parameter('bias', None) 48 | 49 | self._alpha = None 50 | 51 | self.reset_parameters() 52 | 53 | def reset_parameters(self): 54 | glorot(self.lin_l.weight) 55 | glorot(self.lin_r.weight) 56 | glorot(self.att_l) 57 | glorot(self.att_r) 58 | zeros(self.bias) 59 | 60 | def forward(self, x: Union[Tensor, OptPairTensor], edge_index: Adj, 61 | size: Size = None, return_attention_weights=None, edge_weight=None): 62 | # type: (Union[Tensor, OptPairTensor], Tensor, Size, NoneType) -> Tensor # noqa 63 | # type: (Union[Tensor, OptPairTensor], SparseTensor, Size, NoneType) -> Tensor # noqa 64 | # type: (Union[Tensor, OptPairTensor], Tensor, Size, bool) -> Tuple[Tensor, Tuple[Tensor, Tensor]] # noqa 65 | # type: (Union[Tensor, OptPairTensor], SparseTensor, Size, bool) -> Tuple[Tensor, SparseTensor] # noqa 66 | r""" 67 | Args: 68 | return_attention_weights (bool, optional): If set to :obj:`True`, 69 | will additionally return the tuple 70 | :obj:`(edge_index, attention_weights)`, holding the computed 71 | attention weights for each edge. (default: :obj:`None`) 72 | """ 73 | H, C = self.heads, self.out_channels # 4, 256 74 | 75 | x_l: OptTensor = None 76 | x_r: OptTensor = None 77 | alpha_l: OptTensor = None 78 | alpha_r: OptTensor = None 79 | if isinstance(x, Tensor): 80 | assert x.dim() == 2, 'Static graphs not supported in `GATConv`.' 81 | x_l = x_r = self.lin_l(x).view(-1, H, C) 82 | alpha_l = (x_l * self.att_l).sum(dim=-1) 83 | alpha_r = (x_r * self.att_r).sum(dim=-1) 84 | else: 85 | x_l, x_r = x[0], x[1] 86 | assert x[0].dim() == 2, 'Static graphs not supported in `GATConv`.' 87 | x_l = self.lin_l(x_l).view(-1, H, C) 88 | alpha_l = (x_l * self.att_l).sum(dim=-1) 89 | if x_r is not None: 90 | x_r = self.lin_r(x_r).view(-1, H, C) 91 | alpha_r = (x_r * self.att_r).sum(dim=-1) 92 | 93 | assert x_l is not None 94 | assert alpha_l is not None 95 | 96 | if self.add_self_loops: 97 | if isinstance(edge_index, Tensor): 98 | num_nodes = x_l.size(0) 99 | if x_r is not None: 100 | num_nodes = min(num_nodes, x_r.size(0)) 101 | if size is not None: 102 | num_nodes = min(size[0], size[1]) 103 | edge_index, _ = remove_self_loops(edge_index) 104 | edge_index, _ = add_self_loops(edge_index, num_nodes=num_nodes) 105 | if edge_weight is not None: 106 | loop_weight = torch.full((num_nodes, ), 1, 107 | dtype=edge_weight.dtype, 108 | device=edge_weight.device) 109 | edge_weight = torch.cat([edge_weight, loop_weight], dim=0) 110 | elif isinstance(edge_index, SparseTensor): 111 | edge_index = set_diag(edge_index) 112 | 113 | # propagate_type: (x: OptPairTensor, alpha: OptPairTensor) 114 | out = self.propagate(edge_index, 115 | x=(x_l, x_r), 116 | alpha=(alpha_l, alpha_r), 117 | size=size, 118 | edge_weight=edge_weight) 119 | 120 | alpha = self._alpha 121 | self._alpha = None 122 | 123 | if self.concat: 124 | out = out.view(-1, self.heads * self.out_channels) 125 | else: 126 | out = out.mean(dim=1) 127 | 128 | if self.bias is not None: 129 | out += self.bias 130 | 131 | if isinstance(return_attention_weights, bool): 132 | assert alpha is not None 133 | if isinstance(edge_index, Tensor): 134 | return out, (edge_index, alpha) 135 | elif isinstance(edge_index, SparseTensor): 136 | return out, edge_index.set_value(alpha, layout='coo') 137 | else: 138 | return out 139 | 140 | def message(self, 141 | x_j: Tensor, 142 | alpha_j: Tensor, 143 | alpha_i: OptTensor, 144 | index: Tensor, 145 | ptr: OptTensor, 146 | size_i: Optional[int], 147 | edge_weight: Tensor) -> Tensor: 148 | alpha = alpha_j if alpha_i is None else alpha_j + alpha_i 149 | alpha = F.leaky_relu(alpha, self.negative_slope) 150 | alpha = softmax(alpha, index, ptr, size_i) 151 | self._alpha = alpha 152 | alpha = F.dropout(alpha, p=self.dropout, training=self.training) 153 | if edge_weight is None: 154 | return x_j * alpha.unsqueeze(-1) 155 | else: 156 | return x_j * alpha.unsqueeze(-1) * edge_weight.expand(alpha.shape[1], alpha.shape[0]).t().unsqueeze(-1) 157 | 158 | def __repr__(self): 159 | return '{}({}, {}, heads={})'.format(self.__class__.__name__, 160 | self.in_channels, 161 | self.out_channels, 162 | self.heads) 163 | -------------------------------------------------------------------------------- /sparselearning/models.py: -------------------------------------------------------------------------------- 1 | import math 2 | import time 3 | import numpy as np 4 | 5 | import torch 6 | import torch.nn as nn 7 | import torch.nn.functional as F 8 | class SparseSpeedupBench(object): 9 | """Class to benchmark speedups for convolutional layers. 10 | 11 | Basic usage: 12 | 1. Assing a single SparseSpeedupBench instance to class (and sub-classes with conv layers). 13 | 2. Instead of forwarding input through normal convolutional layers, we pass them through the bench: 14 | self.bench = SparseSpeedupBench() 15 | self.conv_layer1 = nn.Conv2(3, 96, 3) 16 | 17 | if self.bench is not None: 18 | outputs = self.bench.forward(self.conv_layer1, inputs, layer_id='conv_layer1') 19 | else: 20 | outputs = self.conv_layer1(inputs) 21 | 3. Speedups of the convolutional layer will be aggregated and print every 1000 mini-batches. 22 | """ 23 | def __init__(self): 24 | self.layer_timings = {} 25 | self.layer_timings_channel_sparse = {} 26 | self.layer_timings_sparse = {} 27 | self.iter_idx = 0 28 | self.layer_0_idx = None 29 | self.total_timings = [] 30 | self.total_timings_channel_sparse = [] 31 | self.total_timings_sparse = [] 32 | 33 | def get_density(self, x): 34 | return (x.data!=0.0).sum().item()/x.numel() 35 | 36 | def print_weights(self, w, layer): 37 | # w dims: out, in, k1, k2 38 | #outers = [] 39 | #for outer in range(w.shape[0]): 40 | # inners = [] 41 | # for inner in range(w.shape[1]): 42 | # n = np.prod(w.shape[2:]) 43 | # density = (w[outer, inner, :, :] != 0.0).sum().item() / n 44 | # #print(density, w[outer, inner]) 45 | # inners.append(density) 46 | # outers.append([np.mean(inners), np.std(inner)]) 47 | #print(outers) 48 | #print(w.shape, (w!=0.0).sum().item()/w.numel()) 49 | pass 50 | 51 | def forward(self, layer, x, layer_id): 52 | if self.layer_0_idx is None: self.layer_0_idx = layer_id 53 | if layer_id == self.layer_0_idx: self.iter_idx += 1 54 | self.print_weights(layer.weight.data, layer) 55 | 56 | # calc input sparsity 57 | sparse_channels_in = ((x.data != 0.0).sum([2, 3]) == 0.0).sum().item() 58 | num_channels_in = x.shape[1] 59 | batch_size = x.shape[0] 60 | channel_sparsity_input = sparse_channels_in/float(num_channels_in*batch_size) 61 | input_sparsity = self.get_density(x) 62 | 63 | # bench dense layer 64 | start = torch.cuda.Event(enable_timing=True) 65 | end = torch.cuda.Event(enable_timing=True) 66 | start.record() 67 | x = layer(x) 68 | end.record() 69 | start.synchronize() 70 | end.synchronize() 71 | time_taken_s = start.elapsed_time(end)/1000.0 72 | 73 | # calc weight sparsity 74 | num_channels = layer.weight.shape[1] 75 | sparse_channels = ((layer.weight.data != 0.0).sum([0, 2, 3]) == 0.0).sum().item() 76 | channel_sparsity_weight = sparse_channels/float(num_channels) 77 | weight_sparsity = self.get_density(layer.weight) 78 | 79 | # store sparse and dense timings 80 | if layer_id not in self.layer_timings: 81 | self.layer_timings[layer_id] = [] 82 | self.layer_timings_channel_sparse[layer_id] = [] 83 | self.layer_timings_sparse[layer_id] = [] 84 | self.layer_timings[layer_id].append(time_taken_s) 85 | self.layer_timings_channel_sparse[layer_id].append(time_taken_s*(1.0-channel_sparsity_weight)*(1.0-channel_sparsity_input)) 86 | self.layer_timings_sparse[layer_id].append(time_taken_s*input_sparsity*weight_sparsity) 87 | 88 | if self.iter_idx % 1000 == 0: 89 | self.print_layer_timings() 90 | self.iter_idx += 1 91 | 92 | return x 93 | 94 | def print_layer_timings(self): 95 | total_time_dense = 0.0 96 | total_time_sparse = 0.0 97 | total_time_channel_sparse = 0.0 98 | print('\n') 99 | for layer_id in self.layer_timings: 100 | t_dense = np.mean(self.layer_timings[layer_id]) 101 | t_channel_sparse = np.mean(self.layer_timings_channel_sparse[layer_id]) 102 | t_sparse = np.mean(self.layer_timings_sparse[layer_id]) 103 | total_time_dense += t_dense 104 | total_time_sparse += t_sparse 105 | total_time_channel_sparse += t_channel_sparse 106 | 107 | print('Layer {0}: Dense {1:.6f} Channel Sparse {2:.6f} vs Full Sparse {3:.6f}'.format(layer_id, t_dense, t_channel_sparse, t_sparse)) 108 | self.total_timings.append(total_time_dense) 109 | self.total_timings_sparse.append(total_time_sparse) 110 | self.total_timings_channel_sparse.append(total_time_channel_sparse) 111 | 112 | print('Speedups for this segment:') 113 | print('Dense took {0:.4f}s. Channel Sparse took {1:.4f}s. Speedup of {2:.4f}x'.format(total_time_dense, total_time_channel_sparse, total_time_dense/total_time_channel_sparse)) 114 | print('Dense took {0:.4f}s. Sparse took {1:.4f}s. Speedup of {2:.4f}x'.format(total_time_dense, total_time_sparse, total_time_dense/total_time_sparse)) 115 | print('\n') 116 | 117 | total_dense = np.sum(self.total_timings) 118 | total_sparse = np.sum(self.total_timings_sparse) 119 | total_channel_sparse = np.sum(self.total_timings_channel_sparse) 120 | print('Speedups for entire training:') 121 | print('Dense took {0:.4f}s. Channel Sparse took {1:.4f}s. Speedup of {2:.4f}x'.format(total_dense, total_channel_sparse, total_dense/total_channel_sparse)) 122 | print('Dense took {0:.4f}s. Sparse took {1:.4f}s. Speedup of {2:.4f}x'.format(total_dense, total_sparse, total_dense/total_sparse)) 123 | print('\n') 124 | 125 | # clear timings 126 | for layer_id in list(self.layer_timings.keys()): 127 | self.layer_timings.pop(layer_id) 128 | self.layer_timings_channel_sparse.pop(layer_id) 129 | self.layer_timings_sparse.pop(layer_id) 130 | 131 | 132 | 133 | class AlexNet(nn.Module): 134 | """AlexNet with batch normalization and without pooling. 135 | 136 | This is an adapted version of AlexNet as taken from 137 | SNIP: Single-shot Network Pruning based on Connection Sensitivity, 138 | https://arxiv.org/abs/1810.02340 139 | 140 | There are two different version of AlexNet: 141 | AlexNet-s (small): Has hidden layers with size 1024 142 | AlexNet-b (big): Has hidden layers with size 2048 143 | 144 | Based on https://github.com/mi-lad/snip/blob/master/train.py 145 | by Milad Alizadeh. 146 | """ 147 | 148 | def __init__(self, config='s', num_classes=1000, save_features=False, bench_model=False): 149 | super(AlexNet, self).__init__() 150 | self.save_features = save_features 151 | self.feats = [] 152 | self.densities = [] 153 | self.bench = None if not bench_model else SparseSpeedupBench() 154 | 155 | factor = 1 if config=='s' else 2 156 | self.features = nn.Sequential( 157 | nn.Conv2d(3, 96, kernel_size=11, stride=2, padding=2, bias=True), 158 | nn.BatchNorm2d(96), 159 | nn.ReLU(inplace=True), 160 | nn.Conv2d(96, 256, kernel_size=5, stride=2, padding=2, bias=True), 161 | nn.BatchNorm2d(256), 162 | nn.ReLU(inplace=True), 163 | nn.Conv2d(256, 384, kernel_size=3, stride=2, padding=1, bias=True), 164 | nn.BatchNorm2d(384), 165 | nn.ReLU(inplace=True), 166 | nn.Conv2d(384, 384, kernel_size=3, stride=2, padding=1, bias=True), 167 | nn.BatchNorm2d(384), 168 | nn.ReLU(inplace=True), 169 | nn.Conv2d(384, 256, kernel_size=3, stride=2, padding=1, bias=True), 170 | nn.BatchNorm2d(256), 171 | nn.ReLU(inplace=True), 172 | ) 173 | self.classifier = nn.Sequential( 174 | nn.Linear(256, 1024*factor), 175 | nn.BatchNorm1d(1024*factor), 176 | nn.ReLU(inplace=True), 177 | nn.Linear(1024*factor, 1024*factor), 178 | nn.BatchNorm1d(1024*factor), 179 | nn.ReLU(inplace=True), 180 | nn.Linear(1024*factor, num_classes), 181 | ) 182 | 183 | def forward(self, x): 184 | for layer_id, layer in enumerate(self.features): 185 | if self.bench is not None and isinstance(layer, nn.Conv2d): 186 | x = self.bench.forward(layer, x, layer_id) 187 | else: 188 | x = layer(x) 189 | 190 | if self.save_features: 191 | if isinstance(layer, nn.ReLU): 192 | self.feats.append(x.clone().detach()) 193 | if isinstance(layer, nn.Conv2d): 194 | self.densities.append((layer.weight.data != 0.0).sum().item()/layer.weight.numel()) 195 | 196 | x = x.view(x.size(0), -1) 197 | x = self.classifier(x) 198 | return F.log_softmax(x, dim=1) 199 | 200 | class LeNet_300_100(nn.Module): 201 | """Simple NN with hidden layers [300, 100] 202 | 203 | Based on https://github.com/mi-lad/snip/blob/master/train.py 204 | by Milad Alizadeh. 205 | """ 206 | def __init__(self, save_features=None, bench_model=False): 207 | super(LeNet_300_100, self).__init__() 208 | self.fc1 = nn.Linear(28*28, 300, bias=True) 209 | self.fc2 = nn.Linear(300, 100, bias=True) 210 | self.fc3 = nn.Linear(100, 10, bias=True) 211 | self.mask = None 212 | 213 | def forward(self, x): 214 | x0 = x.view(-1, 28*28) 215 | x1 = F.relu(self.fc1(x0)) 216 | x2 = F.relu(self.fc2(x1)) 217 | x3 = self.fc3(x2) 218 | return F.log_softmax(x3, dim=1) 219 | 220 | class MLP_CIFAR10(nn.Module): 221 | def __init__(self, save_features=None, bench_model=False): 222 | super(MLP_CIFAR10, self).__init__() 223 | 224 | self.fc1 = nn.Linear(3*32*32, 1024) 225 | self.fc2 = nn.Linear(1024, 512) 226 | self.fc3 = nn.Linear(512, 10) 227 | 228 | def forward(self, x): 229 | x0 = F.relu(self.fc1(x.view(-1, 3*32*32))) 230 | x1 = F.relu(self.fc2(x0)) 231 | return F.log_softmax(self.fc3(x1), dim=1) 232 | 233 | 234 | class LeNet_5_Caffe(nn.Module): 235 | """LeNet-5 without padding in the first layer. 236 | This is based on Caffe's implementation of Lenet-5 and is slightly different 237 | from the vanilla LeNet-5. Note that the first layer does NOT have padding 238 | and therefore intermediate shapes do not match the official LeNet-5. 239 | 240 | Based on https://github.com/mi-lad/snip/blob/master/train.py 241 | by Milad Alizadeh. 242 | """ 243 | 244 | def __init__(self, save_features=None, bench_model=False): 245 | super().__init__() 246 | self.conv1 = nn.Conv2d(1, 20, 5, padding=0, bias=True) 247 | self.conv2 = nn.Conv2d(20, 50, 5, bias=True) 248 | self.fc3 = nn.Linear(50 * 4 * 4, 500) 249 | self.fc4 = nn.Linear(500, 10) 250 | 251 | def forward(self, x): 252 | x = F.relu(self.conv1(x)) 253 | x = F.max_pool2d(x, 2) 254 | x = F.relu(self.conv2(x)) 255 | x = F.max_pool2d(x, 2) 256 | x = F.relu(self.fc3(x.view(-1, 50 * 4 * 4))) 257 | x = F.log_softmax(self.fc4(x), dim=1) 258 | 259 | return x 260 | 261 | 262 | VGG_CONFIGS = { 263 | # M for MaxPool, Number for channels 264 | 'like': [ 265 | 64, 64, 'M', 128, 128, 'M', 256, 256, 256, 'M', 512, 512, 512, 'M', 266 | 512, 512, 512, 'M' 267 | ], 268 | 'D': [ 269 | 64, 64, 'M', 128, 128, 'M', 256, 256, 256, 'M', 512, 512, 512, 'M', 270 | 512, 512, 512, 'M' 271 | ], 272 | 'C': [ 273 | 64, 64, 'M', 128, 128, 'M', 256, 256, (1, 256), 'M', 512, 512, (1, 512), 'M', 274 | 512, 512, (1, 512), 'M' # tuples indicate (kernel size, output channels) 275 | ] 276 | } 277 | 278 | 279 | class VGG16(nn.Module): 280 | """ 281 | This is a base class to generate three VGG variants used in SNIP paper: 282 | 1. VGG-C (16 layers) 283 | 2. VGG-D (16 layers) 284 | 3. VGG-like 285 | 286 | Some of the differences: 287 | * Reduced size of FC lis ayers to 512 288 | * Adjusted flattening to match CIFAR-10 shapes 289 | * Replaced dropout layers with BatchNorm 290 | 291 | Based on https://github.com/mi-lad/snip/blob/master/train.py 292 | by Milad Alizadeh. 293 | """ 294 | 295 | def __init__(self, config, num_classes=10, save_features=False, bench_model=False): 296 | super().__init__() 297 | 298 | self.features = self.make_layers(VGG_CONFIGS[config], batch_norm=True) 299 | self.feats = [] 300 | self.densities = [] 301 | self.save_features = save_features 302 | self.bench = None if not bench_model else SparseSpeedupBench() 303 | 304 | if config == 'C' or config == 'D': 305 | self.classifier = nn.Sequential( 306 | nn.Linear((512 if config == 'D' else 2048), 512), # 512 * 7 * 7 in the original VGG 307 | nn.ReLU(True), 308 | nn.BatchNorm1d(512), # instead of dropout 309 | nn.Linear(512, 512), 310 | nn.ReLU(True), 311 | nn.BatchNorm1d(512), # instead of dropout 312 | nn.Linear(512, num_classes), 313 | ) 314 | else: 315 | self.classifier = nn.Sequential( 316 | nn.Linear(512, 512), # 512 * 7 * 7 in the original VGG 317 | nn.ReLU(True), 318 | nn.BatchNorm1d(512), # instead of dropout 319 | nn.Linear(512, num_classes), 320 | ) 321 | 322 | @staticmethod 323 | def make_layers(config, batch_norm=False): 324 | layers = [] 325 | in_channels = 3 326 | for v in config: 327 | if v == 'M': 328 | layers += [nn.MaxPool2d(kernel_size=2, stride=2)] 329 | else: 330 | kernel_size = 3 331 | if isinstance(v, tuple): 332 | kernel_size, v = v 333 | conv2d = nn.Conv2d(in_channels, v, kernel_size=kernel_size, padding=1) 334 | if batch_norm: 335 | layers += [ 336 | conv2d, 337 | nn.BatchNorm2d(v), 338 | nn.ReLU(inplace=True) 339 | ] 340 | else: 341 | layers += [conv2d, nn.ReLU(inplace=True)] 342 | in_channels = v 343 | return nn.Sequential(*layers) 344 | 345 | def forward(self, x): 346 | for layer_id, layer in enumerate(self.features): 347 | if self.bench is not None and isinstance(layer, nn.Conv2d): 348 | x = self.bench.forward(layer, x, layer_id) 349 | else: 350 | x = layer(x) 351 | 352 | if self.save_features: 353 | if isinstance(layer, nn.ReLU): 354 | self.feats.append(x.clone().detach()) 355 | self.densities.append((x.data != 0.0).sum().item()/x.numel()) 356 | 357 | x = x.view(x.size(0), -1) 358 | x = self.classifier(x) 359 | x = F.log_softmax(x, dim=1) 360 | return x 361 | 362 | class VGG16_Srelu(nn.Module): 363 | """ 364 | This is a base class to generate three VGG variants used in SNIP paper: 365 | 1. VGG-C (16 layers) 366 | 2. VGG-D (16 layers) 367 | 3. VGG-like 368 | 369 | Some of the differences: 370 | * Reduced size of FC layers to 512 371 | * Adjusted flattening to match CIFAR-10 shapes 372 | * Replaced dropout layers with BatchNorm 373 | 374 | Based on https://github.com/mi-lad/snip/blob/master/train.py 375 | by Milad Alizadeh. 376 | """ 377 | 378 | def __init__(self, config, num_classes=10, save_features=False, bench_model=False): 379 | super().__init__() 380 | 381 | self.features = self.make_layers(VGG_CONFIGS[config], batch_norm=True) 382 | self.feats = [] 383 | self.densities = [] 384 | self.save_features = save_features 385 | self.bench = None if not bench_model else SparseSpeedupBench() 386 | 387 | if config == 'C' or config == 'D': 388 | self.classifier = nn.Sequential( 389 | nn.Linear((512 if config == 'D' else 2048), 512), # 512 * 7 * 7 in the original VGG 390 | nn.ReLU(True), 391 | nn.BatchNorm1d(512), # instead of dropout 392 | nn.Linear(512, 512), 393 | nn.ReLU(True), 394 | nn.BatchNorm1d(512), # instead of dropout 395 | nn.Linear(512, num_classes), 396 | ) 397 | else: 398 | self.classifier = nn.Sequential( 399 | nn.Linear(512, 512), # 512 * 7 * 7 in the original VGG 400 | nn.ReLU(True), 401 | nn.BatchNorm1d(512), # instead of dropout 402 | nn.Linear(512, num_classes), 403 | ) 404 | 405 | @staticmethod 406 | def make_layers(config, batch_norm=False): 407 | layers = [] 408 | in_channels = 3 409 | for v in config: 410 | if v == 'M': 411 | layers += [nn.MaxPool2d(kernel_size=2, stride=2)] 412 | else: 413 | kernel_size = 3 414 | if isinstance(v, tuple): 415 | kernel_size, v = v 416 | conv2d = nn.Conv2d(in_channels, v, kernel_size=kernel_size, padding=1) 417 | if batch_norm: 418 | layers += [ 419 | conv2d, 420 | nn.BatchNorm2d(v), 421 | nn.ReLU(inplace=True) 422 | ] 423 | else: 424 | layers += [conv2d, nn.ReLU(inplace=True)] 425 | in_channels = v 426 | return nn.Sequential(*layers) 427 | 428 | def forward(self, x): 429 | for layer_id, layer in enumerate(self.features): 430 | if self.bench is not None and isinstance(layer, nn.Conv2d): 431 | x = self.bench.forward(layer, x, layer_id) 432 | else: 433 | x = layer(x) 434 | 435 | if self.save_features: 436 | if isinstance(layer, nn.ReLU): 437 | self.feats.append(x.clone().detach()) 438 | self.densities.append((x.data != 0.0).sum().item()/x.numel()) 439 | 440 | x = x.view(x.size(0), -1) 441 | x = self.classifier(x) 442 | x = F.log_softmax(x, dim=1) 443 | return x 444 | 445 | class WideResNet(nn.Module): 446 | """Wide Residual Network with varying depth and width. 447 | 448 | For more info, see the paper: Wide Residual Networks by Sergey Zagoruyko, Nikos Komodakis 449 | https://arxiv.org/abs/1605.07146 450 | """ 451 | def __init__(self, depth, widen_factor, num_classes=10, dropRate=0.3, save_features=False, bench_model=False): 452 | super(WideResNet, self).__init__() 453 | nChannels = [16, 16*widen_factor, 32*widen_factor, 64*widen_factor] 454 | assert((depth - 4) % 6 == 0) 455 | n = (depth - 4) / 6 456 | block = BasicBlock 457 | # 1st conv before any network block 458 | self.conv1 = nn.Conv2d(3, nChannels[0], kernel_size=3, stride=1, 459 | padding=1, bias=False) 460 | self.bench = None if not bench_model else SparseSpeedupBench() 461 | # 1st block 462 | self.block1 = NetworkBlock(n, nChannels[0], nChannels[1], block, 1, dropRate, save_features=save_features, bench=self.bench) 463 | # 2nd block 464 | self.block2 = NetworkBlock(n, nChannels[1], nChannels[2], block, 2, dropRate, save_features=save_features, bench=self.bench) 465 | # 3rd block 466 | self.block3 = NetworkBlock(n, nChannels[2], nChannels[3], block, 2, dropRate, save_features=save_features, bench=self.bench) 467 | # global average pooling and classifier 468 | self.bn1 = nn.BatchNorm2d(nChannels[3]) 469 | self.relu = nn.ReLU(inplace=True) 470 | self.fc = nn.Linear(nChannels[3], num_classes) 471 | self.nChannels = nChannels[3] 472 | self.feats = [] 473 | self.densities = [] 474 | self.save_features = save_features 475 | 476 | for m in self.modules(): 477 | if isinstance(m, nn.Conv2d): 478 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 479 | m.weight.data.normal_(0, math.sqrt(2. / n)) 480 | elif isinstance(m, nn.BatchNorm2d): 481 | m.weight.data.fill_(1) 482 | m.bias.data.zero_() 483 | elif isinstance(m, nn.Linear): 484 | m.bias.data.zero_() 485 | 486 | def forward(self, x): 487 | if self.bench is not None: 488 | out = self.bench.forward(self.conv1, x, 'conv1') 489 | else: 490 | out = self.conv1(x) 491 | 492 | out = self.block1(out) 493 | out = self.block2(out) 494 | out = self.block3(out) 495 | 496 | if self.save_features: 497 | # this is a mess, but I do not have time to refactor it now 498 | self.feats += self.block1.feats 499 | self.densities += self.block1.densities 500 | del self.block1.feats[:] 501 | del self.block1.densities[:] 502 | self.feats += self.block2.feats 503 | self.densities += self.block2.densities 504 | del self.block2.feats[:] 505 | del self.block2.densities[:] 506 | self.feats += self.block3.feats 507 | self.densities += self.block3.densities 508 | del self.block3.feats[:] 509 | del self.block3.densities[:] 510 | 511 | out = self.relu(self.bn1(out)) 512 | out = F.avg_pool2d(out, 8) 513 | out = out.view(-1, self.nChannels) 514 | out = self.fc(out) 515 | return F.log_softmax(out, dim=1) 516 | 517 | 518 | class BasicBlock(nn.Module): 519 | """Wide Residual Network basic block 520 | 521 | For more info, see the paper: Wide Residual Networks by Sergey Zagoruyko, Nikos Komodakis 522 | https://arxiv.org/abs/1605.07146 523 | """ 524 | def __init__(self, in_planes, out_planes, stride, dropRate=0.0, save_features=False, bench=None): 525 | super(BasicBlock, self).__init__() 526 | self.bn1 = nn.BatchNorm2d(in_planes) 527 | self.relu1 = nn.ReLU(inplace=True) 528 | self.conv1 = nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, 529 | padding=1, bias=False) 530 | self.bn2 = nn.BatchNorm2d(out_planes) 531 | self.relu2 = nn.ReLU(inplace=True) 532 | self.conv2 = nn.Conv2d(out_planes, out_planes, kernel_size=3, stride=1, 533 | padding=1, bias=False) 534 | self.droprate = dropRate 535 | self.equalInOut = (in_planes == out_planes) 536 | self.convShortcut = (not self.equalInOut) and nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, 537 | padding=0, bias=False) or None 538 | self.feats = [] 539 | self.densities = [] 540 | self.save_features = save_features 541 | self.bench = bench 542 | self.in_planes = in_planes 543 | 544 | def forward(self, x): 545 | conv_layers = [] 546 | if not self.equalInOut: 547 | x = self.relu1(self.bn1(x)) 548 | if self.save_features: 549 | self.feats.append(x.clone().detach()) 550 | self.densities.append((x.data != 0.0).sum().item()/x.numel()) 551 | else: 552 | out = self.relu1(self.bn1(x)) 553 | if self.save_features: 554 | self.feats.append(out.clone().detach()) 555 | self.densities.append((out.data != 0.0).sum().item()/out.numel()) 556 | if self.bench: 557 | out0 = self.bench.forward(self.conv1, (out if self.equalInOut else x), str(self.in_planes) + '.conv1') 558 | else: 559 | out0 = self.conv1(out if self.equalInOut else x) 560 | 561 | out = self.relu2(self.bn2(out0)) 562 | if self.save_features: 563 | self.feats.append(out.clone().detach()) 564 | self.densities.append((out.data != 0.0).sum().item()/out.numel()) 565 | if self.droprate > 0: 566 | out = F.dropout(out, p=self.droprate, training=self.training) 567 | if self.bench: 568 | out = self.bench.forward(self.conv2, out, str(self.in_planes) + '.conv2') 569 | else: 570 | out = self.conv2(out) 571 | 572 | return torch.add(x if self.equalInOut else self.convShortcut(x), out) 573 | 574 | class NetworkBlock(nn.Module): 575 | """Wide Residual Network network block which holds basic blocks. 576 | 577 | For more info, see the paper: Wide Residual Networks by Sergey Zagoruyko, Nikos Komodakis 578 | https://arxiv.org/abs/1605.07146 579 | """ 580 | def __init__(self, nb_layers, in_planes, out_planes, block, stride, dropRate=0.0, save_features=False, bench=None): 581 | super(NetworkBlock, self).__init__() 582 | self.feats = [] 583 | self.densities = [] 584 | self.save_features = save_features 585 | self.bench = bench 586 | self.layer = self._make_layer(block, in_planes, out_planes, nb_layers, stride, dropRate) 587 | 588 | def _make_layer(self, block, in_planes, out_planes, nb_layers, stride, dropRate): 589 | layers = [] 590 | for i in range(int(nb_layers)): 591 | layers.append(block(i == 0 and in_planes or out_planes, out_planes, i == 0 and stride or 1, dropRate, save_features=self.save_features, bench=self.bench)) 592 | return nn.Sequential(*layers) 593 | 594 | def forward(self, x): 595 | for layer in self.layer: 596 | x = layer(x) 597 | if self.save_features: 598 | self.feats += layer.feats 599 | self.densities += layer.densities 600 | del layer.feats[:] 601 | del layer.densities[:] 602 | return x 603 | 604 | -------------------------------------------------------------------------------- /models/model.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from torch.nn import Linear 4 | import torch.nn.functional as F 5 | from torch_geometric.nn import ChebConv, GCNConv, SGConv, APPNP, GCN2Conv, JumpingKnowledge, MessagePassing # noqa 6 | from torch_geometric.nn.conv.gcn_conv import gcn_norm 7 | from models.gat_conv import GATConv 8 | from models.fagcn_conv import FAConv 9 | 10 | from torch_geometric.utils import to_scipy_sparse_matrix 11 | import torch_sparse 12 | from torch_sparse import SparseTensor, matmul 13 | import scipy.sparse 14 | import numpy as np 15 | 16 | 17 | 18 | class GCNNet(torch.nn.Module): 19 | def __init__(self, dataset, args): 20 | super(GCNNet, self).__init__() 21 | 22 | self.args = args 23 | self.conv1 = GCNConv(dataset.num_features, args.dim, cached=False, add_self_loops = True, normalize = True) 24 | self.conv2 = GCNConv(args.dim, dataset.num_classes, cached=False, add_self_loops = True, normalize = True) 25 | 26 | if args.adj_sparse: 27 | self.edge_weight_train = nn.Parameter(torch.randn_like(dataset.edge_index[0].to(torch.float32))) 28 | 29 | if args.feature_sparse: 30 | self.x_weight = nn.Parameter(torch.ones_like(dataset.x)[0].to(torch.float32)) 31 | 32 | def forward(self, data, data_mask=None): 33 | 34 | x, edge_index, edge_weight = data.x, data.edge_index, data.edge_attr 35 | 36 | if self.args.adj_sparse: 37 | edge_mask = torch.abs(self.edge_weight_train) > 0 38 | row, col = data.edge_index 39 | row, col= row[edge_mask], col[edge_mask] 40 | edge_index = torch.stack([row, col], dim=0) 41 | edge_weight = self.edge_weight_train.sigmoid()[edge_mask] 42 | 43 | if self.args.feature_sparse: 44 | x_mask = (torch.abs(self.x_weight) > 0).float() 45 | x_weight = torch.mul(self.x_weight.sigmoid(), x_mask) 46 | x = torch.mul(x,x_weight) 47 | 48 | 49 | x = F.relu(self.conv1(x, edge_index, edge_weight=edge_weight)) 50 | x = F.dropout(x, training=self.training) 51 | x = self.conv2(x, edge_index, edge_weight=edge_weight) 52 | return F.log_softmax(x, dim=1) 53 | 54 | 55 | class GATNet(torch.nn.Module): 56 | def __init__(self, dataset, args): 57 | super(GATNet, self).__init__() 58 | 59 | self.args = args 60 | self.conv1 = GATConv(dataset.num_features, 8, heads=8) 61 | self.conv2 = GATConv(8 * 8, dataset.num_classes, heads=1, concat=False) 62 | 63 | if args.adj_sparse: 64 | self.edge_weight_train = nn.Parameter(torch.randn_like(dataset.edge_index[0].to(torch.float32))) 65 | 66 | if args.feature_sparse: 67 | self.x_weight = nn.Parameter(torch.ones_like(dataset.x)[0].to(torch.float32)) 68 | 69 | 70 | def forward(self, data, data_mask=None): 71 | 72 | x, edge_index, edge_weight = data.x, data.edge_index, data.edge_attr 73 | 74 | if self.args.adj_sparse: 75 | edge_mask = torch.abs(self.edge_weight_train) > 0 76 | row, col = data.edge_index 77 | row, col= row[edge_mask], col[edge_mask] 78 | edge_index = torch.stack([row, col], dim=0) 79 | edge_weight = self.edge_weight_train.sigmoid()[edge_mask] 80 | 81 | if self.args.feature_sparse: 82 | x_mask = (torch.abs(self.x_weight) > 0).float() 83 | x_weight = torch.mul(self.x_weight.sigmoid(), x_mask) 84 | x = torch.mul(x,x_weight) 85 | 86 | x = F.dropout(x, p=0.5, training=self.training) 87 | x = F.elu(self.conv1(x, edge_index)) 88 | x = F.dropout(x, p=0.5, training=self.training) 89 | x = self.conv2(x, edge_index, edge_weight=edge_weight) 90 | return x.log_softmax(dim=-1) 91 | 92 | 93 | class SGCNet(torch.nn.Module): 94 | def __init__(self, dataset, args): 95 | super().__init__() 96 | 97 | self.args = args 98 | self.conv1 = SGConv(dataset.num_features, dataset.num_classes, K=2, 99 | cached=True) 100 | if args.adj_sparse: 101 | self.edge_weight_train = nn.Parameter(torch.randn_like(dataset.edge_index[0].to(torch.float32))) 102 | 103 | if args.feature_sparse: 104 | self.x_weight = nn.Parameter(torch.ones_like(dataset.x)[0].to(torch.float32)) 105 | 106 | def forward(self, data): 107 | x, edge_index, edge_weight = data.x, data.edge_index, data.edge_attr 108 | 109 | if self.args.adj_sparse: 110 | edge_mask = torch.abs(self.edge_weight_train) > 0 111 | row, col = data.edge_index 112 | row, col= row[edge_mask], col[edge_mask] 113 | edge_index = torch.stack([row, col], dim=0) 114 | edge_weight = self.edge_weight_train.sigmoid()[edge_mask] 115 | 116 | if self.args.feature_sparse: 117 | x_mask = (torch.abs(self.x_weight) > 0).float() 118 | x_weight = torch.mul(self.x_weight.sigmoid(), x_mask) 119 | x = torch.mul(x,x_weight) 120 | 121 | x = self.conv1(x, edge_index, edge_weight= edge_weight) 122 | return F.log_softmax(x, dim=1) 123 | 124 | 125 | class APPNPNet(torch.nn.Module): 126 | def __init__(self, dataset, args): 127 | super().__init__() 128 | self.args = args 129 | self.lin1 = Linear(dataset.num_features, args.dim) 130 | self.lin2 = Linear(args.dim, dataset.num_classes) 131 | self.prop1 = APPNP(10, 0.1) 132 | 133 | if args.adj_sparse: 134 | self.edge_weight_train = nn.Parameter(torch.randn_like(dataset.edge_index[0].to(torch.float32))) 135 | 136 | if args.feature_sparse: 137 | self.x_weight = nn.Parameter(torch.ones_like(dataset.x)[0].to(torch.float32)) 138 | 139 | 140 | 141 | def reset_parameters(self): 142 | self.lin1.reset_parameters() 143 | self.lin2.reset_parameters() 144 | 145 | def forward(self, data): 146 | x, edge_index, edge_weight = data.x, data.edge_index, data.edge_attr 147 | 148 | if self.args.adj_sparse: 149 | edge_mask = torch.abs(self.edge_weight_train) > 0 150 | row, col = data.edge_index 151 | row, col= row[edge_mask], col[edge_mask] 152 | edge_index = torch.stack([row, col], dim=0) 153 | edge_weight = self.edge_weight_train.sigmoid()[edge_mask] 154 | 155 | if self.args.feature_sparse: 156 | x_mask = (torch.abs(self.x_weight) > 0).float() 157 | x_weight = torch.mul(self.x_weight.sigmoid(), x_mask) 158 | x = torch.mul(x,x_weight) 159 | 160 | x = F.relu(self.lin1(x)) 161 | x = F.dropout(x, training=self.training) 162 | x = self.lin2(x) 163 | x = self.prop1(x, edge_index, edge_weight=edge_weight) 164 | return F.log_softmax(x, dim=1) 165 | 166 | class GCNIINet(torch.nn.Module): 167 | def __init__(self, dataset, args,): 168 | super().__init__() 169 | # alpha =0.1 170 | # theta =0.5 171 | self.args = args 172 | self.lins = torch.nn.ModuleList() 173 | self.lins.append(Linear(dataset.num_features, args.dim)) 174 | self.lins.append(Linear(args.dim, dataset.num_classes)) 175 | 176 | self.convs = torch.nn.ModuleList() 177 | for layer in range(20): 178 | self.convs.append( 179 | GCN2Conv(args.dim, 0.1, 0.5, layer + 1, 180 | shared_weights =True, normalize=True)) 181 | 182 | if args.adj_sparse: 183 | self.edge_weight_train = nn.Parameter(torch.randn_like(dataset.edge_index[0].to(torch.float32))) 184 | 185 | if args.feature_sparse: 186 | self.x_weight = nn.Parameter(torch.ones_like(dataset.x)[0].to(torch.float32)) 187 | 188 | 189 | 190 | def forward(self, data): 191 | 192 | 193 | x, edge_index = data.x, data.edge_index 194 | 195 | if self.args.feature_sparse: 196 | x_mask = (torch.abs(self.x_weight) > 0).float() 197 | x_weight = torch.mul(self.x_weight.sigmoid(), x_mask) 198 | x = torch.mul(x,x_weight) 199 | 200 | x = F.dropout(x, training=self.training) 201 | x = x_0 = self.lins[0](x).relu() 202 | 203 | if self.args.adj_sparse: 204 | edge_mask = torch.abs(self.edge_weight_train) > 0 205 | row, col = data.edge_index 206 | row, col= row[edge_mask], col[edge_mask] 207 | edge_index = torch.stack([row, col], dim=0) 208 | edge_weight = self.edge_weight_train.sigmoid()[edge_mask] 209 | 210 | 211 | 212 | 213 | 214 | for conv in self.convs: 215 | x = F.dropout(x, training=self.training) 216 | x = conv(x, x_0, edge_index, edge_weight= edge_weight) 217 | x = x.relu() 218 | 219 | x = F.dropout(x, training=self.training) 220 | x = self.lins[1](x) 221 | 222 | return x.log_softmax(dim=-1) 223 | 224 | 225 | class MLP(nn.Module): 226 | """ 227 | 228 | """ 229 | def __init__(self, dataset, args): 230 | super(MLP,self).__init__() 231 | 232 | self.num_layers=2 233 | self.dropout_rate=0.5 234 | 235 | self.lins=nn.ModuleList() 236 | self.lins.append(nn.Linear(dataset.num_features, args.dim)) 237 | for i in range(self.num_layers-2): 238 | self.lins.append(nn.Linear(args.dim, args.dim)) 239 | self.lins.append(nn.Linear(args.dim,dataset.num_classes)) 240 | 241 | self.bns=nn.ModuleList() 242 | for i in range(self.num_layers-1): 243 | self.bns.append(nn.BatchNorm1d(args.dim)) 244 | self.bns.append(nn.BatchNorm1d(dataset.num_classes)) 245 | 246 | def forward(self, data): 247 | x, edge_index, edge_weight = data.x, data.edge_index, data.edge_attr 248 | 249 | for i in range(self.num_layers-1): 250 | x=self.lins[i](x) 251 | x=self.bns[i](x) 252 | x=self.lins[self.num_layers-1](x) 253 | 254 | return F.log_softmax(x, dim=1) 255 | 256 | 257 | class LINK(nn.Module): 258 | """ logistic regression on adjacency matrix """ 259 | 260 | def __init__(self, dataset, args): 261 | super(LINK, self).__init__() 262 | 263 | self.W = nn.Linear(dataset.x.size(0), dataset.num_classes) 264 | 265 | def reset_parameters(self): 266 | self.W.reset_parameters() 267 | 268 | def forward(self, data): 269 | N = data.x.size(0) 270 | edge_index = data.edge_index 271 | if isinstance(edge_index, torch.Tensor): 272 | row, col = edge_index 273 | A = SparseTensor(row=row, col=col, sparse_sizes=(N, N)).to_torch_sparse_coo_tensor() 274 | elif isinstance(edge_index, SparseTensor): 275 | A = edge_index.to_torch_sparse_coo_tensor() 276 | logits = self.W(A) 277 | return F.log_softmax(logits, dim=1) 278 | 279 | 280 | class FAGCN(nn.Module): 281 | def __init__(self, dataset, args): 282 | super(FAGCN, self).__init__() 283 | self.eps = args.fagcn_eps 284 | self.layer_num = args.fagcn_layer_num 285 | self.dropout = args.fagcn_dropout 286 | 287 | self.layers = nn.ModuleList() 288 | for _ in range(self.layer_num): 289 | self.layers.append(FAConv(args.dim, self.eps, self.dropout)) 290 | 291 | self.t1 = nn.Linear(dataset.num_features, args.dim) 292 | self.t2 = nn.Linear(args.dim, dataset.num_classes) 293 | self.reset_parameters() 294 | 295 | def reset_parameters(self): 296 | nn.init.xavier_normal_(self.t1.weight, gain=1.414) 297 | nn.init.xavier_normal_(self.t2.weight, gain=1.414) 298 | 299 | def forward(self, data): 300 | x, edge_index, edge_weight = data.x, data.edge_index, data.edge_attr 301 | h = F.dropout(x, p=self.dropout, training=self.training) 302 | h = torch.relu(self.t1(h)) 303 | h = F.dropout(h, p=self.dropout, training=self.training) 304 | raw = h 305 | for i in range(self.layer_num): 306 | h = self.layers[i](h,raw,edge_index) 307 | h = self.t2(h) 308 | 309 | return F.log_softmax(h, dim=1) 310 | 311 | 312 | class GPR_prop(MessagePassing): 313 | ''' 314 | GPRGNN, from original repo https://github.com/jianhao2016/GPRGNN 315 | propagation class for GPR_GNN 316 | ''' 317 | 318 | def __init__(self, K, alpha, Init, Gamma=None, bias=True, **kwargs): 319 | super(GPR_prop, self).__init__(aggr='add', **kwargs) 320 | self.K = K 321 | self.Init = Init 322 | self.alpha = alpha 323 | 324 | assert Init in ['SGC', 'PPR', 'NPPR', 'Random', 'WS'] 325 | if Init == 'SGC': 326 | # SGC-like 327 | TEMP = 0.0*np.ones(K+1) 328 | TEMP[alpha] = 1.0 329 | elif Init == 'PPR': 330 | # PPR-like 331 | TEMP = alpha*(1-alpha)**np.arange(K+1) 332 | TEMP[-1] = (1-alpha)**K 333 | elif Init == 'NPPR': 334 | # Negative PPR 335 | TEMP = (alpha)**np.arange(K+1) 336 | TEMP = TEMP/np.sum(np.abs(TEMP)) 337 | elif Init == 'Random': 338 | # Random 339 | bound = np.sqrt(3/(K+1)) 340 | TEMP = np.random.uniform(-bound, bound, K+1) 341 | TEMP = TEMP/np.sum(np.abs(TEMP)) 342 | elif Init == 'WS': 343 | # Specify Gamma 344 | TEMP = Gamma 345 | 346 | self.temp = nn.Parameter(torch.tensor(TEMP)) 347 | 348 | def reset_parameters(self): 349 | nn.init.zeros_(self.temp) 350 | for k in range(self.K+1): 351 | self.temp.data[k] = self.alpha*(1-self.alpha)**k 352 | self.temp.data[-1] = (1-self.alpha)**self.K 353 | 354 | def forward(self, x, edge_index, edge_weight=None): 355 | if isinstance(edge_index, torch.Tensor): 356 | edge_index, norm = gcn_norm( 357 | edge_index, edge_weight, num_nodes=x.size(0), dtype=x.dtype) 358 | elif isinstance(edge_index, SparseTensor): 359 | edge_index = gcn_norm( 360 | edge_index, edge_weight, num_nodes=x.size(0), dtype=x.dtype) 361 | norm = None 362 | 363 | hidden = x*(self.temp[0]) 364 | for k in range(self.K): 365 | x = self.propagate(edge_index, x=x, norm=norm) 366 | gamma = self.temp[k+1] 367 | hidden = hidden + gamma*x 368 | return hidden 369 | 370 | def message(self, x_j, norm): 371 | return norm.view(-1, 1) * x_j 372 | 373 | def __repr__(self): 374 | return '{}(K={}, temp={})'.format(self.__class__.__name__, self.K, 375 | self.temp) 376 | 377 | 378 | class GPRGNN(nn.Module): 379 | """GPRGNN, from original repo https://github.com/jianhao2016/GPRGNN""" 380 | 381 | def __init__(self, dataset, args): 382 | super(GPRGNN, self).__init__() 383 | 384 | Init='PPR' 385 | dprate=.5 386 | dropout=.5 387 | K=args.gprgnn_k 388 | alpha= args.gprgnn_alpha 389 | Gamma=None 390 | ppnp='GPR_prop' 391 | self.lin1 = nn.Linear(dataset.num_features, args.dim) 392 | self.lin2 = nn.Linear(args.dim, dataset.num_classes) 393 | 394 | if ppnp == 'PPNP': 395 | self.prop1 = APPNP(K, alpha) 396 | elif ppnp == 'GPR_prop': 397 | self.prop1 = GPR_prop(K, alpha, Init, Gamma) 398 | 399 | self.Init = Init 400 | self.dprate = dprate 401 | self.dropout = dropout 402 | 403 | def reset_parameters(self): 404 | self.lin1.reset_parameters() 405 | self.lin2.reset_parameters() 406 | self.prop1.reset_parameters() 407 | 408 | def forward(self, data): 409 | x, edge_index = data.x, data.edge_index 410 | 411 | x = F.dropout(x, p=self.dropout, training=self.training) 412 | x = F.relu(self.lin1(x)) 413 | x = F.dropout(x, p=self.dropout, training=self.training) 414 | x = self.lin2(x) 415 | 416 | if self.dprate == 0.0: 417 | x = self.prop1(x, edge_index) 418 | return F.log_softmax(x, dim=1) 419 | else: 420 | x = F.dropout(x, p=self.dprate, training=self.training) 421 | x = self.prop1(x, edge_index) 422 | return F.log_softmax(x, dim=1) 423 | 424 | class MixHopLayer(nn.Module): 425 | """ Our MixHop layer """ 426 | def __init__(self, in_channels, out_channels, hops=2): 427 | super(MixHopLayer, self).__init__() 428 | self.hops = hops 429 | self.lins = nn.ModuleList() 430 | for hop in range(self.hops+1): 431 | lin = nn.Linear(in_channels, out_channels) 432 | self.lins.append(lin) 433 | 434 | def reset_parameters(self): 435 | for lin in self.lins: 436 | lin.reset_parameters() 437 | 438 | def forward(self, x, adj_t): 439 | xs = [self.lins[0](x) ] 440 | for j in range(1,self.hops+1): 441 | # less runtime efficient but usually more memory efficient to mult weight matrix first 442 | x_j = self.lins[j](x) 443 | for hop in range(j): 444 | x_j = matmul(adj_t, x_j) 445 | xs += [x_j] 446 | return torch.cat(xs, dim=1) 447 | 448 | class MixHop(nn.Module): 449 | """ our implementation of MixHop 450 | some assumptions: the powers of the adjacency are [0, 1, ..., hops], 451 | with every power in between 452 | each concatenated layer has the same dimension --- hidden_channels 453 | """ 454 | def __init__(self, dataset, args): 455 | super(MixHop, self).__init__() 456 | 457 | num_layers= args.mixhop_layer_num 458 | dropout=args.mixhop_dropout 459 | hops=args.mixhop_hop 460 | 461 | self.convs = nn.ModuleList() 462 | self.convs.append(MixHopLayer(dataset.num_features, args.dim, hops=hops)) 463 | 464 | self.bns = nn.ModuleList() 465 | self.bns.append(nn.BatchNorm1d(args.dim*(hops+1))) 466 | for _ in range(num_layers - 2): 467 | self.convs.append( 468 | MixHopLayer(args.dim*(hops+1), args.dim, hops=hops)) 469 | self.bns.append(nn.BatchNorm1d(args.dim*(hops+1))) 470 | 471 | self.convs.append( 472 | MixHopLayer(args.dim*(hops+1), dataset.num_classes, hops=hops)) 473 | 474 | # note: uses linear projection instead of paper's attention output 475 | self.final_project = nn.Linear(dataset.num_classes*(hops+1), dataset.num_classes) 476 | 477 | self.dropout = dropout 478 | self.activation = F.relu 479 | 480 | def reset_parameters(self): 481 | for conv in self.convs: 482 | conv.reset_parameters() 483 | for bn in self.bns: 484 | bn.reset_parameters() 485 | self.final_project.reset_parameters() 486 | 487 | 488 | def forward(self, data): 489 | x, edge_index = data.x, data.edge_index 490 | n = data.x.size(0) 491 | edge_weight = None 492 | 493 | if isinstance(edge_index, torch.Tensor): 494 | edge_index, edge_weight = gcn_norm( 495 | edge_index, edge_weight, n, False, 496 | dtype=x.dtype) 497 | row, col = edge_index 498 | adj_t = SparseTensor(row=col, col=row, value=edge_weight, sparse_sizes=(n, n)) 499 | elif isinstance(edge_index, SparseTensor): 500 | edge_index = gcn_norm( 501 | edge_index, edge_weight, n, False, 502 | dtype=x.dtype) 503 | edge_weight=None 504 | adj_t = edge_index 505 | 506 | for i, conv in enumerate(self.convs[:-1]): 507 | x = conv(x, adj_t) 508 | x = self.bns[i](x) 509 | x = self.activation(x) 510 | x = F.dropout(x, p=self.dropout, training=self.training) 511 | x = self.convs[-1](x, adj_t) 512 | 513 | x = self.final_project(x) 514 | return x 515 | 516 | 517 | class HGCN(nn.Module): 518 | def __init__(self,dataset, args): 519 | super(HGCN, self).__init__() 520 | self.args = args 521 | self.lin1 = nn.Linear(dataset.num_features, args.dim) 522 | self.lin = nn.Linear(args.dim*5,dataset.num_classes) 523 | 524 | def forward(self,data): 525 | 526 | x, edge_index, edge_weight = data.x, data.edge_index, data.edge_attr 527 | 528 | edge_index, edge_weight = gcn_norm(edge_index, edge_weight) 529 | 530 | adj_coo = to_scipy_sparse_matrix(edge_index) 531 | adj_row = adj_coo.row 532 | adj_col = adj_coo.col 533 | adj_value = edge_weight.detach().cpu().numpy() 534 | # print(adj_value) 535 | # raise exception("pause") 536 | adj_size = adj_coo.shape 537 | edge_index = torch_sparse.SparseTensor(sparse_sizes=[adj_size[0], adj_size[1]], row=torch.tensor(adj_row, dtype=torch.long), 538 | col=torch.tensor(adj_col, dtype=torch.long), 539 | value=torch.tensor(adj_value, dtype=torch.float32)).to(self.args.device) 540 | temp = self.lin1(x) 541 | temp = F.relu(temp) 542 | temp1 = torch_sparse.matmul(edge_index,temp) 543 | temp1 =torch.cat((temp,temp1),dim=1) 544 | temp2 = torch_sparse.matmul(edge_index,temp1) 545 | temp = torch.cat((temp,temp1,temp2),dim=1) 546 | temp = F.dropout(temp,p=self.args.h2gcn_dropout) 547 | ans = self.lin(temp) 548 | 549 | return F.log_softmax(ans, dim=1) 550 | 551 | 552 | class FAGCNNet(nn.Module): 553 | def __init__(self, dataset, args): 554 | super(FAGCNNet, self).__init__() 555 | self.eps = 0.3 556 | self.layer_num = 2 557 | self.dropout = 0.6 558 | self.args = args 559 | 560 | self.layers = nn.ModuleList() 561 | for _ in range(self.layer_num): 562 | self.layers.append(FAConv(args.dim, self.eps, self.dropout)) 563 | 564 | self.t1 = nn.Linear(dataset.num_features, args.dim) 565 | self.t2 = nn.Linear(args.dim, dataset.num_classes) 566 | self.reset_parameters() 567 | 568 | if args.adj_sparse: 569 | self.edge_weight_train = nn.Parameter(torch.randn_like(dataset.edge_index[0].to(torch.float32))) 570 | 571 | def reset_parameters(self): 572 | nn.init.xavier_normal_(self.t1.weight, gain=1.414) 573 | nn.init.xavier_normal_(self.t2.weight, gain=1.414) 574 | 575 | def forward(self, data): 576 | x, edge_index, edge_weight = data.x, data.edge_index, data.edge_attr 577 | 578 | if self.args.adj_sparse: 579 | edge_mask = torch.abs(self.edge_weight_train) > 0 580 | row, col = data.edge_index 581 | row, col= row[edge_mask], col[edge_mask] 582 | edge_index = torch.stack([row, col], dim=0) 583 | edge_weight = self.edge_weight_train.sigmoid()[edge_mask] 584 | 585 | 586 | h = F.dropout(x, p=self.dropout, training=self.training) 587 | h = torch.relu(self.t1(h)) 588 | h = F.dropout(h, p=self.dropout, training=self.training) 589 | raw = h 590 | for i in range(self.layer_num): 591 | h = self.layers[i](h,raw,edge_index, edge_weight) 592 | h = self.t2(h) 593 | 594 | return F.log_softmax(h, dim=1) 595 | 596 | 597 | 598 | class HGCNNet(nn.Module): 599 | def __init__(self,dataset, args): 600 | super(HGCNNet, self).__init__() 601 | self.args = args 602 | self.lin1 = nn.Linear(dataset.num_features, args.dim) 603 | self.lin = nn.Linear(args.dim*5,dataset.num_classes) 604 | 605 | if args.adj_sparse: 606 | self.edge_weight_train = nn.Parameter(torch.randn_like(dataset.edge_index[0].to(torch.float32))) 607 | 608 | 609 | def forward(self,data): 610 | 611 | x, edge_index, edge_weight = data.x, data.edge_index, data.edge_attr 612 | 613 | if self.args.adj_sparse: 614 | edge_mask = torch.abs(self.edge_weight_train) > 0 615 | row, col = data.edge_index 616 | row, col= row[edge_mask], col[edge_mask] 617 | edge_index = torch.stack([row, col], dim=0) 618 | edge_weight = self.edge_weight_train.sigmoid()[edge_mask] 619 | 620 | edge_index, edge_weight = gcn_norm(edge_index, edge_weight, num_nodes = x.size(0)) 621 | 622 | adj_coo = to_scipy_sparse_matrix(edge_index) 623 | adj_row = adj_coo.row 624 | adj_col = adj_coo.col 625 | adj_value = edge_weight.detach().cpu().numpy() 626 | # print(adj_value) 627 | # raise exception("pause") 628 | adj_size = adj_coo.shape 629 | edge_index = torch_sparse.SparseTensor(sparse_sizes=[adj_size[0], adj_size[1]], row=torch.tensor(adj_row, dtype=torch.long), 630 | col=torch.tensor(adj_col, dtype=torch.long), 631 | value=torch.tensor(adj_value, dtype=torch.float32)).to(self.args.device) 632 | temp = self.lin1(x) 633 | temp = F.relu(temp) 634 | temp1 = torch_sparse.matmul(edge_index,temp) 635 | temp1 =torch.cat((temp,temp1),dim=1) 636 | temp2 = torch_sparse.matmul(edge_index,temp1) 637 | temp = torch.cat((temp,temp1,temp2),dim=1) 638 | temp = F.dropout(temp,p=0.5) 639 | ans = self.lin(temp) 640 | 641 | return F.log_softmax(ans, dim=1) 642 | 643 | 644 | 645 | 646 | class GCNmasker(torch.nn.Module): 647 | 648 | def __init__(self, dataset, args): 649 | super(GCNmasker, self).__init__() 650 | 651 | self.conv1 = GCNConv(dataset.num_features, args.masker_dim, cached=False) 652 | self.conv2 = GCNConv(args.masker_dim, args.masker_dim, cached=False) 653 | self.mlp = nn.Linear(args.masker_dim * 2, 1) 654 | self.score_function = args.score_function 655 | self.sigmoid = nn.Sigmoid() 656 | 657 | def forward(self, data): 658 | 659 | x, edge_index, edge_weight = data.x, data.edge_index, data.edge_attr 660 | x = F.relu(self.conv1(x, edge_index, edge_weight)) 661 | x = F.dropout(x, training=self.training) 662 | x = self.conv2(x, edge_index, edge_weight) 663 | 664 | if self.score_function == 'inner_product': 665 | link_score = self.inner_product_score(x, edge_index) 666 | elif self.score_function == 'concat_mlp': 667 | link_score = self.concat_mlp_score(x, edge_index) 668 | else: 669 | assert False 670 | 671 | return link_score 672 | 673 | def inner_product_score(self, x, edge_index): 674 | 675 | row, col = edge_index 676 | link_score = torch.sum(x[row] * x[col], dim=1) 677 | #print("max:{:.2f} min:{:.2f} mean:{:.2f}".format(link_score.max(), link_score.min(), link_score.mean())) 678 | link_score = self.sigmoid(link_score).view(-1) 679 | return link_score 680 | 681 | def concat_mlp_score(self, x, edge_index): 682 | 683 | row, col = edge_index 684 | link_score = torch.cat((x[row], x[col]), dim=1) 685 | link_score = self.mlp(link_score) 686 | # weight = self.mlp.weight 687 | # print("max:{:.2f} min:{:.2f} mean:{:.2f}".format(link_score.max(), link_score.min(), link_score.mean())) 688 | link_score = self.sigmoid(link_score).view(-1) 689 | return link_score -------------------------------------------------------------------------------- /main_stgnn.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | import sys 3 | import os 4 | import os.path as osp 5 | import shutil 6 | import time 7 | import argparse 8 | import logging 9 | import hashlib 10 | import copy 11 | import csv 12 | import numpy as np 13 | import random 14 | 15 | import torch 16 | import torch.nn.functional as F 17 | import torch.optim as optim 18 | import torch.backends.cudnn as cudnn 19 | import torch_geometric.transforms as T 20 | 21 | from torch_geometric.datasets import Planetoid, Reddit, WebKB, Actor, WikipediaNetwork, Coauthor, Amazon, Flickr, WikiCS, Yelp 22 | from torch_geometric.loader import NeighborLoader, RandomNodeSampler 23 | from torch_geometric.transforms import RandomNodeSplit 24 | from torch_geometric.nn.conv.gcn_conv import gcn_norm 25 | 26 | 27 | from ogb.nodeproppred import Evaluator 28 | from ogb.nodeproppred import PygNodePropPredDataset 29 | from torch_scatter import scatter 30 | from torch_geometric.utils import to_undirected, add_self_loops 31 | 32 | import sparselearning 33 | from models import initializers 34 | from sparselearning.core import Masking, CosineDecay, LinearDecay 35 | from models.model import GCNNet, SGCNet, APPNPNet, GCNIINet, GATNet, MLP, FAGCN, HGCN, LINK, GPRGNN, MixHop, FAGCNNet, HGCNNet 36 | 37 | import warnings 38 | warnings.filterwarnings("ignore", category=UserWarning) 39 | 40 | cudnn.benchmark = True 41 | cudnn.deterministic = True 42 | 43 | 44 | if not os.path.exists('./models'): os.mkdir('./models') 45 | if not os.path.exists('./logs'): os.mkdir('./logs') 46 | if not os.path.exists('./results'): os.mkdir('./results') 47 | logger = None 48 | 49 | torch.backends.cudnn.enabled = True 50 | torch.backends.cudnn.benchmark = True 51 | 52 | models = {} 53 | models['gcn'] = (GCNNet) 54 | models['gat'] = (GATNet) 55 | models['sgc'] = (SGCNet) 56 | models['appnp'] = (APPNPNet) 57 | models['gcnii'] = (GCNIINet) 58 | models['mlp'] = (MLP) 59 | models['fagcn'] = (FAGCN) 60 | models['h2gcn'] = (HGCN) 61 | models['link'] = (LINK) 62 | models['gprgnn'] = (GPRGNN) 63 | models['mixhop'] = (MixHop) 64 | models['fagcnnet'] = (FAGCNNet) 65 | models['h2gcnnet'] = (HGCNNet) 66 | 67 | 68 | def save_checkpoint(state, filename='checkpoint.pth.tar'): 69 | print("SAVING") 70 | torch.save(state, filename) 71 | 72 | 73 | def setup_logger(args): 74 | global logger 75 | if logger == None: 76 | logger = logging.getLogger() 77 | else: # wish there was a logger.close() 78 | for handler in logger.handlers[:]: # make a copy of the list 79 | logger.removeHandler(handler) 80 | 81 | args_copy = copy.deepcopy(args) 82 | # copy to get a clean hash 83 | # use the same log file hash if iterations or verbose are different 84 | # these flags do not change the results 85 | args_copy.iters = 1 86 | args_copy.verbose = False 87 | args_copy.log_interval = 1 88 | args_copy.seed = 0 89 | 90 | if args.weight_sparse and not args.adj_sparse and not args.feature_sparse : 91 | sparse_way = 'w' 92 | log_path = './logs/{0}/{1}_{2}_{3}_{4}.log'.format( 93 | sparse_way,args.model,args.data, args.final_density, hashlib.md5(str(args_copy).encode('utf-8')).hexdigest()[:8]) 94 | elif args.adj_sparse and not args.weight_sparse and not args.feature_sparse : 95 | sparse_way = 'a' 96 | log_path = './logs/{0}/{1}_{2}_{3}_{4}.log'.format( 97 | sparse_way,args.model,args.data, args.final_density_adj, hashlib.md5(str(args_copy).encode('utf-8')).hexdigest()[:8]) 98 | elif args.feature_sparse and not args.weight_sparse and not args.adj_sparse : 99 | sparse_way = 'f' 100 | log_path = './logs/{0}/{1}_{2}_{3}_{4}.log'.format( 101 | sparse_way,args.model,args.data, args.final_density_feature, hashlib.md5(str(args_copy).encode('utf-8')).hexdigest()[:8]) 102 | elif args.weight_sparse and args.adj_sparse and not args.feature_sparse : 103 | sparse_way = 'wa' 104 | log_path = './logs/{0}/{1}_{2}_{3}_{4}_{5}.log'.format( 105 | sparse_way,args.model,args.data, args.final_density, args.final_density_adj, hashlib.md5(str(args_copy).encode('utf-8')).hexdigest()[:8]) 106 | elif args.weight_sparse and args.feature_sparse and not args.adj_sparse : 107 | sparse_way = 'wf' 108 | log_path = './logs/{0}/{1}_{2}_{3}_{4}_{5}.log'.format( 109 | sparse_way,args.model,args.data, args.final_density, args.final_density_feature, hashlib.md5(str(args_copy).encode('utf-8')).hexdigest()[:8]) 110 | elif args.adj_sparse and args.feature_sparse and not args.weight_sparse : 111 | sparse_way = 'af' 112 | log_path = './logs/{0}/{1}_{2}_{3}_{4}_{5}.log'.format( 113 | sparse_way,args.model,args.data, args.final_density_adj, args.final_density_feature, hashlib.md5(str(args_copy).encode('utf-8')).hexdigest()[:8]) 114 | elif args.weight_sparse and args.adj_sparse and args.feature_sparse: 115 | sparse_way = 'waf' 116 | log_path = './logs/{0}/{1}_{2}_{3}_{4}_{5}_{6}.log'.format( 117 | sparse_way,args.model,args.data, args.final_density, args.final_density_adj, args.final_density_feature, hashlib.md5(str(args_copy).encode('utf-8')).hexdigest()[:8]) 118 | else: 119 | sparse_way = 'base' 120 | log_path = './logs/{0}/{1}_{2}_{3}.log'.format( 121 | sparse_way,args.model,args.data, hashlib.md5(str(args_copy).encode('utf-8')).hexdigest()[:8]) 122 | 123 | if not os.path.exists('./logs/{}'.format(sparse_way)): os.mkdir('./logs/{}'.format(sparse_way)) 124 | 125 | #log_path = './logs/{0}/{1}_{2}_{3}_{4}_{5}.log'.format(sparse_way, args.model, args.data, args.final_density, args.final_density_adj, hashlib.md5(str(args_copy).encode('utf-8')).hexdigest()[:8]) 126 | 127 | logger.setLevel(logging.INFO) 128 | formatter = logging.Formatter(fmt='%(asctime)s: %(message)s', datefmt='%H:%M:%S') 129 | 130 | fh = logging.FileHandler(log_path) 131 | fh.setFormatter(formatter) 132 | logger.addHandler(fh) 133 | 134 | def print_and_log(msg): 135 | global logger 136 | print(msg) 137 | logger.info(msg) 138 | 139 | 140 | def results_to_file(args, train_acc, val_acc, test_acc, train_time, test_time): 141 | 142 | 143 | if args.weight_sparse and not args.adj_sparse and not args.feature_sparse : 144 | sparse_way = 'w' 145 | filename = "./results/{}/{}_{}_{}_{}_result.csv".format( 146 | sparse_way, args.model, args.data, args.init_density, args.final_density) 147 | elif args.adj_sparse and not args.weight_sparse and not args.feature_sparse : 148 | sparse_way = 'a' 149 | filename = "./results/{}/{}_{}_{}_{}_result.csv".format( 150 | sparse_way, args.model, args.data, args.final_density_adj) 151 | elif args.feature_sparse and not args.weight_sparse and not args.adj_sparse : 152 | sparse_way = 'f' 153 | filename = "./results/{}/{}_{}_{}_{}_result.csv".format( 154 | sparse_way, args.model, args.data, args.final_density_feature) 155 | elif args.weight_sparse and args.adj_sparse and not args.feature_sparse : 156 | sparse_way = 'wa' 157 | filename = "./results/{}/{}_{}_{}_{}_result.csv".format( 158 | sparse_way, args.model, args.data, args.final_density, args.final_density_adj) 159 | elif args.weight_sparse and args.feature_sparse and not args.adj_sparse : 160 | sparse_way = 'wf' 161 | filename = "./results/{}/{}_{}_{}_{}_result.csv".format( 162 | sparse_way, args.model, args.data, args.final_density, args.final_density_feature) 163 | elif args.adj_sparse and args.feature_sparse and not args.weight_sparse : 164 | sparse_way = 'af' 165 | filename = "./results/{}/{}_{}_{}_{}_result.csv".format( 166 | sparse_way, args.model, args.data, args.final_density_adj, args.final_density_feature) 167 | elif args.weight_sparse and args.adj_sparse and args.feature_sparse: 168 | sparse_way = 'waf' 169 | filename = "./results/{}/{}_{}_{}_{}_{}_result.csv".format( 170 | sparse_way, args.model, args.data, args.final_density, args.final_density_adj, args.final_density_feature) 171 | else: 172 | sparse_way = 'base' 173 | filename = "./results/{}/{}_{}_result.csv".format( 174 | sparse_way, args.model, args.data) 175 | 176 | if not os.path.exists('./results/{}'.format(sparse_way)): os.mkdir('./results/{}'.format(sparse_way)) 177 | 178 | headerList = ["Method","Growth","Prune Rate", "Update Frequency", "Final Prune Epoch", "::", "train_acc", "val_acc", "test_acc", "train_time", "test_time"] 179 | 180 | #filename = "./results/{}/{}_{}_{}_{}_result.csv".format(sparse_way, args.model, args.data, args.final_density, args.final_density_adj) 181 | with open(filename, "a+") as f: 182 | 183 | # reader = csv.reader(f) 184 | # row1 = next(reader) 185 | f.seek(0) 186 | header = f.read(6) 187 | if header != "Method": 188 | dw = csv.DictWriter(f, delimiter=',', 189 | fieldnames=headerList) 190 | dw.writeheader() 191 | 192 | line = "{}, {}, {}, {}, {}, :::, {:.4f}, {:.4f}, {:.4f},{:.4f}, {:.4f}\n".format( 193 | args.method, args.growth_schedule, args.prune_rate, args.update_frequency, args.final_prune_epoch, train_acc, val_acc, test_acc, train_time, test_time 194 | ) 195 | f.write(line) 196 | 197 | 198 | 199 | def train(args, model, device, data, optimizer, epoch, mask=None): 200 | model.train() 201 | train_loss = 0 202 | correct = 0 203 | n = 0 204 | criterion = torch.nn.BCEWithLogitsLoss() 205 | 206 | data = data.to(device) 207 | 208 | target = data.y[data.train_mask].to(device) 209 | 210 | 211 | if args.fp16: data = data.half() 212 | 213 | optimizer.zero_grad() 214 | 215 | output = model(data)[data.train_mask] 216 | if args.data in ['ogbn-proteins']: 217 | loss = criterion(output, target) 218 | acc = 0.0 219 | else: 220 | loss = F.nll_loss(output, target) 221 | pred = output.max(1)[1] 222 | acc = pred.eq(data.y[data.train_mask]).sum().item() / data.train_mask.sum().item() 223 | 224 | if args.fp16: 225 | optimizer.backward(loss) 226 | else: 227 | loss.backward() 228 | 229 | if mask is not None: 230 | #print("Mask!!!!") 231 | mask.step() 232 | else: 233 | optimizer.step() 234 | 235 | # print_and_log('\n{}: Average loss: {:.4f}, Accuracy: {} \n'.format( 236 | # 'Training summary', loss, acc,)) 237 | 238 | def evaluate(args, model, device, data, is_test_set=False): 239 | model.eval() 240 | test_loss = 0 241 | correct = 0 242 | n = 0 243 | with torch.no_grad(): 244 | #target = data.y[data.train_mask].to(device) 245 | data = data.to(device) 246 | 247 | if args.fp16: data = data.half() 248 | 249 | logits, accs = model(data), [] 250 | 251 | for _, mask in data('train_mask', 'val_mask', 'test_mask'): 252 | pred = logits[mask].max(1)[1] 253 | acc = pred.eq(data.y[mask]).sum().item() / mask.sum().item() 254 | accs.append(acc) 255 | train_acc, val_acc, tmp_test_acc = accs 256 | 257 | # print_and_log('\n{}: Train Accuracy: {:.4f}, Val Accuracy: {} Test Accuracy: {}\n'.format( 258 | # 'Test evaluation' if is_test_set else 'Evaluation', 259 | # train_acc, val_acc, tmp_test_acc)) 260 | return train_acc, val_acc, tmp_test_acc 261 | 262 | def evaluate_ogb(args, model, device, data, evaluator, is_test_set=False): 263 | 264 | model.eval() 265 | test_loss = 0 266 | correct = 0 267 | n = 0 268 | with torch.no_grad(): 269 | #target = data.y[data.train_mask].to(device) 270 | data = data.to(device) 271 | 272 | if args.fp16: data = data.half() 273 | 274 | out = model(data) 275 | 276 | if args.data in ['ogbn-arxiv','ogbn-products' ]: 277 | y_pred = out.argmax(dim=-1, keepdim=True) 278 | 279 | train_acc = evaluator.eval({ 280 | 'y_true': data.y.unsqueeze(-1)[data.train_mask], 281 | 'y_pred': y_pred[data.train_mask], 282 | })['acc'] 283 | val_acc = evaluator.eval({ 284 | 'y_true': data.y.unsqueeze(-1)[data.valid_mask], 285 | 'y_pred': y_pred[data.valid_mask], 286 | })['acc'] 287 | tmp_test_acc = evaluator.eval({ 288 | 'y_true': data.y.unsqueeze(-1)[data.test_mask], 289 | 'y_pred': y_pred[data.test_mask], 290 | })['acc'] 291 | print_and_log('\n{}: Train Accuracy: {:.4f}, Val Accuracy: {} Test Accuracy: {}\n'.format( 292 | 'Test evaluation' if is_test_set else 'Evaluation', 293 | train_acc, val_acc, tmp_test_acc)) 294 | 295 | elif args.data in ['ogbn-proteins']: 296 | 297 | train_acc = evaluator.eval({ 298 | 'y_true': data.y[data.train_mask], 299 | 'y_pred': out[data.train_mask], 300 | })['rocauc'] # Acutually roc-auc, only name it train_acc 301 | val_acc = evaluator.eval({ 302 | 'y_true': data.y[data.valid_mask], 303 | 'y_pred': out[data.valid_mask], 304 | })['rocauc'] 305 | tmp_test_acc = evaluator.eval({ 306 | 'y_true': data.y[data.test_mask], 307 | 'y_pred': out[data.test_mask], 308 | })['rocauc'] 309 | print_and_log('\n{}: Train ROC-AUC: {:.4f}, Val ROC-AUC: {} Test ROC-AUC: {}\n'.format( 310 | 'Test evaluation' if is_test_set else 'Evaluation', 311 | train_acc, val_acc, tmp_test_acc)) 312 | 313 | return train_acc, val_acc, tmp_test_acc 314 | 315 | 316 | def main(): 317 | # Training settings 318 | parser = argparse.ArgumentParser(description='PyTorch GraNet for sparse training') 319 | parser.add_argument('--batch-size', type=int, default=100, metavar='N', 320 | help='input batch size for training (default: 100)') 321 | parser.add_argument('--batch-size-jac', type=int, default=200, metavar='N', 322 | help='batch size for jac (default: 1000)') 323 | parser.add_argument('--test-batch-size', type=int, default=100, metavar='N', 324 | help='input batch size for testing (default: 100)') 325 | parser.add_argument('--multiplier', type=int, default=1, metavar='N', 326 | help='extend training time by multiplier times') 327 | parser.add_argument('--epochs', type=int, default=250, metavar='N', 328 | help='number of epochs to train (default: 100)') 329 | parser.add_argument('--lr', type=float, default=0.1, metavar='LR', 330 | help='learning rate (default: 0.1)') 331 | parser.add_argument('--momentum', type=float, default=0.9, metavar='M', 332 | help='SGD momentum (default: 0.9)') 333 | parser.add_argument('--no-cuda', action='store_true', default=False, 334 | help='disables CUDA training') 335 | parser.add_argument('--seed', type=int, default=17, metavar='S', help='random seed (default: 17)') 336 | parser.add_argument('--log-interval', type=int, default=100, metavar='N', 337 | help='how many batches to wait before logging training status') 338 | parser.add_argument('--optimizer', type=str, default='adam', help='The optimizer to use. Default: sgd. Options: sgd, adam.') 339 | randomhash = ''.join(str(time.time()).split('.')) 340 | parser.add_argument('--save', type=str, default=randomhash + '.pt', 341 | help='path to save the final model') 342 | parser.add_argument('--data', type=str, default='mnist') 343 | parser.add_argument('--decay_frequency', type=int, default=25000) 344 | parser.add_argument('--l1', type=float, default=0.0) 345 | parser.add_argument('--fp16', action='store_true', help='Run in fp16 mode.') 346 | parser.add_argument('--valid_split', type=float, default=0.1) 347 | parser.add_argument('--resume', type=str) 348 | parser.add_argument('--start-epoch', type=int, default=1) 349 | parser.add_argument('--model', type=str, default='') 350 | parser.add_argument('--l2', type=float, default=1.0e-4) 351 | parser.add_argument('--iters', type=int, default=1, help='How many times the model should be run after each other. Default=1') 352 | parser.add_argument('--save-features', action='store_true', help='Resumes a saved model and saves its feature data to disk for plotting.') 353 | parser.add_argument('--bench', action='store_true', help='Enables the benchmarking of layers and estimates sparse speedups') 354 | parser.add_argument('--max-threads', type=int, default=10, help='How many threads to use for data loading.') 355 | parser.add_argument('--decay-schedule', type=str, default='cosine', help='The decay schedule for the pruning rate. Default: cosine. Choose from: cosine, linear.') 356 | parser.add_argument('--growth_schedule', type=str, default='gradient', help='The growth schedule. Default: gradient. Choose from: gradient, momentum, random.') 357 | parser.add_argument('--lr_scheduler', action='store_true', default=False, 358 | help='disables CUDA training') 359 | parser.add_argument('--adj_sparse', action='store_true', help='If Sparse Adj.') 360 | parser.add_argument('--feature_sparse', action='store_true', help='If Sparse Weight.') 361 | parser.add_argument('--weight_sparse', action='store_true', help='If Sparse Feature.',) 362 | parser.add_argument('--dim', type=int, default=512) 363 | parser.add_argument('--cuda', type=int, default=0) 364 | 365 | 366 | # FAGCN 367 | parser.add_argument('--fagcn_layer_num', type=int, default=1) 368 | parser.add_argument('--fagcn_dropout', type=float, default=0) 369 | parser.add_argument('--fagcn_eps', type=float, default=0.1) 370 | 371 | # MixHop 372 | parser.add_argument('--mixhop_layer_num', type=int, default=1) 373 | parser.add_argument('--mixhop_dropout', type=float, default=0) 374 | parser.add_argument('--mixhop_hop', type=int, default=2) 375 | 376 | # GPRGNN 377 | parser.add_argument('--gprgnn_alpha', type=float, default=0.1) 378 | parser.add_argument('--gprgnn_k', type=int, default=10) 379 | 380 | # H2GCN 381 | 382 | parser.add_argument('--h2gcn_dropout', type=float, default=0.1) 383 | 384 | 385 | sparselearning.core.add_sparse_args(parser) 386 | 387 | args = parser.parse_args() 388 | setup_logger(args) 389 | print_and_log(args) 390 | 391 | if args.fp16: 392 | try: 393 | from apex.fp16_utils import FP16_Optimizer 394 | except: 395 | print('WARNING: apex not installed, ignoring --fp16 option') 396 | args.fp16 = False 397 | 398 | use_cuda = not args.no_cuda and torch.cuda.is_available() 399 | args.device = torch.device('cuda:{}'.format(args.cuda) if use_cuda else "cpu") 400 | 401 | 402 | 403 | print_and_log('\n\n') 404 | print_and_log('='*80) 405 | np.random.seed(args.seed) 406 | torch.manual_seed(args.seed) 407 | random.seed(args.seed) 408 | if torch.cuda.is_available(): 409 | torch.cuda.manual_seed(args.seed) 410 | torch.cuda.manual_seed_all(args.seed) 411 | for i in range(args.iters): 412 | 413 | ####################################################################################### 414 | ############################# Datasets ################################################ 415 | ####################################################################################### 416 | print_and_log("\nIteration start: {0}/{1}\n".format(i+1, args.iters)) 417 | 418 | 419 | if args.data in ['cora','citeseer','pubmed']: 420 | path = osp.join(osp.dirname(osp.realpath(__file__)), '../data', args.data) 421 | dataset = Planetoid(path, args.data, transform=T.NormalizeFeatures()) 422 | data = dataset[0] 423 | 424 | data.num_classes = dataset.num_classes 425 | data.num_edges_orig = data.num_edges 426 | #print_and_log(data) 427 | #raise Exception('pause!! ') 428 | 429 | elif args.data in ["Cornell", "Texas", "Wisconsin"] : 430 | path = os.path.join(os.path.dirname(os.path.realpath(__file__)), '../data', args.data) 431 | dataset = WebKB(path,args.data, transform=T.NormalizeFeatures()) 432 | data = dataset[0] 433 | data.num_classes = dataset.num_classes 434 | print_and_log(data) 435 | data.train_mask = data.train_mask[:, args.seed % 10] 436 | data.val_mask = data.val_mask[:, args.seed % 10] 437 | data.test_mask = data.test_mask[:, args.seed % 10] 438 | #print_and_log(data) 439 | #raise Exception('pause!! ') 440 | 441 | elif args.data in ["Actor"] : 442 | path = os.path.join(os.path.dirname(os.path.realpath(__file__)), '../data', args.data) 443 | dataset = Actor(path, transform=T.NormalizeFeatures()) 444 | data = dataset[0] 445 | data.num_classes = dataset.num_classes 446 | 447 | #print_and_log(data) 448 | data.train_mask = data.train_mask[:, args.seed % 10] 449 | data.val_mask = data.val_mask[:, args.seed % 10] 450 | data.test_mask = data.test_mask[:, args.seed % 10] 451 | #print_and_log(data) 452 | #raise Exception('pause!! ') 453 | 454 | elif args.data in ["chameleon", "crocodile", "squirrel"] : 455 | path = os.path.join(os.path.dirname(os.path.realpath(__file__)), '../data', args.data) 456 | dataset = WikipediaNetwork(path, args.data, transform=T.NormalizeFeatures()) 457 | data = dataset[0] 458 | data.num_classes = dataset.num_classes 459 | 460 | data.train_mask = data.train_mask[:, args.seed % 10] 461 | data.val_mask = data.val_mask[:, args.seed % 10] 462 | data.test_mask = data.test_mask[:, args.seed % 10] 463 | #print_and_log(data) 464 | 465 | # Multi Spilt 466 | raise Exception('pause!! ') 467 | 468 | 469 | elif args.data in ["CS", "Physics"] : 470 | path = os.path.join(os.path.dirname(os.path.realpath(__file__)), '../data', args.data) 471 | dataset = Coauthor(path, args.data, transform=T.NormalizeFeatures()) 472 | data = dataset[0] 473 | data.num_classes = dataset.num_classes 474 | transform = RandomNodeSplit(split= "test_rest", 475 | num_train_per_class = 20, 476 | num_val = 30* data.num_classes,) 477 | transform(data) 478 | #print_and_log(data) 479 | #raise Exception('pause!! ') 480 | 481 | elif args.data in ["Computers", "Photo"] : 482 | path = os.path.join(os.path.dirname(os.path.realpath(__file__)), '../data', args.data) 483 | dataset = Amazon(path, args.data, transform=T.NormalizeFeatures()) 484 | data = dataset[0] 485 | data.num_classes = dataset.num_classes 486 | transform = RandomNodeSplit(split= "test_rest", 487 | num_train_per_class = 20, 488 | num_val = 30* data.num_classes,) 489 | transform(data) 490 | #print_and_log(data) 491 | 492 | #raise Exception('pause!! ') 493 | 494 | 495 | elif args.data in ["Flickr"] : 496 | path = os.path.join(os.path.dirname(os.path.realpath(__file__)), '../data', args.data) 497 | dataset = Flickr(path) 498 | data = dataset[0] 499 | data.num_classes = dataset.num_classes 500 | print_and_log(data) 501 | # Cannot load file containing pickled data when allow_pickle=False 502 | raise Exception('pause!! ') 503 | 504 | elif args.data in ["Yelp"] : 505 | path = os.path.join(os.path.dirname(os.path.realpath(__file__)), '../data', args.data) 506 | dataset = Yelp(path) 507 | data = dataset[0] 508 | data.num_classes = dataset.num_classes 509 | print_and_log(data) 510 | # Cannot load file containing pickled data when allow_pickle=False 511 | # Fix: allow_pickle=True) 512 | raise Exception('pause!! ') 513 | 514 | 515 | elif args.data in ["WikiCS"] : 516 | path = os.path.join(os.path.dirname(os.path.realpath(__file__)), '../data', args.data) 517 | dataset = WikiCS(path, transform=T.NormalizeFeatures()) 518 | data = dataset[0] 519 | data.num_classes = dataset.num_classes 520 | 521 | data.stopping_mask = None 522 | data.train_mask = data.train_mask[:, args.seed % 20] 523 | data.val_mask = data.val_mask[:, args.seed % 20] 524 | print_and_log(data) 525 | 526 | elif args.data == 'reddit': 527 | print("Loading Reddit .....") 528 | path = os.path.join(os.path.dirname(os.path.realpath(__file__)), '../data', 'Reddit') 529 | dataset = Reddit(path, transform=T.NormalizeFeatures()) 530 | 531 | data = dataset[0] 532 | data.num_classes = dataset.num_classes 533 | #kwargs = {'batch_size': 1024, 'num_workers': 6, 'persistent_workers': True} 534 | print_and_log(data) 535 | 536 | print("Load Reddit Done!") 537 | 538 | elif args.data in['ogbn-arxiv', 'ogbn-products', 'ogbn-proteins', 'ogbn-papers100M']: 539 | 540 | print("Loading Dataset: {}".format(args.data)) 541 | 542 | dataset = PygNodePropPredDataset(name=args.data, root='../data') 543 | data = dataset[0] 544 | split_idx = dataset.get_idx_split() 545 | evaluator = Evaluator(args.data) 546 | 547 | edge_index = to_undirected(data.edge_index, data.num_nodes) 548 | #edge_index = add_self_loops(edge_index, num_nodes=data.num_nodes)[0] 549 | 550 | data.edge_index = edge_index 551 | for split in ['train', 'valid', 'test']: 552 | mask = torch.zeros(data.num_nodes, dtype=torch.bool) 553 | mask[split_idx[split]] = True 554 | data[f'{split}_mask'] = mask 555 | 556 | if args.data in ['ogbn-proteins']: 557 | data.y = data.y.to(torch.float) 558 | data.num_classes = dataset.num_tasks 559 | data.node_species = None 560 | row, col = data.edge_index 561 | data.x = scatter(data.edge_attr, col, 0, dim_size=data.num_nodes, reduce='add') 562 | else: 563 | data.num_classes = dataset.num_classes 564 | 565 | if args.data in ['ogbn-arxiv']: 566 | data.y = data.y.squeeze(1) 567 | 568 | print("Load Done !") 569 | print_and_log(data) 570 | 571 | ####################################################################################### 572 | ############################# Models ################################################ 573 | ####################################################################################### 574 | 575 | if args.model not in models: 576 | print('You need to select an existing model via the --model argument. Available models include: ') 577 | for key in models: 578 | print('\t{0}'.format(key)) 579 | raise Exception('You need to select a model') 580 | else: 581 | 582 | if args.model == 'gcn': 583 | model = GCNNet(data, args).to(args.device) 584 | 585 | elif args.model == 'sgc': 586 | model = SGCNet(data, args).to(args.device) 587 | 588 | elif args.model == 'appnp': 589 | model = APPNPNet(data, args).to(args.device) 590 | 591 | elif args.model == 'gat': 592 | model = GATNet(data, args).to(args.device) 593 | 594 | elif args.model == 'gcnii': 595 | model = GCNIINet(data, args).to(args.device) 596 | 597 | elif args.model == 'mlp': 598 | model = MLP(data, args).to(args.device) 599 | 600 | elif args.model == 'fagcn': 601 | model = FAGCN(data, args).to(args.device) 602 | 603 | elif args.model == 'h2gcn': 604 | model = HGCN(data, args).to(args.device) 605 | 606 | elif args.model == 'link': 607 | model = LINK(data, args).to(args.device) 608 | 609 | elif args.model == 'gprgnn': 610 | model = GPRGNN(data, args).to(args.device) 611 | 612 | elif args.model == 'mixhop': 613 | model = MixHop(data, args).to(args.device) 614 | 615 | elif args.model == 'fagcnnet': 616 | model = FAGCNNet(data, args).to(args.device) 617 | 618 | elif args.model == 'h2gcnnet': 619 | model = HGCNNet(data, args).to(args.device) 620 | 621 | else: 622 | cls, cls_args = models[args.model] 623 | if args.data == 'cifar100': 624 | cls_args[2] = 100 625 | model = cls(*(cls_args + [args.save_features, args.bench])).to(args.device) 626 | print_and_log(model) 627 | print_and_log('='*60) 628 | print_and_log(args.model) 629 | print_and_log('='*60) 630 | 631 | print_and_log('='*60) 632 | print_and_log('Prune mode: {0}'.format(args.prune)) 633 | print_and_log('Growth mode: {0}'.format(args.growth)) 634 | print_and_log('Redistribution mode: {0}'.format(args.redistribution)) 635 | print_and_log('='*60) 636 | 637 | 638 | optimizer = None 639 | if args.optimizer == 'sgd': 640 | optimizer = optim.SGD(model.parameters(),lr=args.lr,momentum=args.momentum,weight_decay=args.l2, nesterov=True) 641 | elif args.optimizer == 'adam': 642 | optimizer = optim.Adam(model.parameters(),lr=args.lr,weight_decay=args.l2) 643 | else: 644 | print('Unknown optimizer: {0}'.format(args.optimizer)) 645 | raise Exception('Unknown optimizer.') 646 | 647 | if args.lr_scheduler: 648 | lr_scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, milestones=[int(args.epochs / 2) * args.multiplier, int(args.epochs * 3 / 4) * args.multiplier], last_epoch=-1) 649 | else: 650 | lr_scheduler = None 651 | 652 | 653 | if args.resume: 654 | if os.path.isfile(args.resume): 655 | print_and_log("=> loading checkpoint '{}'".format(args.resume)) 656 | checkpoint = torch.load(args.resume) 657 | model.load_state_dict(checkpoint) 658 | original_acc = evaluate(args, model, args.device, test_loader) 659 | 660 | 661 | if args.fp16: 662 | print('FP16') 663 | optimizer = FP16_Optimizer(optimizer, 664 | static_loss_scale = None, 665 | dynamic_loss_scale = True, 666 | dynamic_loss_args = {'init_scale': 2 ** 16}) 667 | model = model.half() 668 | 669 | 670 | mask = None 671 | if args.sparse: 672 | decay = CosineDecay(args.prune_rate, (args.epochs*args.multiplier)) 673 | mask = Masking(optimizer, prune_rate=args.prune_rate, death_mode=args.prune, prune_rate_decay=decay, growth_mode=args.growth, 674 | redistribution_mode=args.redistribution, args=args, train_loader=None, device =args.device) 675 | mask.add_module(model, sparse_init=args.sparse_init) 676 | 677 | best_acc = 0.0 678 | t_start = time.time() 679 | for epoch in range(1, args.epochs*args.multiplier + 1): 680 | 681 | #print_and_log("Epoch:{}".format(epoch)) 682 | #print("="*50) 683 | 684 | # save models 685 | save_path = './save/' + str(args.model) + '/' + str(args.data) + '/' + str(args.method) + '/' + str(args.seed) 686 | save_subfolder = os.path.join(save_path, 'Multiplier=' + str(args.multiplier) + '_sparsity' + str(1-args.final_density)) 687 | if not os.path.exists(save_subfolder): os.makedirs(save_subfolder) 688 | 689 | t0 = time.time() 690 | #print(mask) 691 | 692 | train(args, model, args.device, data, optimizer, epoch, mask) 693 | 694 | 695 | if lr_scheduler is not None: 696 | lr_scheduler.step() 697 | if args.valid_split > 0.0: 698 | if args.data in['ogbn-arxiv', 'ogbn-products', 'ogbn-proteins', 'ogbn-ogbn-papers100M']: 699 | _, val_acc, _ = evaluate_ogb(args, model, args.device, data, evaluator) 700 | else: 701 | _, val_acc, _ = evaluate(args, model, args.device, data) 702 | 703 | # target sparsity is reached 704 | if args.sparse: 705 | if epoch == args.multiplier * args.final_prune_epoch+1: 706 | best_acc = 0.0 707 | 708 | if val_acc > best_acc: 709 | print('Saving model') 710 | best_acc = val_acc 711 | save_checkpoint({ 712 | 'epoch': epoch + 1, 713 | 'state_dict': model.state_dict(), 714 | 'optimizer': optimizer.state_dict(), 715 | }, filename=os.path.join(save_subfolder, 'model_final.pth')) 716 | 717 | #print_and_log(' Time taken for epoch: {:.2f} seconds.\n'.format(time.time() - t0)) 718 | 719 | train_time_total = time.time() - t_start 720 | print('Testing model') 721 | model.load_state_dict(torch.load(os.path.join(save_subfolder, 'model_final.pth'))['state_dict']) 722 | 723 | t_test_0 = time.time() 724 | if args.data in['ogbn-arxiv', 'ogbn-products', 'ogbn-proteins', 'ogbn-ogbn-papers100M']: 725 | train_acc, val_acc, test_acc = evaluate_ogb(args, model, args.device, data, evaluator, is_test_set=True) 726 | else: 727 | train_acc, val_acc, test_acc = evaluate(args, model, args.device, data, is_test_set=True) 728 | print('Test accuracy is:', test_acc) 729 | results_to_file(args, train_acc, val_acc, test_acc, train_time_total, time.time()- t_test_0) 730 | 731 | 732 | 733 | if __name__ == '__main__': 734 | print("Start Runing!") 735 | main() 736 | -------------------------------------------------------------------------------- /sparselearning/core.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | import torch 3 | import torch.optim as optim 4 | import numpy as np 5 | import math 6 | 7 | # use_cuda = torch.cuda.is_available() 8 | # device = torch.device("cuda" if use_cuda else "cpu") 9 | 10 | def add_sparse_args(parser): 11 | # hyperparameters for Zero-Cost Neuroregeneration 12 | parser.add_argument('--growth', type=str, default='gradient', help='Growth mode. Choose from: momentum, random, and momentum_neuron.') 13 | parser.add_argument('--prune', type=str, default='magnitude', help='Death mode / pruning mode. Choose from: magnitude, SET, threshold, CS_death.') 14 | parser.add_argument('--redistribution', type=str, default='none', help='Redistribution mode. Choose from: momentum, magnitude, nonzeros, or none.') 15 | parser.add_argument('--prune-rate', type=float, default=0.50, help='The pruning rate / death rate for Zero-Cost Neuroregeneration.') 16 | parser.add_argument('--pruning-rate', type=float, default=0.50, help='The pruning rate / death rate.') 17 | parser.add_argument('--sparse', action='store_true', help='Enable sparse mode. Default: True.') 18 | parser.add_argument('--fix', action='store_true', help='Fix topology during training. Default: True.') 19 | parser.add_argument('--update-frequency', type=int, default=100, metavar='N', help='how many iterations to train between mask update') 20 | parser.add_argument('--sparse-init', type=str, default='ERK, uniform distributions for sparse training, global pruning and uniform pruning for pruning', help='sparse initialization') 21 | # hyperparameters for gradually pruning 22 | parser.add_argument('--method', type=str, default='GraNet', help='method name: DST, GraNet, GraNet_uniform, GMP, GMO_uniform') 23 | 24 | parser.add_argument('--init-density', type=float, default=0.50, help='The pruning rate / death rate.') 25 | parser.add_argument('--final-density', type=float, default=0.05, help='The density of the overall sparse network.') 26 | parser.add_argument('--init-density_adj', type=float, default=1.0, help='The pruning rate / death rate.') 27 | parser.add_argument('--final-density_adj', type=float, default=0.5, help='The density of the overall sparse network.') 28 | parser.add_argument('--init-density_feature', type=float, default=1.0, help='The pruning rate / death rate.') 29 | parser.add_argument('--final-density_feature', type=float, default=0.5, help='The density of the overall sparse network.') 30 | parser.add_argument('--init-prune-epoch', type=int, default=0, help='The pruning rate / death rate.') 31 | parser.add_argument('--final-prune-epoch', type=int, default=110, help='The density of the overall sparse network.') 32 | parser.add_argument('--rm-first', action='store_true', help='Keep the first layer dense.') 33 | 34 | 35 | 36 | 37 | class CosineDecay(object): 38 | def __init__(self, prune_rate, T_max, eta_min=0.005, last_epoch=-1): 39 | self.sgd = optim.SGD(torch.nn.ParameterList([torch.nn.Parameter(torch.zeros(1))]), lr=prune_rate) 40 | self.cosine_stepper = torch.optim.lr_scheduler.CosineAnnealingLR(self.sgd, T_max, eta_min, last_epoch) 41 | 42 | def step(self): 43 | self.cosine_stepper.step() 44 | 45 | def get_dr(self): 46 | return self.sgd.param_groups[0]['lr'] 47 | 48 | class LinearDecay(object): 49 | def __init__(self, prune_rate, factor=0.99, frequency=600): 50 | self.factor = factor 51 | self.steps = 0 52 | self.frequency = frequency 53 | 54 | def step(self): 55 | self.steps += 1 56 | 57 | def get_dr(self, prune_rate): 58 | if self.steps > 0 and self.steps % self.frequency == 0: 59 | return prune_rate*self.factor 60 | else: 61 | return prune_rate 62 | 63 | 64 | 65 | class Masking(object): 66 | def __init__(self, optimizer, 67 | prune_rate=0.3, 68 | growth_death_ratio=1.0, 69 | prune_rate_decay=None, 70 | death_mode='magnitude', 71 | growth_mode='momentum', 72 | redistribution_mode='momentum', 73 | threshold=0.001, 74 | args=None, 75 | train_loader=None, 76 | device=None): 77 | growth_modes = ['random', 'momentum', 'momentum_neuron', 'gradient'] 78 | if growth_mode not in growth_modes: 79 | print('Growth mode: {0} not supported!'.format(growth_mode)) 80 | print('Supported modes are:', str(growth_modes)) 81 | 82 | self.args = args 83 | self.loader = [1] 84 | self.device = args.device 85 | self.growth_mode = growth_mode 86 | self.death_mode = death_mode 87 | self.growth_death_ratio = growth_death_ratio 88 | self.redistribution_mode = redistribution_mode 89 | self.prune_rate_decay = prune_rate_decay 90 | self.sparse_init = args.sparse_init 91 | 92 | 93 | self.masks = {} 94 | self.final_masks = {} 95 | self.grads = {} 96 | self.nonzero_masks = {} 97 | self.scores = {} 98 | self.pruning_rate = {} 99 | self.modules = [] 100 | self.names = [] 101 | self.optimizer = optimizer 102 | 103 | self.adjusted_growth = 0 104 | self.adjustments = [] 105 | self.baseline_nonzero = None 106 | self.name2baseline_nonzero = {} 107 | 108 | # stats 109 | self.name2variance = {} 110 | self.name2zeros = {} 111 | self.name2nonzeros = {} 112 | self.total_variance = 0 113 | self.total_removed = 0 114 | self.total_zero = 0 115 | self.total_nonzero = 0 116 | self.total_params = 0 117 | self.fc_params = 0 118 | self.prune_rate = prune_rate 119 | self.name2prune_rate = {} 120 | self.steps = 0 121 | 122 | if self.args.fix: 123 | self.prune_every_k_steps = None 124 | else: 125 | self.prune_every_k_steps = self.args.update_frequency 126 | 127 | 128 | def init(self, mode='ER', density=0.05, density_adj=0.05, density_feature=0.05, erk_power_scale=1.0, grad_dict=None): 129 | if self.args.method == 'GMP': 130 | print('initialized with GMP, ones') 131 | self.baseline_nonzero = 0 132 | for module in self.modules: 133 | for name, weight in module.named_parameters(): 134 | if name not in self.masks: continue 135 | self.masks[name] = torch.ones_like(weight, dtype=torch.float32, requires_grad=False).to(self.device) 136 | self.baseline_nonzero += (self.masks[name] != 0).sum().int().item() 137 | self.apply_mask() 138 | elif self.sparse_init == 'prune_uniform': 139 | # used for pruning stabability test 140 | print('initialized by prune_uniform') 141 | self.baseline_nonzero = 0 142 | for module in self.modules: 143 | for name, weight in module.named_parameters(): 144 | if name not in self.masks: continue 145 | self.masks[name] = (weight!=0).to(self.device) 146 | num_zeros = (weight==0).sum().item() 147 | num_remove = (self.args.pruning_rate) * self.masks[name].sum().item() 148 | k = math.ceil(num_zeros + num_remove) 149 | if num_remove == 0.0: return weight.data != 0.0 150 | x, idx = torch.sort(torch.abs(weight.data.view(-1))) 151 | self.masks[name].data.view(-1)[idx[:k]] = 0.0 152 | self.baseline_nonzero += (self.masks[name] != 0).sum().int().item() 153 | self.apply_mask() 154 | 155 | elif self.sparse_init == 'prune_global': 156 | # used for pruning stabability test 157 | print('initialized by prune_global') 158 | self.baseline_nonzero = 0 159 | total_num_nonzoros = 0 160 | for module in self.modules: 161 | for name, weight in module.named_parameters(): 162 | if name not in self.masks: continue 163 | self.masks[name] = (weight!=0).to(self.device) 164 | self.name2nonzeros[name] = (weight!=0).sum().item() 165 | total_num_nonzoros += self.name2nonzeros[name] 166 | 167 | weight_abs = [] 168 | for module in self.modules: 169 | for name, weight in module.named_parameters(): 170 | if name not in self.masks: continue 171 | weight_abs.append(torch.abs(weight)) 172 | 173 | # Gather all scores in a single vector and normalise 174 | all_scores = torch.cat([torch.flatten(x) for x in weight_abs]) 175 | num_params_to_keep = int(total_num_nonzoros * (1 - self.args.pruning_rate)) 176 | 177 | threshold, _ = torch.topk(all_scores, num_params_to_keep, sorted=True) 178 | acceptable_score = threshold[-1] 179 | 180 | for module in self.modules: 181 | for name, weight in module.named_parameters(): 182 | if name not in self.masks: continue 183 | self.masks[name] = ((torch.abs(weight)) >= acceptable_score).float() 184 | self.apply_mask() 185 | 186 | elif self.sparse_init == 'prune_and_grow_uniform': 187 | # used for pruning stabability test 188 | print('initialized by pruning and growing uniformly') 189 | 190 | self.baseline_nonzero = 0 191 | for module in self.modules: 192 | for name, weight in module.named_parameters(): 193 | if name not in self.masks: continue 194 | # prune 195 | self.masks[name] = (weight!=0).to(self.device) 196 | num_zeros = (weight==0).sum().item() 197 | num_remove = (self.args.pruning_rate) * self.masks[name].sum().item() 198 | k = math.ceil(num_zeros + num_remove) 199 | if num_remove == 0.0: return weight.data != 0.0 200 | x, idx = torch.sort(torch.abs(weight.data.view(-1))) 201 | self.masks[name].data.view(-1)[idx[:k]] = 0.0 202 | total_regrowth = (self.masks[name]==0).sum().item() - num_zeros 203 | 204 | # set the pruned weights to zero 205 | weight.data = weight.data * self.masks[name] 206 | if 'momentum_buffer' in self.optimizer.state[weight]: 207 | self.optimizer.state[weight]['momentum_buffer'] = self.optimizer.state[weight]['momentum_buffer'] * self.masks[name] 208 | 209 | # grow 210 | grad = grad_dict[name] 211 | grad = grad * (self.masks[name] == 0).float() 212 | 213 | y, idx = torch.sort(torch.abs(grad).flatten(), descending=True) 214 | self.masks[name].data.view(-1)[idx[:total_regrowth]] = 1.0 215 | self.baseline_nonzero += (self.masks[name] != 0).sum().int().item() 216 | self.apply_mask() 217 | 218 | elif self.sparse_init == 'prune_and_grow_global': 219 | # used for pruning stabability test 220 | print('initialized by pruning and growing globally') 221 | self.baseline_nonzero = 0 222 | total_num_nonzoros = 0 223 | for module in self.modules: 224 | for name, weight in module.named_parameters(): 225 | if name not in self.masks: continue 226 | self.masks[name] = (weight!=0).to(self.device) 227 | self.name2nonzeros[name] = (weight!=0).sum().item() 228 | total_num_nonzoros += self.name2nonzeros[name] 229 | 230 | weight_abs = [] 231 | for module in self.modules: 232 | for name, weight in module.named_parameters(): 233 | if name not in self.masks: continue 234 | weight_abs.append(torch.abs(weight)) 235 | 236 | # Gather all scores in a single vector and normalise 237 | all_scores = torch.cat([torch.flatten(x) for x in weight_abs]) 238 | num_params_to_keep = int(total_num_nonzoros * (1 - self.args.pruning_rate)) 239 | 240 | threshold, _ = torch.topk(all_scores, num_params_to_keep, sorted=True) 241 | acceptable_score = threshold[-1] 242 | 243 | for module in self.modules: 244 | for name, weight in module.named_parameters(): 245 | if name not in self.masks: continue 246 | self.masks[name] = ((torch.abs(weight)) >= acceptable_score).float() 247 | 248 | # set the pruned weights to zero 249 | weight.data = weight.data * self.masks[name] 250 | if 'momentum_buffer' in self.optimizer.state[weight]: 251 | self.optimizer.state[weight]['momentum_buffer'] = self.optimizer.state[weight]['momentum_buffer'] * self.masks[name] 252 | 253 | ### grow 254 | for module in self.modules: 255 | for name, weight in module.named_parameters(): 256 | if name not in self.masks: continue 257 | total_regrowth = self.name2nonzeros[name] - (self.masks[name]!=0).sum().item() 258 | grad = grad_dict[name] 259 | grad = grad * (self.masks[name] == 0).float() 260 | 261 | y, idx = torch.sort(torch.abs(grad).flatten(), descending=True) 262 | self.masks[name].data.view(-1)[idx[:total_regrowth]] = 1.0 263 | self.baseline_nonzero += (self.masks[name] != 0).sum().int().item() 264 | self.apply_mask() 265 | 266 | elif self.sparse_init == 'uniform': 267 | self.baseline_nonzero = 0 268 | for module in self.modules: 269 | for name, weight in module.named_parameters(): 270 | if name not in self.masks: continue 271 | if name == "edge_weight_train": 272 | self.masks[name][:] = (torch.rand(weight.shape) < density_adj).float().data.to(self.device) # 273 | #self.baseline_nonzero += weight.numel() * density_adj 274 | 275 | elif name == "x_weight": 276 | self.masks[name][:] = (torch.rand(weight.shape) < density_feature).float().data.to(self.device) # 277 | #self.baseline_nonzero += weight.numel() * density_feature 278 | 279 | else: 280 | self.masks[name][:] = (torch.rand(weight.shape) < density).float().data.to(self.device) # 281 | #self.baseline_nonzero += weight.numel() * density 282 | self.apply_mask() 283 | 284 | elif self.sparse_init == 'ERK': 285 | print('initialize by ERK') 286 | for name, weight in self.masks.items(): 287 | if name == "edge_weight_train": continue 288 | if name == "x_weight": continue 289 | self.total_params += weight.numel() 290 | if 'classifier' in name: 291 | self.fc_params = weight.numel() 292 | is_epsilon_valid = False 293 | dense_layers = set() 294 | while not is_epsilon_valid: 295 | 296 | divisor = 0 297 | rhs = 0 298 | raw_probabilities = {} 299 | for name, mask in self.masks.items(): 300 | if name == "edge_weight_train": continue 301 | if name == "x_weight": continue 302 | n_param = np.prod(mask.shape) 303 | # if name == "edge_weight_train": 304 | # n_zeros = n_param * (1 - density_adj) 305 | # n_ones = n_param * density_adj 306 | # elif name == "x_weight": 307 | # n_zeros = n_param * (1 - density_feature) 308 | # n_ones = n_param * density_feature 309 | # else: 310 | n_zeros = n_param * (1 - density) 311 | n_ones = n_param * density 312 | 313 | 314 | if name in dense_layers: 315 | # See `- default_sparsity * (N_3 + N_4)` part of the equation above. 316 | rhs -= n_zeros 317 | 318 | else: 319 | # Corresponds to `(1 - default_sparsity) * (N_1 + N_2)` part of the 320 | # equation above. 321 | rhs += n_ones 322 | # Erdos-Renyi probability: epsilon * (n_in + n_out / n_in * n_out). 323 | raw_probabilities[name] = ( 324 | np.sum(mask.shape) / np.prod(mask.shape) 325 | ) ** erk_power_scale 326 | # Note that raw_probabilities[mask] * n_param gives the individual 327 | # elements of the divisor. 328 | divisor += raw_probabilities[name] * n_param 329 | # By multipliying individual probabilites with epsilon, we should get the 330 | # number of parameters per layer correctly. 331 | epsilon = rhs / divisor 332 | # If epsilon * raw_probabilities[mask.name] > 1. We set the sparsities of that 333 | # mask to 0., so they become part of dense_layers sets. 334 | max_prob = np.max(list(raw_probabilities.values())) 335 | max_prob_one = max_prob * epsilon 336 | if max_prob_one > 1: 337 | is_epsilon_valid = False 338 | for mask_name, mask_raw_prob in raw_probabilities.items(): 339 | if mask_raw_prob == max_prob: 340 | print(f"Sparsity of var:{mask_name} had to be set to 0.") 341 | dense_layers.add(mask_name) 342 | else: 343 | is_epsilon_valid = True 344 | 345 | density_dict = {} 346 | total_nonzero = 0.0 347 | # With the valid epsilon, we can set sparsities of the remaning layers. 348 | for name, mask in self.masks.items(): 349 | if name == "edge_weight_train": continue 350 | if name == "x_weight": continue 351 | n_param = np.prod(mask.shape) 352 | if name in dense_layers: 353 | density_dict[name] = 1.0 354 | else: 355 | probability_one = epsilon * raw_probabilities[name] 356 | density_dict[name] = probability_one 357 | print( 358 | f"layer: {name}, shape: {mask.shape}, density: {density_dict[name]}" 359 | ) 360 | self.masks[name][:] = (torch.rand(mask.shape) < density_dict[name]).float().data.to(self.device) 361 | 362 | total_nonzero += density_dict[name] * mask.numel() 363 | print(f"Overall sparsity {total_nonzero / self.total_params}") 364 | 365 | self.apply_mask() 366 | 367 | 368 | total_size = 0 369 | for name, weight in self.masks.items(): 370 | total_size += weight.numel() 371 | 372 | sparse_size = 0 373 | for name, weight in self.masks.items(): 374 | sparse_size += (weight != 0).sum().int().item() 375 | print('Total parameters under sparsity level of {0}: {1}'.format(density, sparse_size / total_size)) 376 | 377 | def step(self): 378 | self.optimizer.step() 379 | self.apply_mask() 380 | self.prune_rate_decay.step() 381 | self.prune_rate = self.prune_rate_decay.get_dr() 382 | self.steps += 1 383 | 384 | if self.prune_every_k_steps is not None: 385 | if self.args.method == 'GraNet': 386 | if self.steps >= (self.args.init_prune_epoch * len(self.loader)*self.args.multiplier) and self.steps % self.prune_every_k_steps == 0: 387 | self.pruning(self.steps) 388 | self.truncate_weights(self.steps) 389 | self.print_nonzero_counts() 390 | elif self.args.method == 'GraNet_uniform': 391 | if self.steps >= (self.args.init_prune_epoch * len(self.loader)* self.args.multiplier) and self.steps % self.prune_every_k_steps == 0: 392 | self.pruning_uniform(self.steps) 393 | self.truncate_weights(self.steps) 394 | self.print_nonzero_counts() 395 | # _, _ = self.fired_masks_update() 396 | elif self.args.method == 'DST': 397 | if self.steps % self.prune_every_k_steps == 0: 398 | self.truncate_weights() 399 | self.print_nonzero_counts() 400 | elif self.args.method == 'GMP': 401 | if self.steps >= (self.args.init_prune_epoch * len(self.loader) * self.args.multiplier) and self.steps % self.prune_every_k_steps == 0: 402 | self.pruning(self.steps) 403 | elif self.args.method == 'GMP_uniform': 404 | if self.steps >= (self.args.init_prune_epoch * len(self.loader) * self.args.multiplier) and self.steps % self.prune_every_k_steps == 0: 405 | self.pruning_uniform(self.steps) 406 | 407 | 408 | def pruning(self, step): 409 | # prune_rate = 1 - self.args.final_density - self.args.init_density 410 | curr_prune_iter = int(step / self.prune_every_k_steps) 411 | final_iter = int((self.args.final_prune_epoch * len(self.loader)*self.args.multiplier) / self.prune_every_k_steps) 412 | ini_iter = int((self.args.init_prune_epoch * len(self.loader)*self.args.multiplier) / self.prune_every_k_steps) 413 | total_prune_iter = final_iter - ini_iter 414 | 415 | 416 | 417 | if curr_prune_iter >= ini_iter and curr_prune_iter <= final_iter - 1: 418 | print('******************************************************') 419 | print(f'Pruning Progress is {curr_prune_iter - ini_iter} / {total_prune_iter}') 420 | print('******************************************************') 421 | print("Pruning Start!!") 422 | prune_decay = (1 - ((curr_prune_iter - ini_iter) / total_prune_iter)) ** 3 423 | curr_prune_rate = (1 - self.args.init_density) + (self.args.init_density - self.args.final_density) * ( 424 | 1 - prune_decay) 425 | 426 | curr_prune_rate_adj = (1 - self.args.init_density_adj) + (self.args.init_density_adj - self.args.final_density_adj) * ( 427 | 1 - prune_decay) 428 | 429 | curr_prune_rate_feature = (1 - self.args.init_density_feature) + (self.args.init_density_feature - self.args.final_density_feature) * ( 430 | 1 - prune_decay) 431 | 432 | weight_abs = [] 433 | adj_abs =[] 434 | feature_abs =[] 435 | for module in self.modules: 436 | for name, weight in module.named_parameters(): 437 | if name not in self.masks: continue 438 | 439 | if name == "edge_weight_train": 440 | adj_abs.append(torch.abs(weight)) 441 | elif name == "x_weight": 442 | feature_abs.append(torch.abs(weight)) 443 | else: 444 | weight_abs.append(torch.abs(weight)) 445 | 446 | # Gather all scores in a single vector and normalise 447 | if self.args.weight_sparse: 448 | all_scores = torch.cat([torch.flatten(x) for x in weight_abs]) 449 | num_params_to_keep = int(len(all_scores) * (1 - curr_prune_rate)) 450 | 451 | threshold, _ = torch.topk(all_scores, num_params_to_keep, sorted=True) 452 | acceptable_score_weight = threshold[-1] 453 | 454 | # Gather adj scores 455 | if self.args.adj_sparse: 456 | all_scores = torch.cat([torch.flatten(x) for x in adj_abs]) 457 | num_params_to_keep = int(len(all_scores) * (1 - curr_prune_rate_adj)) 458 | 459 | threshold_adj, _ = torch.topk(all_scores, num_params_to_keep, sorted=True) 460 | acceptable_score_adj = threshold_adj[-1] 461 | 462 | # Gather adj scores 463 | 464 | if self.args.feature_sparse: 465 | all_scores = torch.cat([torch.flatten(x) for x in feature_abs]) 466 | num_params_to_keep = int(len(all_scores) * (1 - curr_prune_rate_feature)) 467 | 468 | threshold_feature, _ = torch.topk(all_scores, num_params_to_keep, sorted=True) 469 | acceptable_score_feature = threshold_feature[-1] 470 | 471 | 472 | for module in self.modules: 473 | for name, weight in module.named_parameters(): 474 | if name not in self.masks: continue 475 | 476 | if self.args.adj_sparse: 477 | if name == "edge_weight_train": 478 | self.masks[name] = ((torch.abs(weight)) > acceptable_score_adj).float() 479 | print("Add Sparse Mask --- Graph Adj: {} !".format(name)) 480 | 481 | if self.args.feature_sparse: 482 | if name == "x_weight": 483 | self.masks[name] = ((torch.abs(weight)) > acceptable_score_feature).float() 484 | print("Add Sparse Mask --- Graph Feature: {} !".format(name)) 485 | 486 | if self.args.weight_sparse: 487 | if len(weight.size()) == 4 or len(weight.size()) == 2: 488 | self.masks[name] = ((torch.abs(weight)) > acceptable_score_weight).float() 489 | #must be > to prevent acceptable_score is zero, leading to dense tensors 490 | print("Add Sparse Mask --- Model Weight: {} !".format(name)) 491 | print("="*40) 492 | self.apply_mask() 493 | 494 | weight_total_size = 1 495 | adj_total_size = 1 496 | feature_total_size = 1 497 | 498 | for name, weight in self.masks.items(): 499 | if name == "edge_weight_train": 500 | adj_total_size += weight.numel() 501 | elif name == "x_weight": 502 | feature_total_size += weight.numel() 503 | else: 504 | weight_total_size += weight.numel() 505 | 506 | print('Total Model parameters:{}, Graph Edge Numbers:{}, Feature Channels:{}'.format(weight_total_size,adj_total_size,feature_total_size)) 507 | 508 | weight_sparse_size = 0 509 | adj_sparse_size = 0 510 | feature_sparse_size = 0 511 | 512 | for name, weight in self.masks.items(): 513 | 514 | if name == "edge_weight_train": 515 | adj_sparse_size += (weight != 0).sum().int().item() 516 | elif name == "x_weight": 517 | feature_sparse_size += (weight != 0).sum().int().item() 518 | else: 519 | weight_sparse_size += (weight != 0).sum().int().item() 520 | 521 | print('Model Parameters Sparsity after pruning: {} \nGraph Edge Numbers after pruning: {} \nFeature Channels Sparsity after pruning:{}'.format( 522 | (weight_total_size-weight_sparse_size) / weight_total_size, 523 | (adj_total_size-adj_sparse_size) / adj_total_size, 524 | (feature_total_size-feature_sparse_size) / feature_total_size)) 525 | print("="*40) 526 | 527 | def pruning_uniform(self, step): 528 | # prune_rate = 1 - self.args.final_density - self.args.init_density 529 | curr_prune_iter = int(step / self.prune_every_k_steps) 530 | final_iter = (self.args.final_prune_epoch * len(self.loader)*self.args.multiplier) / self.prune_every_k_steps 531 | ini_iter = (self.args.init_prune_epoch * len(self.loader)*self.args.multiplier) / self.prune_every_k_steps 532 | total_prune_iter = final_iter - ini_iter 533 | 534 | 535 | if curr_prune_iter >= ini_iter and curr_prune_iter <= final_iter: 536 | print('******************************************************') 537 | print(f'Pruning Progress is {curr_prune_iter - ini_iter} / {total_prune_iter}') 538 | print('******************************************************') 539 | 540 | prune_decay = (1 - ((curr_prune_iter - ini_iter) / total_prune_iter)) ** 3 541 | curr_prune_rate = (1 - self.args.init_density) + (self.args.init_density - self.args.final_density) * ( 542 | 1 - prune_decay) 543 | 544 | curr_prune_rate_adj = (1 - self.args.init_density_adj) + (self.args.init_density_adj - self.args.final_density_adj) * ( 545 | 1 - prune_decay) 546 | 547 | curr_prune_rate_feature = (1 - self.args.init_density_feature) + (self.args.init_density_feature - self.args.final_density_feature) * ( 548 | 1 - prune_decay) 549 | 550 | # keep the density of the last layer as 0.2 if spasity is larger then 0.8 551 | # if curr_prune_rate >= 0.8: 552 | # curr_prune_rate = 1 - (self.total_params * (1-curr_prune_rate) - 0.2 * self.fc_params)/(self.total_params-self.fc_params) 553 | 554 | # for module in self.modules: 555 | # for name, weight in module.named_parameters(): 556 | # if name not in self.masks: continue 557 | # score = torch.flatten(torch.abs(weight)) 558 | # if 'classifier' in name: 559 | # num_params_to_keep = int(len(score) * 0.2) 560 | # threshold, _ = torch.topk(score, num_params_to_keep, sorted=True) 561 | # acceptable_score = threshold[-1] 562 | # self.masks[name] = ((torch.abs(weight)) >= acceptable_score).float() 563 | # else: 564 | # num_params_to_keep = int(len(score) * (1 - curr_prune_rate)) 565 | # threshold, _ = torch.topk(score, num_params_to_keep, sorted=True) 566 | # acceptable_score = threshold[-1] 567 | # self.masks[name] = ((torch.abs(weight)) >= acceptable_score).float() 568 | 569 | for module in self.modules: 570 | for name, weight in module.named_parameters(): 571 | if name not in self.masks: continue 572 | 573 | score = torch.flatten(torch.abs(weight)) 574 | 575 | if name == "edge_weight_train": 576 | num_params_to_keep = int(len(score) * (1 - curr_prune_rate_adj)) 577 | print("Add Sparse Mask --- Graph Adj: {} !".format(name)) 578 | elif name == "x_weight": 579 | num_params_to_keep = int(len(score) * (1 - curr_prune_rate_feature)) 580 | print("Add Sparse Mask --- Graph Feature: {} !".format(name)) 581 | else: 582 | num_params_to_keep = int(len(score) * (1 - curr_prune_rate)) 583 | print("Add Sparse Mask --- Model Weight: {} !".format(name)) 584 | 585 | threshold, _ = torch.topk(score, num_params_to_keep, sorted=True) 586 | acceptable_score = threshold[-1] 587 | self.masks[name] = ((torch.abs(weight)) >= acceptable_score).float() 588 | 589 | 590 | self.apply_mask() 591 | 592 | weight_total_size = 1 593 | adj_total_size = 1 594 | feature_total_size = 1 595 | 596 | for name, weight in self.masks.items(): 597 | if name == "edge_weight_train": 598 | adj_total_size += weight.numel() 599 | elif name == "x_weight": 600 | feature_total_size += weight.numel() 601 | else: 602 | weight_total_size += weight.numel() 603 | 604 | print('Total Model parameters:{}, Graph Edge Numbers:{}, Feature Channels:{}'.format(weight_total_size,adj_total_size,feature_total_size)) 605 | 606 | weight_sparse_size = 0 607 | adj_sparse_size = 0 608 | feature_sparse_size = 0 609 | 610 | for name, weight in self.masks.items(): 611 | 612 | if name == "edge_weight_train": 613 | adj_sparse_size += (weight != 0).sum().int().item() 614 | elif name == "x_weight": 615 | feature_sparse_size += (weight != 0).sum().int().item() 616 | else: 617 | weight_sparse_size += (weight != 0).sum().int().item() 618 | 619 | print('Model Parameters Sparsity after pruning: {} \nGraph Edge Numbers after pruning: {} \nFeature Channels Sparsity after pruning:{}'.format( 620 | (weight_total_size-weight_sparse_size) / weight_total_size, 621 | (adj_total_size-adj_sparse_size) / adj_total_size, 622 | (feature_total_size-feature_sparse_size) / feature_total_size)) 623 | print("="*40) 624 | 625 | # total_size = 0 626 | # for name, weight in self.masks.items(): 627 | # total_size += weight.numel() 628 | # print('Total Model parameters:', total_size) 629 | 630 | # sparse_size = 0 631 | # for name, weight in self.masks.items(): 632 | # sparse_size += (weight != 0).sum().int().item() 633 | 634 | # print('Sparsity after pruning: {0}'.format( 635 | # (total_size-sparse_size) / total_size)) 636 | 637 | 638 | def add_module(self, module, sparse_init='ERK', grad_dic=None): 639 | self.module = module 640 | self.sparse_init = self.sparse_init 641 | self.modules.append(module) 642 | for name, tensor in module.named_parameters(): 643 | 644 | if self.args.adj_sparse: 645 | if name == "edge_weight_train": 646 | self.names.append(name) 647 | self.masks[name] = torch.ones_like(tensor, dtype=torch.float32, requires_grad=False).to(self.device) 648 | print("Add Sparse Module --- Graph Adj:{} Sparse Module!".format(name)) 649 | 650 | if self.args.feature_sparse: 651 | if name == "x_weight": 652 | self.names.append(name) 653 | self.masks[name] = torch.ones_like(tensor, dtype=torch.float32, requires_grad=False).to(self.device) 654 | print("Add Sparse Module --- Graph Feature: {} !".format(name)) 655 | 656 | if self.args.weight_sparse: 657 | if len(tensor.size()) == 4 or len(tensor.size()) == 2: 658 | self.names.append(name) 659 | self.masks[name] = torch.ones_like(tensor, dtype=torch.float32, requires_grad=False).to(self.device) 660 | print("Add Sparse Module --- Model Weight: {} !".format(name)) 661 | 662 | 663 | print("Add Module Done!") 664 | print("="*40) 665 | 666 | if self.args.rm_first: 667 | for name, tensor in module.named_parameters(): 668 | if 'conv.weight' in name or 'feature.0.weight' in name: 669 | self.masks.pop(name) 670 | print(f"pop out {name}") 671 | 672 | self.init( mode=self.args.sparse_init, 673 | density=self.args.init_density, 674 | density_adj =self.args.init_density_adj, 675 | density_feature= self.args.init_density_feature, 676 | grad_dict=grad_dic) # init weight 677 | 678 | 679 | def remove_weight(self, name): 680 | if name in self.masks: 681 | print('Removing {0} of size {1} = {2} parameters.'.format(name, self.masks[name].shape, 682 | self.masks[name].numel())) 683 | self.masks.pop(name) 684 | elif name + '.weight' in self.masks: 685 | print('Removing {0} of size {1} = {2} parameters.'.format(name, self.masks[name + '.weight'].shape, 686 | self.masks[name + '.weight'].numel())) 687 | self.masks.pop(name + '.weight') 688 | else: 689 | print('ERROR', name) 690 | 691 | def remove_weight_partial_name(self, partial_name): 692 | removed = set() 693 | for name in list(self.masks.keys()): 694 | if partial_name in name: 695 | 696 | print('Removing {0} of size {1} with {2} parameters...'.format(name, self.masks[name].shape, 697 | np.prod(self.masks[name].shape))) 698 | removed.add(name) 699 | self.masks.pop(name) 700 | 701 | print('Removed {0} layers.'.format(len(removed))) 702 | 703 | i = 0 704 | while i < len(self.names): 705 | name = self.names[i] 706 | if name in removed: 707 | self.names.pop(i) 708 | else: 709 | i += 1 710 | 711 | def remove_type(self, nn_type): 712 | for module in self.modules: 713 | for name, module in module.named_modules(): 714 | if isinstance(module, nn_type): 715 | self.remove_weight(name) 716 | 717 | def apply_mask(self): 718 | for module in self.modules: 719 | for name, tensor in module.named_parameters(): 720 | if name in self.masks: 721 | tensor.data = tensor.data*self.masks[name] 722 | #print("Trying to Apply Mask on {}".format(name)) 723 | if 'momentum_buffer' in self.optimizer.state[tensor]: 724 | self.optimizer.state[tensor]['momentum_buffer'] = self.optimizer.state[tensor]['momentum_buffer']*self.masks[name] 725 | 726 | 727 | def truncate_weights(self, step=None): 728 | 729 | curr_prune_iter = int(step / self.prune_every_k_steps) 730 | final_iter = int((self.args.final_prune_epoch * len(self.loader)*self.args.multiplier) / self.prune_every_k_steps) 731 | ini_iter = int((self.args.init_prune_epoch * len(self.loader)*self.args.multiplier) / self.prune_every_k_steps) 732 | total_prune_iter = final_iter - ini_iter 733 | 734 | if curr_prune_iter >= ini_iter and curr_prune_iter <= final_iter - 1: 735 | print('******************************************************') 736 | print(f'Death and Growth Progress is {curr_prune_iter - ini_iter} / {total_prune_iter}') 737 | print('******************************************************') 738 | 739 | self.gather_statistics() 740 | 741 | # prune 742 | for module in self.modules: 743 | for name, weight in module.named_parameters(): 744 | if name not in self.masks: continue 745 | mask = self.masks[name] 746 | 747 | new_mask = self.magnitude_death(mask, weight, name) 748 | self.pruning_rate[name] = int(self.name2nonzeros[name] - new_mask.sum().item()) 749 | self.masks[name][:] = new_mask 750 | 751 | # grow 752 | for module in self.modules: 753 | for name, weight in module.named_parameters(): 754 | if name not in self.masks: continue 755 | new_mask = self.masks[name].data.byte() 756 | 757 | if self.args.growth_schedule == "gradient": 758 | new_mask = self.gradient_growth(name, new_mask, self.pruning_rate[name], weight) 759 | elif self.args.growth_schedule == "momentum": 760 | new_mask = self.momentum_growth(name, new_mask, self.pruning_rate[name], weight) 761 | elif self.args.growth_schedule == "random": 762 | new_mask = self.random_growth(name, new_mask, self.pruning_rate[name], weight) 763 | # exchanging masks 764 | self.masks.pop(name) 765 | self.masks[name] = new_mask.float() 766 | 767 | 768 | self.apply_mask() 769 | 770 | 771 | ''' 772 | REDISTRIBUTION 773 | ''' 774 | 775 | def gather_statistics(self): 776 | self.name2nonzeros = {} 777 | self.name2zeros = {} 778 | 779 | for module in self.modules: 780 | for name, tensor in module.named_parameters(): 781 | if name not in self.masks: continue 782 | mask = self.masks[name] 783 | 784 | self.name2nonzeros[name] = mask.sum().item() 785 | self.name2zeros[name] = mask.numel() - self.name2nonzeros[name] 786 | 787 | ############################ DEATH ########################### 788 | 789 | def magnitude_death(self, mask, weight, name): 790 | num_remove = math.ceil(self.prune_rate*self.name2nonzeros[name]) 791 | if num_remove == 0.0: return weight.data != 0.0 792 | num_zeros = self.name2zeros[name] 793 | k = math.ceil(num_zeros + num_remove) 794 | x, idx = torch.sort(torch.abs(weight.data.view(-1))) 795 | threshold = x[k-1].item() 796 | 797 | return (torch.abs(weight.data) > threshold) 798 | 799 | 800 | ########################### GROWTH ########################### 801 | 802 | def random_growth(self, name, new_mask, total_regrowth, weight): 803 | n = (new_mask==0).sum().item() 804 | if n == 0: return new_mask 805 | expeced_growth_probability = (total_regrowth/n) 806 | new_weights = torch.rand(new_mask.shape).to(self.device) < expeced_growth_probability #lsw 807 | # new_weights = torch.rand(new_mask.shape) < expeced_growth_probability 808 | new_mask_ = new_mask.byte() | new_weights 809 | if (new_mask_!=0).sum().item() == 0: 810 | new_mask_ = new_mask 811 | return new_mask_ 812 | 813 | def momentum_growth(self, name, new_mask, total_regrowth, weight): 814 | grad = self.get_momentum_for_weight(weight) 815 | grad = grad*(new_mask==0).float() 816 | y, idx = torch.sort(torch.abs(grad).flatten(), descending=True) 817 | new_mask.data.view(-1)[idx[:total_regrowth]] = 1.0 818 | 819 | return new_mask 820 | 821 | 822 | def gradient_growth(self, name, new_mask, total_regrowth, weight): 823 | grad = self.get_gradient_for_weights(weight) 824 | grad = grad*(new_mask==0).float() 825 | 826 | y, idx = torch.sort(torch.abs(grad).flatten(), descending=True) 827 | new_mask.data.view(-1)[idx[:total_regrowth]] = 1.0 828 | 829 | return new_mask 830 | 831 | 832 | 833 | ''' 834 | UTILITY 835 | ''' 836 | def get_momentum_for_weight(self, weight): 837 | if 'exp_avg' in self.optimizer.state[weight]: 838 | adam_m1 = self.optimizer.state[weight]['exp_avg'] 839 | adam_m2 = self.optimizer.state[weight]['exp_avg_sq'] 840 | grad = adam_m1/(torch.sqrt(adam_m2) + 1e-08) 841 | elif 'momentum_buffer' in self.optimizer.state[weight]: 842 | grad = self.optimizer.state[weight]['momentum_buffer'] 843 | return grad 844 | 845 | def get_gradient_for_weights(self, weight): 846 | grad = weight.grad.clone() 847 | return grad 848 | 849 | def print_nonzero_counts(self): 850 | for module in self.modules: 851 | for name, tensor in module.named_parameters(): 852 | if name not in self.masks: continue 853 | mask = self.masks[name] 854 | num_nonzeros = (mask != 0).sum().item() 855 | val = '{0}: {1}->{2}, density: {3:.3f}'.format(name, self.name2nonzeros[name], num_nonzeros, num_nonzeros/float(mask.numel())) 856 | print(val) 857 | 858 | print('Death rate: {0}\n'.format(self.prune_rate)) 859 | print("="*40) 860 | 861 | def reset_momentum(self): 862 | """ 863 | Taken from: https://github.com/AlliedToasters/synapses/blob/master/synapses/SET_layer.py 864 | Resets buffers from memory according to passed indices. 865 | When connections are reset, parameters should be treated 866 | as freshly initialized. 867 | """ 868 | for module in self.modules: 869 | for name, tensor in module.named_parameters(): 870 | if name not in self.masks: continue 871 | mask = self.masks[name] 872 | weights = list(self.optimizer.state[tensor]) 873 | for w in weights: 874 | if w == 'momentum_buffer': 875 | # momentum 876 | if self.args.reset_mom_zero: 877 | print('zero') 878 | self.optimizer.state[tensor][w][mask == 0] = 0 879 | else: 880 | print('mean') 881 | self.optimizer.state[tensor][w][mask==0] = torch.mean(self.optimizer.state[tensor][w][mask.byte()]) 882 | # self.optimizer.state[tensor][w][mask==0] = 0 883 | elif w == 'square_avg' or \ 884 | w == 'exp_avg' or \ 885 | w == 'exp_avg_sq' or \ 886 | w == 'exp_inf': 887 | # Adam 888 | self.optimizer.state[tensor][w][mask==0] = torch.mean(self.optimizer.state[tensor][w][mask.byte()]) 889 | 890 | def fired_masks_update(self): 891 | ntotal_fired_weights = 0.0 892 | ntotal_weights = 0.0 893 | layer_fired_weights = {} 894 | for module in self.modules: 895 | for name, weight in module.named_parameters(): 896 | if name not in self.masks: continue 897 | self.fired_masks[name] = self.masks[name].data.byte() | self.fired_masks[name].data.byte() 898 | ntotal_fired_weights += float(self.fired_masks[name].sum().item()) 899 | ntotal_weights += float(self.fired_masks[name].numel()) 900 | layer_fired_weights[name] = float(self.fired_masks[name].sum().item())/float(self.fired_masks[name].numel()) 901 | print('Layerwise percentage of the fired weights of', name, 'is:', layer_fired_weights[name]) 902 | total_fired_weights = ntotal_fired_weights/ntotal_weights 903 | print('The percentage of the total fired weights is:', total_fired_weights) 904 | return layer_fired_weights, total_fired_weights 905 | --------------------------------------------------------------------------------