├── .DS_Store
├── models
├── .DS_Store
├── init_utils.py
├── initializers.py
├── gcn_conv.py
├── fagcn_conv.py
├── gat_conv.py
└── model.py
├── sparselearning
├── .DS_Store
├── __init__.py
├── sparse_sgd.py
├── models.py
└── core.py
├── README.md
├── run_base.sh
├── run.sh
├── run_wf.sh
├── run_multi.sh
├── run_waf.sh
└── main_stgnn.py
/.DS_Store:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/CausalLearning/CGP/HEAD/.DS_Store
--------------------------------------------------------------------------------
/models/.DS_Store:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/CausalLearning/CGP/HEAD/models/.DS_Store
--------------------------------------------------------------------------------
/sparselearning/.DS_Store:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/CausalLearning/CGP/HEAD/sparselearning/.DS_Store
--------------------------------------------------------------------------------
/sparselearning/__init__.py:
--------------------------------------------------------------------------------
1 | import logging
2 | logging.getLogger(__name__).addHandler(logging.NullHandler())
3 |
--------------------------------------------------------------------------------
/models/init_utils.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.init as init
4 |
5 |
6 | def weights_init(m):
7 | # print('=> weights init')
8 | if isinstance(m, nn.Conv2d):
9 | nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
10 | # nn.init.normal_(m.weight, 0, 0.1)
11 | if m.bias is not None:
12 | m.bias.data.zero_()
13 | elif isinstance(m, nn.Linear):
14 | # nn.init.xavier_normal(m.weight)
15 | nn.init.normal_(m.weight, 0, 0.01)
16 | nn.init.constant_(m.bias, 0)
17 | elif isinstance(m, nn.BatchNorm2d):
18 | # Note that BN's running_var/mean are
19 | # already initialized to 1 and 0 respectively.
20 | if m.weight is not None:
21 | m.weight.data.fill_(1.0)
22 | if m.bias is not None:
23 | m.bias.data.zero_()
--------------------------------------------------------------------------------
/models/initializers.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Facebook, Inc. and its affiliates.
2 |
3 | # This source code is licensed under the MIT license found in the
4 | # LICENSE file in the root directory of this source tree.
5 |
6 | import torch
7 |
8 |
9 | def binary(w):
10 | if isinstance(w, torch.nn.Linear) or isinstance(w, torch.nn.Conv2d):
11 | torch.nn.init.kaiming_normal_(w.weight)
12 | sigma = w.weight.data.std()
13 | w.weight.data = torch.sign(w.weight.data) * sigma
14 |
15 |
16 | def kaiming_normal(w):
17 | if isinstance(w, torch.nn.Linear) or isinstance(w, torch.nn.Conv2d):
18 | torch.nn.init.kaiming_normal_(w.weight)
19 |
20 |
21 | def kaiming_uniform(w):
22 | if isinstance(w, torch.nn.Linear) or isinstance(w, torch.nn.Conv2d):
23 | torch.nn.init.kaiming_uniform_(w.weight)
24 |
25 |
26 | def orthogonal(w):
27 | if isinstance(w, torch.nn.Linear) or isinstance(w, torch.nn.Conv2d):
28 | torch.nn.init.orthogonal_(w.weight)
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # Comprehensive Graph Gradual Pruning for Sparse Training in Graph Neural Networks
2 |
3 | Open-sourced implementation for TNNLS 2023.
4 |
5 |
6 |
7 |
Abstract
8 |
9 | 1) We propose a graph gradual pruning framework, namely
10 | CGP, to reduce the training and inference computing costs
11 | of GNN models while preserving their accuracy.
12 |
13 | 2) We comprehensively sparsify the elements of GNNs,
14 | including graph structures, the node feature dimension, and
15 | model parameters, to significantly improve the efficiency
16 | of GNN models.
17 |
18 | 3) Experimental results on various GNN models and datasets
19 | consistently validate the effectiveness and efficiency of
20 | our proposed CGP.
21 |
22 |
23 |
24 | Python Dependencies
25 |
26 | Our proposed Gapformer is implemented in Python 3.7 and major libraries include:
27 |
28 | * [Pytorch](https://pytorch.org/) = 1.11.0+cu113
29 | * [PyG](https://pytorch-geometric.readthedocs.io/en/latest/) torch-geometric=2.2.0
30 |
31 | More dependencies are provided in **requirements.txt**.
32 |
33 | To Run
34 |
35 | Once the requirements are fulfilled, use this command to run:
36 |
37 | `sh xx.sh`
38 |
39 | Datasets
40 |
41 | All datasets used in this paper can be downloaded from [PyG](https://pytorch-geometric.readthedocs.io/en/latest/modules/datasets.html).
42 |
--------------------------------------------------------------------------------
/run_base.sh:
--------------------------------------------------------------------------------
1 | for model in gcn gat sgc appnp
2 | do
3 | for data in cora citeseer pubmed Cornell Texas Wisconsin Actor CS Physics Computers Photo WikiCS ogbn-arxiv
4 | do
5 | python main_stgnn.py --method GraNet \
6 | --optimizer adam \
7 | --sparse-init ERK \
8 | --init-density 1.0 \
9 | --l2 0.0005 \
10 | --lr 0.01 \
11 | --cuda $1 \
12 | --epochs 200 \
13 | --model $model \
14 | --data $data
15 | done
16 | done
17 |
18 |
19 | # cora citeseer
20 |
21 |
22 |
23 | # --model: gcn, gat, sgc, appnp, gcnii (5)
24 | # --data: cora, citeseer, citeseer, Cornell, Texas, Wisconsin, Actor
25 | # --data: CS, Physics, Computers, Photo, WikiCS, reddit
26 | # --data: ogbn-arxiv, ogbn-proteins, ogbn-products, ogbn-papers100M (17)
27 | # --weight_sparse or --feature_sparse --sparse (7)
28 | # --sparse: base or sparse train (2)
29 |
30 | # --method: GraNet, GraNet_uniform, GMP, GMP_uniform (4)
31 | # --growth_schedule: gradient, momentum, random (3)
32 | # --sparse_init: uniform, ERK (2)
33 |
34 | # --prune-rate : regenration rate : 0.1, 0.2, 0.3 (3)
35 | # --update-frequency 10 20 30 (3)
36 | # --final-prune-epoch 50 100 150 (3)
37 |
38 | # --init-density: weight init density: 1, (dense to sparse)
39 | # --final-density: weight : 0.5 0.1 0.01 0.0001 (4)
40 | # --final-density_adj : 0.9 0.8 0.7 0.6 0.5 0.4 0.3 0.2 0.1 0.01 (10)
41 | # --final-density_feature: 0.9 0.8 0.7 0.6 0.5 0.4 0.3 0.2 0.1 0.01 (10)
42 |
43 |
44 | # 5 x 17 x 7 x 4 x 3 x 2 X 3 X 3 X 3 X 4 X 10 X 10 = 154,224,000
45 |
46 | # Actual: 4 x 13 x 4 x 10 x 3 x 3 x 3 = 56160
47 |
48 |
49 | # python main_stgnn.py --method GraNet \
50 | # --prune-rate 0.5 \
51 | # --optimizer adam \
52 | # --sparse-init ERK \
53 | # --init-density 0.5 \
54 | # --final-density 0.1 \
55 | # --update-frequency 10 \
56 | # --l2 0.0005 \
57 | # --lr 0.01 \
58 | # --epochs 200 \
59 | # --model gcn \
60 | # --data cora \
61 | # --final-prune-epoch 100
--------------------------------------------------------------------------------
/run.sh:
--------------------------------------------------------------------------------
1 | for model in gcn gat sgc appnp
2 | do
3 | for data in cora citeseer pubmed Cornell Texas Wisconsin Actor CS Physics Computers Photo WikiCS ogbn-arxiv
4 | do
5 | for fde in 0.5 0.1 0.01 0.0001
6 | do
7 | for fda in 0.9 0.8 0.7 0.6 0.5 0.4 0.3 0.2 0.1 0.01
8 | do
9 | for pr in 0.1 0.2 0.3
10 | do
11 | for uf in 10 20 30
12 | do
13 | for fpe in 50 100 150
14 | do
15 | python main_stgnn.py --method GraNet \
16 | --prune-rate $pr \
17 | --optimizer adam \
18 | --sparse-init ERK \
19 | --init-density 1.0 \
20 | --final-density $fde \
21 | --final-density_adj $fda \
22 | --final-density_feature 0.5 \
23 | --update-frequency $uf \
24 | --l2 0.0005 \
25 | --lr 0.01 \
26 | --cuda $1 \
27 | --epochs 200 \
28 | --model $model \
29 | --data $data \
30 | --final-prune-epoch $fpe \
31 | --growth_schedule momentum \
32 | --adj_sparse \
33 | --weight_sparse \
34 | --sparse
35 | done
36 | done
37 | done
38 | done
39 | done
40 | done
41 | done
42 |
43 |
44 |
45 |
46 |
47 | # --model: gcn, gat, sgc, appnp, gcnii (5)
48 | # --data: cora, citeseer, citeseer, Cornell, Texas, Wisconsin, Actor
49 | # --data: CS, Physics, Computers, Photo, WikiCS, reddit
50 | # --data: ogbn-arxiv, ogbn-proteins, ogbn-products, ogbn-papers100M (17)
51 | # --weight_sparse or --feature_sparse --sparse (7)
52 | # --sparse: base or sparse train (2)
53 |
54 | # --method: GraNet, GraNet_uniform, GMP, GMP_uniform (4)
55 | # --growth_schedule: gradient, momentum, random (3)
56 | # --sparse_init: uniform, ERK (2)
57 |
58 | # --prune-rate : regenration rate : 0.1, 0.2, 0.3 (3)
59 | # --update-frequency 10 20 30 (3)
60 | # --final-prune-epoch 50 100 150 (3)
61 |
62 | # --init-density: weight init density: 1, (dense to sparse)
63 | # --final-density: weight : 0.5 0.1 0.01 0.0001 (4)
64 | # --final-density_adj : 0.9 0.8 0.7 0.6 0.5 0.4 0.3 0.2 0.1 0.01 (10)
65 | # --final-density_feature: 0.9 0.8 0.7 0.6 0.5 0.4 0.3 0.2 0.1 0.01 (10)
66 |
67 |
68 | # 5 x 17 x 7 x 4 x 3 x 2 X 3 X 3 X 3 X 4 X 10 X 10 = 154,224,000
69 |
70 | # Actual: 4 x 13 x 4 x 10 x 3 x 3 x 3 = 56160
71 |
72 |
73 | # python main_stgnn.py --method GraNet \
74 | # --prune-rate 0.5 \
75 | # --optimizer adam \
76 | # --sparse-init ERK \
77 | # --init-density 0.5 \
78 | # --final-density 0.1 \
79 | # --update-frequency 10 \
80 | # --l2 0.0005 \
81 | # --lr 0.01 \
82 | # --epochs 200 \
83 | # --model gcn \
84 | # --data cora \
85 | # --final-prune-epoch 100
--------------------------------------------------------------------------------
/run_wf.sh:
--------------------------------------------------------------------------------
1 | for model in gcn gat sgc appnp
2 | do
3 | for data in cora citeseer pubmed Cornell Texas Wisconsin Actor CS Physics Computers Photo ogbn-arxiv
4 | do
5 | for fde in 0.8 0.5 0.1 0.01 0.001
6 | do
7 | for fdf in 0.9 0.8 0.7 0.6 0.5 0.4 0.3 0.2 0.1 0.01
8 | do
9 | for pr in 0.1 0.2 0.3
10 | do
11 | for uf in 10 20 30
12 | do
13 | for fpe in 50 100 150
14 | do
15 | python main_stgnn.py --method GraNet \
16 | --prune-rate $pr \
17 | --optimizer adam \
18 | --sparse-init ERK \
19 | --init-density 1.0 \
20 | --final-density $fde \
21 | --final-density_adj 1.0 \
22 | --final-density_feature $fdf \
23 | --update-frequency $uf \
24 | --l2 0.0005 \
25 | --lr 0.01 \
26 | --cuda $1 \
27 | --epochs 200 \
28 | --model $model \
29 | --data $data \
30 | --final-prune-epoch $fpe \
31 | --growth_schedule momentum \
32 | --feature_sparse \
33 | --weight_sparse \
34 | --sparse
35 | done
36 | done
37 | done
38 | done
39 | done
40 | done
41 | done
42 |
43 |
44 | # cora citeseer
45 |
46 |
47 |
48 | # --model: gcn, gat, sgc, appnp, gcnii (5)
49 | # --data: cora, citeseer, citeseer, Cornell, Texas, Wisconsin, Actor
50 | # --data: CS, Physics, Computers, Photo, WikiCS, reddit
51 | # --data: ogbn-arxiv, ogbn-proteins, ogbn-products, ogbn-papers100M (17)
52 | # --weight_sparse or --feature_sparse --sparse (7)
53 | # --sparse: base or sparse train (2)
54 |
55 | # --method: GraNet, GraNet_uniform, GMP, GMP_uniform (4)
56 | # --growth_schedule: gradient, momentum, random (3)
57 | # --sparse_init: uniform, ERK (2)
58 |
59 | # --prune-rate : regenration rate : 0.1, 0.2, 0.3 (3)
60 | # --update-frequency 10 20 30 (3)
61 | # --final-prune-epoch 50 100 150 (3)
62 |
63 | # --init-density: weight init density: 1, (dense to sparse)
64 | # --final-density: weight : 0.5 0.1 0.01 0.0001 (4)
65 | # --final-density_adj : 0.9 0.8 0.7 0.6 0.5 0.4 0.3 0.2 0.1 0.01 (10)
66 | # --final-density_feature: 0.9 0.8 0.7 0.6 0.5 0.4 0.3 0.2 0.1 0.01 (10)
67 |
68 |
69 | # 5 x 17 x 7 x 4 x 3 x 2 X 3 X 3 X 3 X 4 X 10 X 10 = 154,224,000
70 |
71 | # Actual: 4 x 13 x 4 x 10 x 3 x 3 x 3 = 56160
72 |
73 |
74 | # python main_stgnn.py --method GraNet \
75 | # --prune-rate 0.5 \
76 | # --optimizer adam \
77 | # --sparse-init ERK \
78 | # --init-density 0.5 \
79 | # --final-density 0.1 \
80 | # --update-frequency 10 \
81 | # --l2 0.0005 \
82 | # --lr 0.01 \
83 | # --epochs 200 \
84 | # --model gcn \
85 | # --data cora \
86 | # --final-prune-epoch 100
--------------------------------------------------------------------------------
/models/gcn_conv.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch.nn import Parameter
3 | from torch_scatter import scatter_add
4 | from torch_geometric.nn.conv import MessagePassing
5 | from torch_geometric.utils import add_remaining_self_loops
6 | from torch_geometric.nn.inits import glorot, zeros
7 | import pdb
8 |
9 | class GCNConv(MessagePassing):
10 |
11 | def __init__(self, in_channels, out_channels, improved=False, cached=False,
12 | bias=True, normalize=True, **kwargs):
13 | super(GCNConv, self).__init__(aggr='add', **kwargs)
14 |
15 | self.in_channels = in_channels
16 | self.out_channels = out_channels
17 | self.improved = improved
18 | self.cached = cached
19 | self.normalize = normalize
20 |
21 | self.weight = Parameter(torch.Tensor(in_channels, out_channels))
22 |
23 | if bias:
24 | self.bias = Parameter(torch.Tensor(out_channels))
25 | else:
26 | self.register_parameter('bias', None)
27 |
28 | self.reset_parameters()
29 |
30 | def reset_parameters(self):
31 | glorot(self.weight)
32 | zeros(self.bias)
33 | self.cached_result = None
34 | self.cached_num_edges = None
35 |
36 | @staticmethod
37 | def norm(edge_index, num_nodes, edge_weight=None, improved=False,
38 | dtype=None):
39 |
40 | if edge_weight is None:
41 | edge_weight = torch.ones((edge_index.size(1), ), dtype=dtype,
42 | device=edge_index.device)
43 |
44 |
45 | fill_value = 1 if not improved else 2
46 | edge_index, edge_weight = add_remaining_self_loops(
47 | edge_index, edge_weight, fill_value, num_nodes)
48 |
49 | row, col = edge_index
50 | deg = scatter_add(edge_weight, row, dim=0, dim_size=num_nodes)
51 | deg_inv_sqrt = deg.pow(-0.5)
52 | deg_inv_sqrt[deg_inv_sqrt == float('inf')] = 0
53 |
54 | return edge_index, deg_inv_sqrt[row] * edge_weight * deg_inv_sqrt[col]
55 |
56 | def forward(self, x, edge_index, edge_weight=None):
57 | """"""
58 | x = torch.matmul(x, self.weight)
59 |
60 | if self.cached and self.cached_result is not None:
61 | if edge_index.size(1) != self.cached_num_edges:
62 | raise RuntimeError(
63 | 'Cached {} number of edges, but found {}. Please '
64 | 'disable the caching behavior of this layer by removing '
65 | 'the `cached=True` argument in its constructor.'.format(
66 | self.cached_num_edges, edge_index.size(1)))
67 |
68 | if not self.cached or self.cached_result is None:
69 | self.cached_num_edges = edge_index.size(1)
70 | if self.normalize:
71 | edge_index, norm = self.norm(edge_index, x.size(self.node_dim),
72 | edge_weight, self.improved,
73 | x.dtype)
74 | else:
75 | norm = edge_weight
76 | self.cached_result = edge_index, norm
77 |
78 | edge_index, norm = self.cached_result
79 |
80 | return self.propagate(edge_index, x=x, norm=norm)
81 |
82 | def message(self, x_j, norm):
83 |
84 | return norm.view(-1, 1) * x_j if norm is not None else x_j
85 |
86 | def update(self, aggr_out):
87 | if self.bias is not None:
88 | aggr_out = aggr_out + self.bias
89 | return aggr_out
90 |
91 | def __repr__(self):
92 | return '{}({}, {})'.format(self.__class__.__name__, self.in_channels,
93 | self.out_channels)
94 |
--------------------------------------------------------------------------------
/run_multi.sh:
--------------------------------------------------------------------------------
1 | wei_arr=(
2 | 0.0783
3 | 0.0620
4 | 0.0493
5 | 0.0394
6 | 0.0315
7 | 0.0252
8 | 0.0201
9 | 0.0161
10 | 0.0129
11 | 0.0103)
12 |
13 | adj_arr=(
14 | 0.5656
15 | 0.5372
16 | 0.5102
17 | 0.4843
18 | 0.4598
19 | 0.4368
20 | 0.4149
21 | 0.3941
22 | 0.3737
23 | 0.3547)
24 |
25 |
26 | for model in gcn
27 | do
28 | for data in cora citeseer pubmed Cornell Texas Wisconsin Computers Photo
29 | do
30 | for i in ${!wei_arr[@]}
31 | do
32 | for pr in 0.1 0.2 0.3
33 | do
34 | for uf in 10 20 30
35 | do
36 | for fpe in 50 100 150
37 | do
38 | python main_stgnn.py --method GraNet \
39 | --prune-rate $pr \
40 | --optimizer adam \
41 | --sparse-init ERK \
42 | --init-density 1.0 \
43 | --final-density ${wei_arr[$i]} \
44 | --final-density_adj ${adj_arr[$i]} \
45 | --final-density_feature 1.0 \
46 | --update-frequency $uf \
47 | --l2 0.0005 \
48 | --lr 0.01 \
49 | --epochs 200 \
50 | --model $model \
51 | --data $data \
52 | --final-prune-epoch $fpe \
53 | --growth_schedule momentum \
54 | --adj_sparse \
55 | --weight_sparse \
56 | --sparse
57 | done
58 | done
59 | done
60 | done
61 | done
62 | done
63 |
64 |
65 |
66 |
67 |
68 | # --model: gcn, gat, sgc, appnp, gcnii (5)
69 | # --data: cora, citeseer, citeseer, Cornell, Texas, Wisconsin, Actor
70 | # --data: CS, Physics, Computers, Photo, WikiCS, reddit
71 | # --data: ogbn-arxiv, ogbn-proteins, ogbn-products, ogbn-papers100M (17)
72 | # --weight_sparse or --feature_sparse --sparse (7)
73 | # --sparse: base or sparse train (2)
74 |
75 | # --method: GraNet, GraNet_uniform, GMP, GMP_uniform (4)
76 | # --growth_schedule: gradient, momentum, random (3)
77 | # --sparse_init: uniform, ERK (2)
78 |
79 | # --prune-rate : regenration rate : 0.1, 0.2, 0.3 (3)
80 | # --update-frequency 10 20 30 (3)
81 | # --final-prune-epoch 50 100 150 (3)
82 |
83 | # --init-density: weight init density: 1, (dense to sparse)
84 | # --final-density: weight : 0.5 0.1 0.01 0.0001 (4)
85 | # --final-density_adj : 0.9 0.8 0.7 0.6 0.5 0.4 0.3 0.2 0.1 0.01 (10)
86 | # --final-density_feature: 0.9 0.8 0.7 0.6 0.5 0.4 0.3 0.2 0.1 0.01 (10)
87 |
88 |
89 | # 5 x 17 x 7 x 4 x 3 x 2 X 3 X 3 X 3 X 4 X 10 X 10 = 154,224,000
90 |
91 | # Actual: 4 x 13 x 4 x 10 x 3 x 3 x 3 = 56160
92 |
93 |
94 | # python main_stgnn.py --method GraNet \
95 | # --prune-rate 0.5 \
96 | # --optimizer adam \
97 | # --sparse-init ERK \
98 | # --init-density 0.5 \
99 | # --final-density 0.1 \
100 | # --update-frequency 10 \
101 | # --l2 0.0005 \
102 | # --lr 0.01 \
103 | # --epochs 200 \
104 | # --model gcn \
105 | # --data cora \
106 | # --final-prune-epoch 100
--------------------------------------------------------------------------------
/run_waf.sh:
--------------------------------------------------------------------------------
1 | for model in gcn gat sgc appnp
2 | do
3 | for data in cora citeseer pubmed Cornell Texas Wisconsin Actor CS Physics Computers Photo ogbn-arxiv
4 | do
5 | for fde in 0.5 0.1 0.01
6 | do
7 | for fdf in 0.9 0.8 0.7 0.6 0.5 0.4 0.3 0.2 0.1 0.01
8 | do
9 | for fda in 0.9 0.8 0.7 0.6 0.5 0.4 0.3 0.2 0.1 0.01
10 | do
11 | for pr in 0.1 0.2 0.3
12 | do
13 | for uf in 10 20 30
14 | do
15 | for fpe in 50 100 150
16 | do
17 | python main_stgnn.py --method GraNet \
18 | --prune-rate $pr \
19 | --optimizer adam \
20 | --sparse-init ERK \
21 | --init-density 1.0 \
22 | --final-density $fde \
23 | --final-density_adj $fda \
24 | --final-density_feature $fdf \
25 | --update-frequency $uf \
26 | --l2 0.0005 \
27 | --lr 0.01 \
28 | --cuda $1 \
29 | --epochs 200 \
30 | --model $model \
31 | --data $data \
32 | --final-prune-epoch $fpe \
33 | --growth_schedule momentum \
34 | --feature_sparse \
35 | --weight_sparse \
36 | --adj_sparse \
37 | --sparse
38 | done
39 | done
40 | done
41 | done
42 | done
43 | done
44 | done
45 | done
46 |
47 |
48 | # cora citeseer
49 |
50 |
51 |
52 | # --model: gcn, gat, sgc, appnp, gcnii (5)
53 | # --data: cora, citeseer, citeseer, Cornell, Texas, Wisconsin, Actor
54 | # --data: CS, Physics, Computers, Photo, WikiCS, reddit
55 | # --data: ogbn-arxiv, ogbn-proteins, ogbn-products, ogbn-papers100M (17)
56 | # --weight_sparse or --feature_sparse --sparse (7)
57 | # --sparse: base or sparse train (2)
58 |
59 | # --method: GraNet, GraNet_uniform, GMP, GMP_uniform (4)
60 | # --growth_schedule: gradient, momentum, random (3)
61 | # --sparse_init: uniform, ERK (2)
62 |
63 | # --prune-rate : regenration rate : 0.1, 0.2, 0.3 (3)
64 | # --update-frequency 10 20 30 (3)
65 | # --final-prune-epoch 50 100 150 (3)
66 |
67 | # --init-density: weight init density: 1, (dense to sparse)
68 | # --final-density: weight : 0.5 0.1 0.01 0.0001 (4)
69 | # --final-density_adj : 0.9 0.8 0.7 0.6 0.5 0.4 0.3 0.2 0.1 0.01 (10)
70 | # --final-density_feature: 0.9 0.8 0.7 0.6 0.5 0.4 0.3 0.2 0.1 0.01 (10)
71 |
72 |
73 | # 5 x 17 x 7 x 4 x 3 x 2 X 3 X 3 X 3 X 4 X 10 X 10 = 154,224,000
74 |
75 | # Actual: 4 x 13 x 4 x 10 x 3 x 3 x 3 = 56160
76 |
77 |
78 | # python main_stgnn.py --method GraNet \
79 | # --prune-rate 0.5 \
80 | # --optimizer adam \
81 | # --sparse-init ERK \
82 | # --init-density 0.5 \
83 | # --final-density 0.1 \
84 | # --update-frequency 10 \
85 | # --l2 0.0005 \
86 | # --lr 0.01 \
87 | # --epochs 200 \
88 | # --model gcn \
89 | # --data cora \
90 | # --final-prune-epoch 100
--------------------------------------------------------------------------------
/models/fagcn_conv.py:
--------------------------------------------------------------------------------
1 | from typing import Optional, Tuple
2 |
3 | import torch.nn.functional as F
4 | from torch import Tensor
5 | from torch_sparse import SparseTensor
6 |
7 | from torch_geometric.nn.conv import MessagePassing
8 | from torch_geometric.nn.conv.gcn_conv import gcn_norm
9 | from torch_geometric.nn.dense.linear import Linear
10 | from torch_geometric.typing import Adj, OptTensor
11 |
12 |
13 | class FAConv(MessagePassing):
14 |
15 | _cached_edge_index: Optional[Tuple[Tensor, Tensor]]
16 | _cached_adj_t: Optional[SparseTensor]
17 | _alpha: OptTensor
18 |
19 | def __init__(self, channels: int, eps: float = 0.1, dropout: float = 0.0,
20 | cached: bool = False, add_self_loops: bool = True,
21 | normalize: bool = True, **kwargs):
22 |
23 | kwargs.setdefault('aggr', 'add')
24 | super(FAConv, self).__init__(**kwargs)
25 |
26 | self.channels = channels
27 | self.eps = eps
28 | self.dropout = dropout
29 | self.cached = cached
30 | self.add_self_loops = add_self_loops
31 | self.normalize = normalize
32 |
33 | self._cached_edge_index = None
34 | self._cached_adj_t = None
35 | self._alpha = None
36 |
37 | self.att_l = Linear(channels, 1, bias=False)
38 | self.att_r = Linear(channels, 1, bias=False)
39 |
40 | self.reset_parameters()
41 |
42 | def reset_parameters(self):
43 | self.att_l.reset_parameters()
44 | self.att_r.reset_parameters()
45 | self._cached_edge_index = None
46 | self._cached_adj_t = None
47 |
48 |
49 | def forward(self, x: Tensor, x_0: Tensor, edge_index: Adj, edge_weight: OptTensor = None, return_attention_weights=None):
50 | if self.normalize:
51 | if isinstance(edge_index, Tensor):
52 | #assert edge_weight is None
53 | cache = self._cached_edge_index
54 | if cache is None:
55 | edge_index, edge_weight = gcn_norm( # yapf: disable
56 | edge_index, None, x.size(self.node_dim), False,
57 | self.add_self_loops, dtype=x.dtype)
58 | if self.cached:
59 | self._cached_edge_index = (edge_index, edge_weight)
60 | else:
61 | edge_index, edge_weight = cache[0], cache[1]
62 |
63 | elif isinstance(edge_index, SparseTensor):
64 | assert not edge_index.has_value()
65 | cache = self._cached_adj_t
66 | if cache is None:
67 | edge_index = gcn_norm( # yapf: disable
68 | edge_index, None, x.size(self.node_dim), False,
69 | self.add_self_loops, dtype=x.dtype)
70 | if self.cached:
71 | self._cached_adj_t = edge_index
72 | else:
73 | edge_index = cache
74 | else:
75 | if isinstance(edge_index, Tensor):
76 | assert edge_weight is not None
77 | elif isinstance(edge_index, SparseTensor):
78 | assert edge_index.has_value()
79 |
80 | alpha_l = self.att_l(x)
81 | alpha_r = self.att_r(x)
82 |
83 | # propagate_type: (x: Tensor, alpha: PairTensor, edge_weight: OptTensor) # noqa
84 | out = self.propagate(edge_index, x=x, alpha=(alpha_l, alpha_r),
85 | edge_weight=edge_weight, size=None)
86 |
87 | alpha = self._alpha
88 | self._alpha = None
89 |
90 | if self.eps != 0.0:
91 | out += self.eps * x_0
92 |
93 | if isinstance(return_attention_weights, bool):
94 | assert alpha is not None
95 | if isinstance(edge_index, Tensor):
96 | return out, (edge_index, alpha)
97 | elif isinstance(edge_index, SparseTensor):
98 | return out, edge_index.set_value(alpha, layout='coo')
99 | else:
100 | return out
101 |
102 |
103 | def message(self, x_j: Tensor, alpha_j: Tensor, alpha_i: Tensor,
104 | edge_weight: OptTensor) -> Tensor:
105 | assert edge_weight is not None
106 | alpha = (alpha_j + alpha_i).tanh().squeeze(-1)
107 | self._alpha = alpha
108 | alpha = F.dropout(alpha, p=self.dropout, training=self.training)
109 | return x_j * (alpha * edge_weight).view(-1, 1)
110 |
111 | def __repr__(self) -> str:
112 | return f'{self.__class__.__name__}({self.channels}, eps={self.eps})'
--------------------------------------------------------------------------------
/sparselearning/sparse_sgd.py:
--------------------------------------------------------------------------------
1 | from torch.optim.optimizer import Optimizer, required
2 | import torch
3 | import numpy as np
4 | class sparse_SGD(Optimizer):
5 | r"""Implements sparse stochastic gradient descent (optionally with momentum), according to the pytorch version 1.5.1.
6 |
7 | Nesterov momentum is based on the formula from
8 | `On the importance of initialization and momentum in deep learning`__.
9 |
10 | Args:
11 | params (iterable): iterable of parameters to optimize or dicts defining
12 | parameter groups
13 | lr (float): learning rate
14 | momentum (float, optional): momentum factor (default: 0)
15 | weight_decay (float, optional): weight decay (L2 penalty) (default: 0)
16 | dampening (float, optional): dampening for momentum (default: 0)
17 | nesterov (bool, optional): enables Nesterov momentum (default: False)
18 |
19 | Example:
20 | >>> optimizer = torch.optim.SGD(model.parameters(), lr=0.1, momentum=0.9)
21 | >>> optimizer.zero_grad()
22 | >>> loss_fn(model(input), target).backward()
23 | >>> optimizer.step()
24 |
25 | __ http://www.cs.toronto.edu/%7Ehinton/absps/momentum.pdf
26 |
27 | .. note::
28 | The implementation of SGD with Momentum/Nesterov subtly differs from
29 | Sutskever et. al. and implementations in some other frameworks.
30 |
31 | Considering the specific case of Momentum, the update can be written as
32 |
33 | .. math::
34 | v = \rho * v + g \\
35 | p = p - lr * v
36 |
37 | where p, g, v and :math:`\rho` denote the parameters, gradient,
38 | velocity, and momentum respectively.
39 |
40 | This is in contrast to Sutskever et. al. and
41 | other frameworks which employ an update of the form
42 |
43 | .. math::
44 | v = \rho * v + lr * g \\
45 | p = p - v
46 |
47 | The Nesterov version is analogously modified.
48 | """
49 |
50 | def __init__(self, params, lr=required, momentum=0, dampening=0,
51 | weight_decay=0, nesterov=False):
52 | if lr is not required and lr < 0.0:
53 | raise ValueError("Invalid learning rate: {}".format(lr))
54 | if momentum < 0.0:
55 | raise ValueError("Invalid momentum value: {}".format(momentum))
56 | if weight_decay < 0.0:
57 | raise ValueError("Invalid weight_decay value: {}".format(weight_decay))
58 |
59 | defaults = dict(lr=lr, momentum=momentum, dampening=dampening,
60 | weight_decay=weight_decay, nesterov=nesterov)
61 | if nesterov and (momentum <= 0 or dampening != 0):
62 | raise ValueError("Nesterov momentum requires a momentum and zero dampening")
63 | super(sparse_SGD, self).__init__(params, defaults)
64 |
65 | def __setstate__(self, state):
66 | super(sparse_SGD, self).__setstate__(state)
67 | for group in self.param_groups:
68 | group.setdefault('nesterov', False)
69 |
70 | @torch.no_grad()
71 | def step(self, closure=None, nonzero_masks=None, new_masks=None, gamma=None, epoch=None):
72 | """Performs a single optimization step.
73 |
74 | Arguments:
75 | closure (callable, optional): A closure that reevaluates the model
76 | and returns the loss.
77 | """
78 | loss = None
79 | if closure is not None:
80 | with torch.enable_grad():
81 | loss = closure()
82 |
83 | if epoch <= 100:
84 | for group in self.param_groups:
85 | weight_decay = group['weight_decay']
86 | momentum = group['momentum']
87 | dampening = group['dampening']
88 | nesterov = group['nesterov']
89 |
90 | for p in group['params']:
91 | if p.grad is None:
92 | continue
93 | d_p = p.grad
94 | if weight_decay != 0:
95 | d_p = d_p.add(p, alpha=weight_decay)
96 | if momentum != 0:
97 | param_state = self.state[p]
98 | if 'momentum_buffer' not in param_state:
99 | buf = param_state['momentum_buffer'] = torch.clone(d_p).detach()
100 | else:
101 | buf = param_state['momentum_buffer']
102 | buf.mul_(momentum).add_(d_p, alpha=1 - dampening)
103 | if nesterov:
104 | d_p = d_p.add(buf, alpha=momentum)
105 | else:
106 | d_p = buf
107 |
108 | p.add_(d_p, alpha=-group['lr'])
109 | else:
110 | for group in self.param_groups:
111 | weight_decay = group['weight_decay']
112 | momentum = group['momentum']
113 | dampening = group['dampening']
114 | nesterov = group['nesterov']
115 |
116 | for i, p in enumerate(group['params']):
117 | if p.grad is None:
118 | continue
119 |
120 | sparse_layer_flag = False
121 | for key in nonzero_masks.keys():
122 | if i == float(key.split('_')[-1]):
123 | nonzero_mask = nonzero_masks[key]
124 | new_mask = new_masks[key]
125 | sparse_layer_flag = True
126 |
127 | d_p = p.grad
128 | if weight_decay != 0:
129 | d_p = d_p.add(p, alpha=weight_decay)
130 | if momentum != 0:
131 | param_state = self.state[p]
132 | if 'momentum_buffer' not in param_state:
133 | buf = param_state['momentum_buffer'] = torch.clone(d_p).detach()
134 | else:
135 | buf = param_state['momentum_buffer']
136 | buf.mul_(momentum).add_(d_p, alpha=1 - dampening)
137 | if nesterov:
138 | d_p = d_p.add(buf, alpha=momentum)
139 | else:
140 | d_p = buf
141 |
142 | p.add_(d_p, alpha=-group['lr'])
143 |
144 | if sparse_layer_flag:
145 | p.add_(d_p * nonzero_mask, alpha=-group['lr'])
146 | p.add_(d_p * new_mask, alpha=-gamma)
147 |
148 | else:
149 | p.add_(d_p, alpha=-group['lr'])
150 |
151 | return loss
--------------------------------------------------------------------------------
/models/gat_conv.py:
--------------------------------------------------------------------------------
1 | from typing import Union, Tuple, Optional
2 | from torch_geometric.typing import (OptPairTensor, Adj, Size, NoneType, OptTensor)
3 | import torch
4 | from torch import Tensor
5 | import torch.nn.functional as F
6 | from torch.nn import Parameter, Linear
7 | from torch_sparse import SparseTensor, set_diag
8 | from torch_geometric.nn.conv import MessagePassing
9 | from torch_geometric.utils import remove_self_loops, add_self_loops, softmax
10 | from torch_geometric.nn.inits import glorot, zeros
11 | import pdb
12 |
13 | class GATConv(MessagePassing):
14 |
15 | _alpha: OptTensor
16 |
17 | def __init__(self, in_channels: Union[int, Tuple[int, int]],
18 | out_channels: int, heads: int = 1, concat: bool = True,
19 | negative_slope: float = 0.2, dropout: float = 0.0,
20 | add_self_loops: bool = True, bias: bool = True, **kwargs):
21 | kwargs.setdefault('aggr', 'add')
22 | super(GATConv, self).__init__(node_dim=0, **kwargs)
23 |
24 | self.in_channels = in_channels
25 | self.out_channels = out_channels
26 | self.heads = heads
27 | self.concat = concat
28 | self.negative_slope = negative_slope
29 | self.dropout = dropout
30 | self.add_self_loops = add_self_loops
31 |
32 | if isinstance(in_channels, int):
33 | self.lin_l = Linear(in_channels, heads * out_channels, bias=False)
34 | self.lin_r = self.lin_l
35 | else:
36 | self.lin_l = Linear(in_channels[0], heads * out_channels, False)
37 | self.lin_r = Linear(in_channels[1], heads * out_channels, False)
38 |
39 | self.att_l = Parameter(torch.Tensor(1, heads, out_channels))
40 | self.att_r = Parameter(torch.Tensor(1, heads, out_channels))
41 |
42 | if bias and concat:
43 | self.bias = Parameter(torch.Tensor(heads * out_channels))
44 | elif bias and not concat:
45 | self.bias = Parameter(torch.Tensor(out_channels))
46 | else:
47 | self.register_parameter('bias', None)
48 |
49 | self._alpha = None
50 |
51 | self.reset_parameters()
52 |
53 | def reset_parameters(self):
54 | glorot(self.lin_l.weight)
55 | glorot(self.lin_r.weight)
56 | glorot(self.att_l)
57 | glorot(self.att_r)
58 | zeros(self.bias)
59 |
60 | def forward(self, x: Union[Tensor, OptPairTensor], edge_index: Adj,
61 | size: Size = None, return_attention_weights=None, edge_weight=None):
62 | # type: (Union[Tensor, OptPairTensor], Tensor, Size, NoneType) -> Tensor # noqa
63 | # type: (Union[Tensor, OptPairTensor], SparseTensor, Size, NoneType) -> Tensor # noqa
64 | # type: (Union[Tensor, OptPairTensor], Tensor, Size, bool) -> Tuple[Tensor, Tuple[Tensor, Tensor]] # noqa
65 | # type: (Union[Tensor, OptPairTensor], SparseTensor, Size, bool) -> Tuple[Tensor, SparseTensor] # noqa
66 | r"""
67 | Args:
68 | return_attention_weights (bool, optional): If set to :obj:`True`,
69 | will additionally return the tuple
70 | :obj:`(edge_index, attention_weights)`, holding the computed
71 | attention weights for each edge. (default: :obj:`None`)
72 | """
73 | H, C = self.heads, self.out_channels # 4, 256
74 |
75 | x_l: OptTensor = None
76 | x_r: OptTensor = None
77 | alpha_l: OptTensor = None
78 | alpha_r: OptTensor = None
79 | if isinstance(x, Tensor):
80 | assert x.dim() == 2, 'Static graphs not supported in `GATConv`.'
81 | x_l = x_r = self.lin_l(x).view(-1, H, C)
82 | alpha_l = (x_l * self.att_l).sum(dim=-1)
83 | alpha_r = (x_r * self.att_r).sum(dim=-1)
84 | else:
85 | x_l, x_r = x[0], x[1]
86 | assert x[0].dim() == 2, 'Static graphs not supported in `GATConv`.'
87 | x_l = self.lin_l(x_l).view(-1, H, C)
88 | alpha_l = (x_l * self.att_l).sum(dim=-1)
89 | if x_r is not None:
90 | x_r = self.lin_r(x_r).view(-1, H, C)
91 | alpha_r = (x_r * self.att_r).sum(dim=-1)
92 |
93 | assert x_l is not None
94 | assert alpha_l is not None
95 |
96 | if self.add_self_loops:
97 | if isinstance(edge_index, Tensor):
98 | num_nodes = x_l.size(0)
99 | if x_r is not None:
100 | num_nodes = min(num_nodes, x_r.size(0))
101 | if size is not None:
102 | num_nodes = min(size[0], size[1])
103 | edge_index, _ = remove_self_loops(edge_index)
104 | edge_index, _ = add_self_loops(edge_index, num_nodes=num_nodes)
105 | if edge_weight is not None:
106 | loop_weight = torch.full((num_nodes, ), 1,
107 | dtype=edge_weight.dtype,
108 | device=edge_weight.device)
109 | edge_weight = torch.cat([edge_weight, loop_weight], dim=0)
110 | elif isinstance(edge_index, SparseTensor):
111 | edge_index = set_diag(edge_index)
112 |
113 | # propagate_type: (x: OptPairTensor, alpha: OptPairTensor)
114 | out = self.propagate(edge_index,
115 | x=(x_l, x_r),
116 | alpha=(alpha_l, alpha_r),
117 | size=size,
118 | edge_weight=edge_weight)
119 |
120 | alpha = self._alpha
121 | self._alpha = None
122 |
123 | if self.concat:
124 | out = out.view(-1, self.heads * self.out_channels)
125 | else:
126 | out = out.mean(dim=1)
127 |
128 | if self.bias is not None:
129 | out += self.bias
130 |
131 | if isinstance(return_attention_weights, bool):
132 | assert alpha is not None
133 | if isinstance(edge_index, Tensor):
134 | return out, (edge_index, alpha)
135 | elif isinstance(edge_index, SparseTensor):
136 | return out, edge_index.set_value(alpha, layout='coo')
137 | else:
138 | return out
139 |
140 | def message(self,
141 | x_j: Tensor,
142 | alpha_j: Tensor,
143 | alpha_i: OptTensor,
144 | index: Tensor,
145 | ptr: OptTensor,
146 | size_i: Optional[int],
147 | edge_weight: Tensor) -> Tensor:
148 | alpha = alpha_j if alpha_i is None else alpha_j + alpha_i
149 | alpha = F.leaky_relu(alpha, self.negative_slope)
150 | alpha = softmax(alpha, index, ptr, size_i)
151 | self._alpha = alpha
152 | alpha = F.dropout(alpha, p=self.dropout, training=self.training)
153 | if edge_weight is None:
154 | return x_j * alpha.unsqueeze(-1)
155 | else:
156 | return x_j * alpha.unsqueeze(-1) * edge_weight.expand(alpha.shape[1], alpha.shape[0]).t().unsqueeze(-1)
157 |
158 | def __repr__(self):
159 | return '{}({}, {}, heads={})'.format(self.__class__.__name__,
160 | self.in_channels,
161 | self.out_channels,
162 | self.heads)
163 |
--------------------------------------------------------------------------------
/sparselearning/models.py:
--------------------------------------------------------------------------------
1 | import math
2 | import time
3 | import numpy as np
4 |
5 | import torch
6 | import torch.nn as nn
7 | import torch.nn.functional as F
8 | class SparseSpeedupBench(object):
9 | """Class to benchmark speedups for convolutional layers.
10 |
11 | Basic usage:
12 | 1. Assing a single SparseSpeedupBench instance to class (and sub-classes with conv layers).
13 | 2. Instead of forwarding input through normal convolutional layers, we pass them through the bench:
14 | self.bench = SparseSpeedupBench()
15 | self.conv_layer1 = nn.Conv2(3, 96, 3)
16 |
17 | if self.bench is not None:
18 | outputs = self.bench.forward(self.conv_layer1, inputs, layer_id='conv_layer1')
19 | else:
20 | outputs = self.conv_layer1(inputs)
21 | 3. Speedups of the convolutional layer will be aggregated and print every 1000 mini-batches.
22 | """
23 | def __init__(self):
24 | self.layer_timings = {}
25 | self.layer_timings_channel_sparse = {}
26 | self.layer_timings_sparse = {}
27 | self.iter_idx = 0
28 | self.layer_0_idx = None
29 | self.total_timings = []
30 | self.total_timings_channel_sparse = []
31 | self.total_timings_sparse = []
32 |
33 | def get_density(self, x):
34 | return (x.data!=0.0).sum().item()/x.numel()
35 |
36 | def print_weights(self, w, layer):
37 | # w dims: out, in, k1, k2
38 | #outers = []
39 | #for outer in range(w.shape[0]):
40 | # inners = []
41 | # for inner in range(w.shape[1]):
42 | # n = np.prod(w.shape[2:])
43 | # density = (w[outer, inner, :, :] != 0.0).sum().item() / n
44 | # #print(density, w[outer, inner])
45 | # inners.append(density)
46 | # outers.append([np.mean(inners), np.std(inner)])
47 | #print(outers)
48 | #print(w.shape, (w!=0.0).sum().item()/w.numel())
49 | pass
50 |
51 | def forward(self, layer, x, layer_id):
52 | if self.layer_0_idx is None: self.layer_0_idx = layer_id
53 | if layer_id == self.layer_0_idx: self.iter_idx += 1
54 | self.print_weights(layer.weight.data, layer)
55 |
56 | # calc input sparsity
57 | sparse_channels_in = ((x.data != 0.0).sum([2, 3]) == 0.0).sum().item()
58 | num_channels_in = x.shape[1]
59 | batch_size = x.shape[0]
60 | channel_sparsity_input = sparse_channels_in/float(num_channels_in*batch_size)
61 | input_sparsity = self.get_density(x)
62 |
63 | # bench dense layer
64 | start = torch.cuda.Event(enable_timing=True)
65 | end = torch.cuda.Event(enable_timing=True)
66 | start.record()
67 | x = layer(x)
68 | end.record()
69 | start.synchronize()
70 | end.synchronize()
71 | time_taken_s = start.elapsed_time(end)/1000.0
72 |
73 | # calc weight sparsity
74 | num_channels = layer.weight.shape[1]
75 | sparse_channels = ((layer.weight.data != 0.0).sum([0, 2, 3]) == 0.0).sum().item()
76 | channel_sparsity_weight = sparse_channels/float(num_channels)
77 | weight_sparsity = self.get_density(layer.weight)
78 |
79 | # store sparse and dense timings
80 | if layer_id not in self.layer_timings:
81 | self.layer_timings[layer_id] = []
82 | self.layer_timings_channel_sparse[layer_id] = []
83 | self.layer_timings_sparse[layer_id] = []
84 | self.layer_timings[layer_id].append(time_taken_s)
85 | self.layer_timings_channel_sparse[layer_id].append(time_taken_s*(1.0-channel_sparsity_weight)*(1.0-channel_sparsity_input))
86 | self.layer_timings_sparse[layer_id].append(time_taken_s*input_sparsity*weight_sparsity)
87 |
88 | if self.iter_idx % 1000 == 0:
89 | self.print_layer_timings()
90 | self.iter_idx += 1
91 |
92 | return x
93 |
94 | def print_layer_timings(self):
95 | total_time_dense = 0.0
96 | total_time_sparse = 0.0
97 | total_time_channel_sparse = 0.0
98 | print('\n')
99 | for layer_id in self.layer_timings:
100 | t_dense = np.mean(self.layer_timings[layer_id])
101 | t_channel_sparse = np.mean(self.layer_timings_channel_sparse[layer_id])
102 | t_sparse = np.mean(self.layer_timings_sparse[layer_id])
103 | total_time_dense += t_dense
104 | total_time_sparse += t_sparse
105 | total_time_channel_sparse += t_channel_sparse
106 |
107 | print('Layer {0}: Dense {1:.6f} Channel Sparse {2:.6f} vs Full Sparse {3:.6f}'.format(layer_id, t_dense, t_channel_sparse, t_sparse))
108 | self.total_timings.append(total_time_dense)
109 | self.total_timings_sparse.append(total_time_sparse)
110 | self.total_timings_channel_sparse.append(total_time_channel_sparse)
111 |
112 | print('Speedups for this segment:')
113 | print('Dense took {0:.4f}s. Channel Sparse took {1:.4f}s. Speedup of {2:.4f}x'.format(total_time_dense, total_time_channel_sparse, total_time_dense/total_time_channel_sparse))
114 | print('Dense took {0:.4f}s. Sparse took {1:.4f}s. Speedup of {2:.4f}x'.format(total_time_dense, total_time_sparse, total_time_dense/total_time_sparse))
115 | print('\n')
116 |
117 | total_dense = np.sum(self.total_timings)
118 | total_sparse = np.sum(self.total_timings_sparse)
119 | total_channel_sparse = np.sum(self.total_timings_channel_sparse)
120 | print('Speedups for entire training:')
121 | print('Dense took {0:.4f}s. Channel Sparse took {1:.4f}s. Speedup of {2:.4f}x'.format(total_dense, total_channel_sparse, total_dense/total_channel_sparse))
122 | print('Dense took {0:.4f}s. Sparse took {1:.4f}s. Speedup of {2:.4f}x'.format(total_dense, total_sparse, total_dense/total_sparse))
123 | print('\n')
124 |
125 | # clear timings
126 | for layer_id in list(self.layer_timings.keys()):
127 | self.layer_timings.pop(layer_id)
128 | self.layer_timings_channel_sparse.pop(layer_id)
129 | self.layer_timings_sparse.pop(layer_id)
130 |
131 |
132 |
133 | class AlexNet(nn.Module):
134 | """AlexNet with batch normalization and without pooling.
135 |
136 | This is an adapted version of AlexNet as taken from
137 | SNIP: Single-shot Network Pruning based on Connection Sensitivity,
138 | https://arxiv.org/abs/1810.02340
139 |
140 | There are two different version of AlexNet:
141 | AlexNet-s (small): Has hidden layers with size 1024
142 | AlexNet-b (big): Has hidden layers with size 2048
143 |
144 | Based on https://github.com/mi-lad/snip/blob/master/train.py
145 | by Milad Alizadeh.
146 | """
147 |
148 | def __init__(self, config='s', num_classes=1000, save_features=False, bench_model=False):
149 | super(AlexNet, self).__init__()
150 | self.save_features = save_features
151 | self.feats = []
152 | self.densities = []
153 | self.bench = None if not bench_model else SparseSpeedupBench()
154 |
155 | factor = 1 if config=='s' else 2
156 | self.features = nn.Sequential(
157 | nn.Conv2d(3, 96, kernel_size=11, stride=2, padding=2, bias=True),
158 | nn.BatchNorm2d(96),
159 | nn.ReLU(inplace=True),
160 | nn.Conv2d(96, 256, kernel_size=5, stride=2, padding=2, bias=True),
161 | nn.BatchNorm2d(256),
162 | nn.ReLU(inplace=True),
163 | nn.Conv2d(256, 384, kernel_size=3, stride=2, padding=1, bias=True),
164 | nn.BatchNorm2d(384),
165 | nn.ReLU(inplace=True),
166 | nn.Conv2d(384, 384, kernel_size=3, stride=2, padding=1, bias=True),
167 | nn.BatchNorm2d(384),
168 | nn.ReLU(inplace=True),
169 | nn.Conv2d(384, 256, kernel_size=3, stride=2, padding=1, bias=True),
170 | nn.BatchNorm2d(256),
171 | nn.ReLU(inplace=True),
172 | )
173 | self.classifier = nn.Sequential(
174 | nn.Linear(256, 1024*factor),
175 | nn.BatchNorm1d(1024*factor),
176 | nn.ReLU(inplace=True),
177 | nn.Linear(1024*factor, 1024*factor),
178 | nn.BatchNorm1d(1024*factor),
179 | nn.ReLU(inplace=True),
180 | nn.Linear(1024*factor, num_classes),
181 | )
182 |
183 | def forward(self, x):
184 | for layer_id, layer in enumerate(self.features):
185 | if self.bench is not None and isinstance(layer, nn.Conv2d):
186 | x = self.bench.forward(layer, x, layer_id)
187 | else:
188 | x = layer(x)
189 |
190 | if self.save_features:
191 | if isinstance(layer, nn.ReLU):
192 | self.feats.append(x.clone().detach())
193 | if isinstance(layer, nn.Conv2d):
194 | self.densities.append((layer.weight.data != 0.0).sum().item()/layer.weight.numel())
195 |
196 | x = x.view(x.size(0), -1)
197 | x = self.classifier(x)
198 | return F.log_softmax(x, dim=1)
199 |
200 | class LeNet_300_100(nn.Module):
201 | """Simple NN with hidden layers [300, 100]
202 |
203 | Based on https://github.com/mi-lad/snip/blob/master/train.py
204 | by Milad Alizadeh.
205 | """
206 | def __init__(self, save_features=None, bench_model=False):
207 | super(LeNet_300_100, self).__init__()
208 | self.fc1 = nn.Linear(28*28, 300, bias=True)
209 | self.fc2 = nn.Linear(300, 100, bias=True)
210 | self.fc3 = nn.Linear(100, 10, bias=True)
211 | self.mask = None
212 |
213 | def forward(self, x):
214 | x0 = x.view(-1, 28*28)
215 | x1 = F.relu(self.fc1(x0))
216 | x2 = F.relu(self.fc2(x1))
217 | x3 = self.fc3(x2)
218 | return F.log_softmax(x3, dim=1)
219 |
220 | class MLP_CIFAR10(nn.Module):
221 | def __init__(self, save_features=None, bench_model=False):
222 | super(MLP_CIFAR10, self).__init__()
223 |
224 | self.fc1 = nn.Linear(3*32*32, 1024)
225 | self.fc2 = nn.Linear(1024, 512)
226 | self.fc3 = nn.Linear(512, 10)
227 |
228 | def forward(self, x):
229 | x0 = F.relu(self.fc1(x.view(-1, 3*32*32)))
230 | x1 = F.relu(self.fc2(x0))
231 | return F.log_softmax(self.fc3(x1), dim=1)
232 |
233 |
234 | class LeNet_5_Caffe(nn.Module):
235 | """LeNet-5 without padding in the first layer.
236 | This is based on Caffe's implementation of Lenet-5 and is slightly different
237 | from the vanilla LeNet-5. Note that the first layer does NOT have padding
238 | and therefore intermediate shapes do not match the official LeNet-5.
239 |
240 | Based on https://github.com/mi-lad/snip/blob/master/train.py
241 | by Milad Alizadeh.
242 | """
243 |
244 | def __init__(self, save_features=None, bench_model=False):
245 | super().__init__()
246 | self.conv1 = nn.Conv2d(1, 20, 5, padding=0, bias=True)
247 | self.conv2 = nn.Conv2d(20, 50, 5, bias=True)
248 | self.fc3 = nn.Linear(50 * 4 * 4, 500)
249 | self.fc4 = nn.Linear(500, 10)
250 |
251 | def forward(self, x):
252 | x = F.relu(self.conv1(x))
253 | x = F.max_pool2d(x, 2)
254 | x = F.relu(self.conv2(x))
255 | x = F.max_pool2d(x, 2)
256 | x = F.relu(self.fc3(x.view(-1, 50 * 4 * 4)))
257 | x = F.log_softmax(self.fc4(x), dim=1)
258 |
259 | return x
260 |
261 |
262 | VGG_CONFIGS = {
263 | # M for MaxPool, Number for channels
264 | 'like': [
265 | 64, 64, 'M', 128, 128, 'M', 256, 256, 256, 'M', 512, 512, 512, 'M',
266 | 512, 512, 512, 'M'
267 | ],
268 | 'D': [
269 | 64, 64, 'M', 128, 128, 'M', 256, 256, 256, 'M', 512, 512, 512, 'M',
270 | 512, 512, 512, 'M'
271 | ],
272 | 'C': [
273 | 64, 64, 'M', 128, 128, 'M', 256, 256, (1, 256), 'M', 512, 512, (1, 512), 'M',
274 | 512, 512, (1, 512), 'M' # tuples indicate (kernel size, output channels)
275 | ]
276 | }
277 |
278 |
279 | class VGG16(nn.Module):
280 | """
281 | This is a base class to generate three VGG variants used in SNIP paper:
282 | 1. VGG-C (16 layers)
283 | 2. VGG-D (16 layers)
284 | 3. VGG-like
285 |
286 | Some of the differences:
287 | * Reduced size of FC lis ayers to 512
288 | * Adjusted flattening to match CIFAR-10 shapes
289 | * Replaced dropout layers with BatchNorm
290 |
291 | Based on https://github.com/mi-lad/snip/blob/master/train.py
292 | by Milad Alizadeh.
293 | """
294 |
295 | def __init__(self, config, num_classes=10, save_features=False, bench_model=False):
296 | super().__init__()
297 |
298 | self.features = self.make_layers(VGG_CONFIGS[config], batch_norm=True)
299 | self.feats = []
300 | self.densities = []
301 | self.save_features = save_features
302 | self.bench = None if not bench_model else SparseSpeedupBench()
303 |
304 | if config == 'C' or config == 'D':
305 | self.classifier = nn.Sequential(
306 | nn.Linear((512 if config == 'D' else 2048), 512), # 512 * 7 * 7 in the original VGG
307 | nn.ReLU(True),
308 | nn.BatchNorm1d(512), # instead of dropout
309 | nn.Linear(512, 512),
310 | nn.ReLU(True),
311 | nn.BatchNorm1d(512), # instead of dropout
312 | nn.Linear(512, num_classes),
313 | )
314 | else:
315 | self.classifier = nn.Sequential(
316 | nn.Linear(512, 512), # 512 * 7 * 7 in the original VGG
317 | nn.ReLU(True),
318 | nn.BatchNorm1d(512), # instead of dropout
319 | nn.Linear(512, num_classes),
320 | )
321 |
322 | @staticmethod
323 | def make_layers(config, batch_norm=False):
324 | layers = []
325 | in_channels = 3
326 | for v in config:
327 | if v == 'M':
328 | layers += [nn.MaxPool2d(kernel_size=2, stride=2)]
329 | else:
330 | kernel_size = 3
331 | if isinstance(v, tuple):
332 | kernel_size, v = v
333 | conv2d = nn.Conv2d(in_channels, v, kernel_size=kernel_size, padding=1)
334 | if batch_norm:
335 | layers += [
336 | conv2d,
337 | nn.BatchNorm2d(v),
338 | nn.ReLU(inplace=True)
339 | ]
340 | else:
341 | layers += [conv2d, nn.ReLU(inplace=True)]
342 | in_channels = v
343 | return nn.Sequential(*layers)
344 |
345 | def forward(self, x):
346 | for layer_id, layer in enumerate(self.features):
347 | if self.bench is not None and isinstance(layer, nn.Conv2d):
348 | x = self.bench.forward(layer, x, layer_id)
349 | else:
350 | x = layer(x)
351 |
352 | if self.save_features:
353 | if isinstance(layer, nn.ReLU):
354 | self.feats.append(x.clone().detach())
355 | self.densities.append((x.data != 0.0).sum().item()/x.numel())
356 |
357 | x = x.view(x.size(0), -1)
358 | x = self.classifier(x)
359 | x = F.log_softmax(x, dim=1)
360 | return x
361 |
362 | class VGG16_Srelu(nn.Module):
363 | """
364 | This is a base class to generate three VGG variants used in SNIP paper:
365 | 1. VGG-C (16 layers)
366 | 2. VGG-D (16 layers)
367 | 3. VGG-like
368 |
369 | Some of the differences:
370 | * Reduced size of FC layers to 512
371 | * Adjusted flattening to match CIFAR-10 shapes
372 | * Replaced dropout layers with BatchNorm
373 |
374 | Based on https://github.com/mi-lad/snip/blob/master/train.py
375 | by Milad Alizadeh.
376 | """
377 |
378 | def __init__(self, config, num_classes=10, save_features=False, bench_model=False):
379 | super().__init__()
380 |
381 | self.features = self.make_layers(VGG_CONFIGS[config], batch_norm=True)
382 | self.feats = []
383 | self.densities = []
384 | self.save_features = save_features
385 | self.bench = None if not bench_model else SparseSpeedupBench()
386 |
387 | if config == 'C' or config == 'D':
388 | self.classifier = nn.Sequential(
389 | nn.Linear((512 if config == 'D' else 2048), 512), # 512 * 7 * 7 in the original VGG
390 | nn.ReLU(True),
391 | nn.BatchNorm1d(512), # instead of dropout
392 | nn.Linear(512, 512),
393 | nn.ReLU(True),
394 | nn.BatchNorm1d(512), # instead of dropout
395 | nn.Linear(512, num_classes),
396 | )
397 | else:
398 | self.classifier = nn.Sequential(
399 | nn.Linear(512, 512), # 512 * 7 * 7 in the original VGG
400 | nn.ReLU(True),
401 | nn.BatchNorm1d(512), # instead of dropout
402 | nn.Linear(512, num_classes),
403 | )
404 |
405 | @staticmethod
406 | def make_layers(config, batch_norm=False):
407 | layers = []
408 | in_channels = 3
409 | for v in config:
410 | if v == 'M':
411 | layers += [nn.MaxPool2d(kernel_size=2, stride=2)]
412 | else:
413 | kernel_size = 3
414 | if isinstance(v, tuple):
415 | kernel_size, v = v
416 | conv2d = nn.Conv2d(in_channels, v, kernel_size=kernel_size, padding=1)
417 | if batch_norm:
418 | layers += [
419 | conv2d,
420 | nn.BatchNorm2d(v),
421 | nn.ReLU(inplace=True)
422 | ]
423 | else:
424 | layers += [conv2d, nn.ReLU(inplace=True)]
425 | in_channels = v
426 | return nn.Sequential(*layers)
427 |
428 | def forward(self, x):
429 | for layer_id, layer in enumerate(self.features):
430 | if self.bench is not None and isinstance(layer, nn.Conv2d):
431 | x = self.bench.forward(layer, x, layer_id)
432 | else:
433 | x = layer(x)
434 |
435 | if self.save_features:
436 | if isinstance(layer, nn.ReLU):
437 | self.feats.append(x.clone().detach())
438 | self.densities.append((x.data != 0.0).sum().item()/x.numel())
439 |
440 | x = x.view(x.size(0), -1)
441 | x = self.classifier(x)
442 | x = F.log_softmax(x, dim=1)
443 | return x
444 |
445 | class WideResNet(nn.Module):
446 | """Wide Residual Network with varying depth and width.
447 |
448 | For more info, see the paper: Wide Residual Networks by Sergey Zagoruyko, Nikos Komodakis
449 | https://arxiv.org/abs/1605.07146
450 | """
451 | def __init__(self, depth, widen_factor, num_classes=10, dropRate=0.3, save_features=False, bench_model=False):
452 | super(WideResNet, self).__init__()
453 | nChannels = [16, 16*widen_factor, 32*widen_factor, 64*widen_factor]
454 | assert((depth - 4) % 6 == 0)
455 | n = (depth - 4) / 6
456 | block = BasicBlock
457 | # 1st conv before any network block
458 | self.conv1 = nn.Conv2d(3, nChannels[0], kernel_size=3, stride=1,
459 | padding=1, bias=False)
460 | self.bench = None if not bench_model else SparseSpeedupBench()
461 | # 1st block
462 | self.block1 = NetworkBlock(n, nChannels[0], nChannels[1], block, 1, dropRate, save_features=save_features, bench=self.bench)
463 | # 2nd block
464 | self.block2 = NetworkBlock(n, nChannels[1], nChannels[2], block, 2, dropRate, save_features=save_features, bench=self.bench)
465 | # 3rd block
466 | self.block3 = NetworkBlock(n, nChannels[2], nChannels[3], block, 2, dropRate, save_features=save_features, bench=self.bench)
467 | # global average pooling and classifier
468 | self.bn1 = nn.BatchNorm2d(nChannels[3])
469 | self.relu = nn.ReLU(inplace=True)
470 | self.fc = nn.Linear(nChannels[3], num_classes)
471 | self.nChannels = nChannels[3]
472 | self.feats = []
473 | self.densities = []
474 | self.save_features = save_features
475 |
476 | for m in self.modules():
477 | if isinstance(m, nn.Conv2d):
478 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
479 | m.weight.data.normal_(0, math.sqrt(2. / n))
480 | elif isinstance(m, nn.BatchNorm2d):
481 | m.weight.data.fill_(1)
482 | m.bias.data.zero_()
483 | elif isinstance(m, nn.Linear):
484 | m.bias.data.zero_()
485 |
486 | def forward(self, x):
487 | if self.bench is not None:
488 | out = self.bench.forward(self.conv1, x, 'conv1')
489 | else:
490 | out = self.conv1(x)
491 |
492 | out = self.block1(out)
493 | out = self.block2(out)
494 | out = self.block3(out)
495 |
496 | if self.save_features:
497 | # this is a mess, but I do not have time to refactor it now
498 | self.feats += self.block1.feats
499 | self.densities += self.block1.densities
500 | del self.block1.feats[:]
501 | del self.block1.densities[:]
502 | self.feats += self.block2.feats
503 | self.densities += self.block2.densities
504 | del self.block2.feats[:]
505 | del self.block2.densities[:]
506 | self.feats += self.block3.feats
507 | self.densities += self.block3.densities
508 | del self.block3.feats[:]
509 | del self.block3.densities[:]
510 |
511 | out = self.relu(self.bn1(out))
512 | out = F.avg_pool2d(out, 8)
513 | out = out.view(-1, self.nChannels)
514 | out = self.fc(out)
515 | return F.log_softmax(out, dim=1)
516 |
517 |
518 | class BasicBlock(nn.Module):
519 | """Wide Residual Network basic block
520 |
521 | For more info, see the paper: Wide Residual Networks by Sergey Zagoruyko, Nikos Komodakis
522 | https://arxiv.org/abs/1605.07146
523 | """
524 | def __init__(self, in_planes, out_planes, stride, dropRate=0.0, save_features=False, bench=None):
525 | super(BasicBlock, self).__init__()
526 | self.bn1 = nn.BatchNorm2d(in_planes)
527 | self.relu1 = nn.ReLU(inplace=True)
528 | self.conv1 = nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride,
529 | padding=1, bias=False)
530 | self.bn2 = nn.BatchNorm2d(out_planes)
531 | self.relu2 = nn.ReLU(inplace=True)
532 | self.conv2 = nn.Conv2d(out_planes, out_planes, kernel_size=3, stride=1,
533 | padding=1, bias=False)
534 | self.droprate = dropRate
535 | self.equalInOut = (in_planes == out_planes)
536 | self.convShortcut = (not self.equalInOut) and nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride,
537 | padding=0, bias=False) or None
538 | self.feats = []
539 | self.densities = []
540 | self.save_features = save_features
541 | self.bench = bench
542 | self.in_planes = in_planes
543 |
544 | def forward(self, x):
545 | conv_layers = []
546 | if not self.equalInOut:
547 | x = self.relu1(self.bn1(x))
548 | if self.save_features:
549 | self.feats.append(x.clone().detach())
550 | self.densities.append((x.data != 0.0).sum().item()/x.numel())
551 | else:
552 | out = self.relu1(self.bn1(x))
553 | if self.save_features:
554 | self.feats.append(out.clone().detach())
555 | self.densities.append((out.data != 0.0).sum().item()/out.numel())
556 | if self.bench:
557 | out0 = self.bench.forward(self.conv1, (out if self.equalInOut else x), str(self.in_planes) + '.conv1')
558 | else:
559 | out0 = self.conv1(out if self.equalInOut else x)
560 |
561 | out = self.relu2(self.bn2(out0))
562 | if self.save_features:
563 | self.feats.append(out.clone().detach())
564 | self.densities.append((out.data != 0.0).sum().item()/out.numel())
565 | if self.droprate > 0:
566 | out = F.dropout(out, p=self.droprate, training=self.training)
567 | if self.bench:
568 | out = self.bench.forward(self.conv2, out, str(self.in_planes) + '.conv2')
569 | else:
570 | out = self.conv2(out)
571 |
572 | return torch.add(x if self.equalInOut else self.convShortcut(x), out)
573 |
574 | class NetworkBlock(nn.Module):
575 | """Wide Residual Network network block which holds basic blocks.
576 |
577 | For more info, see the paper: Wide Residual Networks by Sergey Zagoruyko, Nikos Komodakis
578 | https://arxiv.org/abs/1605.07146
579 | """
580 | def __init__(self, nb_layers, in_planes, out_planes, block, stride, dropRate=0.0, save_features=False, bench=None):
581 | super(NetworkBlock, self).__init__()
582 | self.feats = []
583 | self.densities = []
584 | self.save_features = save_features
585 | self.bench = bench
586 | self.layer = self._make_layer(block, in_planes, out_planes, nb_layers, stride, dropRate)
587 |
588 | def _make_layer(self, block, in_planes, out_planes, nb_layers, stride, dropRate):
589 | layers = []
590 | for i in range(int(nb_layers)):
591 | layers.append(block(i == 0 and in_planes or out_planes, out_planes, i == 0 and stride or 1, dropRate, save_features=self.save_features, bench=self.bench))
592 | return nn.Sequential(*layers)
593 |
594 | def forward(self, x):
595 | for layer in self.layer:
596 | x = layer(x)
597 | if self.save_features:
598 | self.feats += layer.feats
599 | self.densities += layer.densities
600 | del layer.feats[:]
601 | del layer.densities[:]
602 | return x
603 |
604 |
--------------------------------------------------------------------------------
/models/model.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | from torch.nn import Linear
4 | import torch.nn.functional as F
5 | from torch_geometric.nn import ChebConv, GCNConv, SGConv, APPNP, GCN2Conv, JumpingKnowledge, MessagePassing # noqa
6 | from torch_geometric.nn.conv.gcn_conv import gcn_norm
7 | from models.gat_conv import GATConv
8 | from models.fagcn_conv import FAConv
9 |
10 | from torch_geometric.utils import to_scipy_sparse_matrix
11 | import torch_sparse
12 | from torch_sparse import SparseTensor, matmul
13 | import scipy.sparse
14 | import numpy as np
15 |
16 |
17 |
18 | class GCNNet(torch.nn.Module):
19 | def __init__(self, dataset, args):
20 | super(GCNNet, self).__init__()
21 |
22 | self.args = args
23 | self.conv1 = GCNConv(dataset.num_features, args.dim, cached=False, add_self_loops = True, normalize = True)
24 | self.conv2 = GCNConv(args.dim, dataset.num_classes, cached=False, add_self_loops = True, normalize = True)
25 |
26 | if args.adj_sparse:
27 | self.edge_weight_train = nn.Parameter(torch.randn_like(dataset.edge_index[0].to(torch.float32)))
28 |
29 | if args.feature_sparse:
30 | self.x_weight = nn.Parameter(torch.ones_like(dataset.x)[0].to(torch.float32))
31 |
32 | def forward(self, data, data_mask=None):
33 |
34 | x, edge_index, edge_weight = data.x, data.edge_index, data.edge_attr
35 |
36 | if self.args.adj_sparse:
37 | edge_mask = torch.abs(self.edge_weight_train) > 0
38 | row, col = data.edge_index
39 | row, col= row[edge_mask], col[edge_mask]
40 | edge_index = torch.stack([row, col], dim=0)
41 | edge_weight = self.edge_weight_train.sigmoid()[edge_mask]
42 |
43 | if self.args.feature_sparse:
44 | x_mask = (torch.abs(self.x_weight) > 0).float()
45 | x_weight = torch.mul(self.x_weight.sigmoid(), x_mask)
46 | x = torch.mul(x,x_weight)
47 |
48 |
49 | x = F.relu(self.conv1(x, edge_index, edge_weight=edge_weight))
50 | x = F.dropout(x, training=self.training)
51 | x = self.conv2(x, edge_index, edge_weight=edge_weight)
52 | return F.log_softmax(x, dim=1)
53 |
54 |
55 | class GATNet(torch.nn.Module):
56 | def __init__(self, dataset, args):
57 | super(GATNet, self).__init__()
58 |
59 | self.args = args
60 | self.conv1 = GATConv(dataset.num_features, 8, heads=8)
61 | self.conv2 = GATConv(8 * 8, dataset.num_classes, heads=1, concat=False)
62 |
63 | if args.adj_sparse:
64 | self.edge_weight_train = nn.Parameter(torch.randn_like(dataset.edge_index[0].to(torch.float32)))
65 |
66 | if args.feature_sparse:
67 | self.x_weight = nn.Parameter(torch.ones_like(dataset.x)[0].to(torch.float32))
68 |
69 |
70 | def forward(self, data, data_mask=None):
71 |
72 | x, edge_index, edge_weight = data.x, data.edge_index, data.edge_attr
73 |
74 | if self.args.adj_sparse:
75 | edge_mask = torch.abs(self.edge_weight_train) > 0
76 | row, col = data.edge_index
77 | row, col= row[edge_mask], col[edge_mask]
78 | edge_index = torch.stack([row, col], dim=0)
79 | edge_weight = self.edge_weight_train.sigmoid()[edge_mask]
80 |
81 | if self.args.feature_sparse:
82 | x_mask = (torch.abs(self.x_weight) > 0).float()
83 | x_weight = torch.mul(self.x_weight.sigmoid(), x_mask)
84 | x = torch.mul(x,x_weight)
85 |
86 | x = F.dropout(x, p=0.5, training=self.training)
87 | x = F.elu(self.conv1(x, edge_index))
88 | x = F.dropout(x, p=0.5, training=self.training)
89 | x = self.conv2(x, edge_index, edge_weight=edge_weight)
90 | return x.log_softmax(dim=-1)
91 |
92 |
93 | class SGCNet(torch.nn.Module):
94 | def __init__(self, dataset, args):
95 | super().__init__()
96 |
97 | self.args = args
98 | self.conv1 = SGConv(dataset.num_features, dataset.num_classes, K=2,
99 | cached=True)
100 | if args.adj_sparse:
101 | self.edge_weight_train = nn.Parameter(torch.randn_like(dataset.edge_index[0].to(torch.float32)))
102 |
103 | if args.feature_sparse:
104 | self.x_weight = nn.Parameter(torch.ones_like(dataset.x)[0].to(torch.float32))
105 |
106 | def forward(self, data):
107 | x, edge_index, edge_weight = data.x, data.edge_index, data.edge_attr
108 |
109 | if self.args.adj_sparse:
110 | edge_mask = torch.abs(self.edge_weight_train) > 0
111 | row, col = data.edge_index
112 | row, col= row[edge_mask], col[edge_mask]
113 | edge_index = torch.stack([row, col], dim=0)
114 | edge_weight = self.edge_weight_train.sigmoid()[edge_mask]
115 |
116 | if self.args.feature_sparse:
117 | x_mask = (torch.abs(self.x_weight) > 0).float()
118 | x_weight = torch.mul(self.x_weight.sigmoid(), x_mask)
119 | x = torch.mul(x,x_weight)
120 |
121 | x = self.conv1(x, edge_index, edge_weight= edge_weight)
122 | return F.log_softmax(x, dim=1)
123 |
124 |
125 | class APPNPNet(torch.nn.Module):
126 | def __init__(self, dataset, args):
127 | super().__init__()
128 | self.args = args
129 | self.lin1 = Linear(dataset.num_features, args.dim)
130 | self.lin2 = Linear(args.dim, dataset.num_classes)
131 | self.prop1 = APPNP(10, 0.1)
132 |
133 | if args.adj_sparse:
134 | self.edge_weight_train = nn.Parameter(torch.randn_like(dataset.edge_index[0].to(torch.float32)))
135 |
136 | if args.feature_sparse:
137 | self.x_weight = nn.Parameter(torch.ones_like(dataset.x)[0].to(torch.float32))
138 |
139 |
140 |
141 | def reset_parameters(self):
142 | self.lin1.reset_parameters()
143 | self.lin2.reset_parameters()
144 |
145 | def forward(self, data):
146 | x, edge_index, edge_weight = data.x, data.edge_index, data.edge_attr
147 |
148 | if self.args.adj_sparse:
149 | edge_mask = torch.abs(self.edge_weight_train) > 0
150 | row, col = data.edge_index
151 | row, col= row[edge_mask], col[edge_mask]
152 | edge_index = torch.stack([row, col], dim=0)
153 | edge_weight = self.edge_weight_train.sigmoid()[edge_mask]
154 |
155 | if self.args.feature_sparse:
156 | x_mask = (torch.abs(self.x_weight) > 0).float()
157 | x_weight = torch.mul(self.x_weight.sigmoid(), x_mask)
158 | x = torch.mul(x,x_weight)
159 |
160 | x = F.relu(self.lin1(x))
161 | x = F.dropout(x, training=self.training)
162 | x = self.lin2(x)
163 | x = self.prop1(x, edge_index, edge_weight=edge_weight)
164 | return F.log_softmax(x, dim=1)
165 |
166 | class GCNIINet(torch.nn.Module):
167 | def __init__(self, dataset, args,):
168 | super().__init__()
169 | # alpha =0.1
170 | # theta =0.5
171 | self.args = args
172 | self.lins = torch.nn.ModuleList()
173 | self.lins.append(Linear(dataset.num_features, args.dim))
174 | self.lins.append(Linear(args.dim, dataset.num_classes))
175 |
176 | self.convs = torch.nn.ModuleList()
177 | for layer in range(20):
178 | self.convs.append(
179 | GCN2Conv(args.dim, 0.1, 0.5, layer + 1,
180 | shared_weights =True, normalize=True))
181 |
182 | if args.adj_sparse:
183 | self.edge_weight_train = nn.Parameter(torch.randn_like(dataset.edge_index[0].to(torch.float32)))
184 |
185 | if args.feature_sparse:
186 | self.x_weight = nn.Parameter(torch.ones_like(dataset.x)[0].to(torch.float32))
187 |
188 |
189 |
190 | def forward(self, data):
191 |
192 |
193 | x, edge_index = data.x, data.edge_index
194 |
195 | if self.args.feature_sparse:
196 | x_mask = (torch.abs(self.x_weight) > 0).float()
197 | x_weight = torch.mul(self.x_weight.sigmoid(), x_mask)
198 | x = torch.mul(x,x_weight)
199 |
200 | x = F.dropout(x, training=self.training)
201 | x = x_0 = self.lins[0](x).relu()
202 |
203 | if self.args.adj_sparse:
204 | edge_mask = torch.abs(self.edge_weight_train) > 0
205 | row, col = data.edge_index
206 | row, col= row[edge_mask], col[edge_mask]
207 | edge_index = torch.stack([row, col], dim=0)
208 | edge_weight = self.edge_weight_train.sigmoid()[edge_mask]
209 |
210 |
211 |
212 |
213 |
214 | for conv in self.convs:
215 | x = F.dropout(x, training=self.training)
216 | x = conv(x, x_0, edge_index, edge_weight= edge_weight)
217 | x = x.relu()
218 |
219 | x = F.dropout(x, training=self.training)
220 | x = self.lins[1](x)
221 |
222 | return x.log_softmax(dim=-1)
223 |
224 |
225 | class MLP(nn.Module):
226 | """
227 |
228 | """
229 | def __init__(self, dataset, args):
230 | super(MLP,self).__init__()
231 |
232 | self.num_layers=2
233 | self.dropout_rate=0.5
234 |
235 | self.lins=nn.ModuleList()
236 | self.lins.append(nn.Linear(dataset.num_features, args.dim))
237 | for i in range(self.num_layers-2):
238 | self.lins.append(nn.Linear(args.dim, args.dim))
239 | self.lins.append(nn.Linear(args.dim,dataset.num_classes))
240 |
241 | self.bns=nn.ModuleList()
242 | for i in range(self.num_layers-1):
243 | self.bns.append(nn.BatchNorm1d(args.dim))
244 | self.bns.append(nn.BatchNorm1d(dataset.num_classes))
245 |
246 | def forward(self, data):
247 | x, edge_index, edge_weight = data.x, data.edge_index, data.edge_attr
248 |
249 | for i in range(self.num_layers-1):
250 | x=self.lins[i](x)
251 | x=self.bns[i](x)
252 | x=self.lins[self.num_layers-1](x)
253 |
254 | return F.log_softmax(x, dim=1)
255 |
256 |
257 | class LINK(nn.Module):
258 | """ logistic regression on adjacency matrix """
259 |
260 | def __init__(self, dataset, args):
261 | super(LINK, self).__init__()
262 |
263 | self.W = nn.Linear(dataset.x.size(0), dataset.num_classes)
264 |
265 | def reset_parameters(self):
266 | self.W.reset_parameters()
267 |
268 | def forward(self, data):
269 | N = data.x.size(0)
270 | edge_index = data.edge_index
271 | if isinstance(edge_index, torch.Tensor):
272 | row, col = edge_index
273 | A = SparseTensor(row=row, col=col, sparse_sizes=(N, N)).to_torch_sparse_coo_tensor()
274 | elif isinstance(edge_index, SparseTensor):
275 | A = edge_index.to_torch_sparse_coo_tensor()
276 | logits = self.W(A)
277 | return F.log_softmax(logits, dim=1)
278 |
279 |
280 | class FAGCN(nn.Module):
281 | def __init__(self, dataset, args):
282 | super(FAGCN, self).__init__()
283 | self.eps = args.fagcn_eps
284 | self.layer_num = args.fagcn_layer_num
285 | self.dropout = args.fagcn_dropout
286 |
287 | self.layers = nn.ModuleList()
288 | for _ in range(self.layer_num):
289 | self.layers.append(FAConv(args.dim, self.eps, self.dropout))
290 |
291 | self.t1 = nn.Linear(dataset.num_features, args.dim)
292 | self.t2 = nn.Linear(args.dim, dataset.num_classes)
293 | self.reset_parameters()
294 |
295 | def reset_parameters(self):
296 | nn.init.xavier_normal_(self.t1.weight, gain=1.414)
297 | nn.init.xavier_normal_(self.t2.weight, gain=1.414)
298 |
299 | def forward(self, data):
300 | x, edge_index, edge_weight = data.x, data.edge_index, data.edge_attr
301 | h = F.dropout(x, p=self.dropout, training=self.training)
302 | h = torch.relu(self.t1(h))
303 | h = F.dropout(h, p=self.dropout, training=self.training)
304 | raw = h
305 | for i in range(self.layer_num):
306 | h = self.layers[i](h,raw,edge_index)
307 | h = self.t2(h)
308 |
309 | return F.log_softmax(h, dim=1)
310 |
311 |
312 | class GPR_prop(MessagePassing):
313 | '''
314 | GPRGNN, from original repo https://github.com/jianhao2016/GPRGNN
315 | propagation class for GPR_GNN
316 | '''
317 |
318 | def __init__(self, K, alpha, Init, Gamma=None, bias=True, **kwargs):
319 | super(GPR_prop, self).__init__(aggr='add', **kwargs)
320 | self.K = K
321 | self.Init = Init
322 | self.alpha = alpha
323 |
324 | assert Init in ['SGC', 'PPR', 'NPPR', 'Random', 'WS']
325 | if Init == 'SGC':
326 | # SGC-like
327 | TEMP = 0.0*np.ones(K+1)
328 | TEMP[alpha] = 1.0
329 | elif Init == 'PPR':
330 | # PPR-like
331 | TEMP = alpha*(1-alpha)**np.arange(K+1)
332 | TEMP[-1] = (1-alpha)**K
333 | elif Init == 'NPPR':
334 | # Negative PPR
335 | TEMP = (alpha)**np.arange(K+1)
336 | TEMP = TEMP/np.sum(np.abs(TEMP))
337 | elif Init == 'Random':
338 | # Random
339 | bound = np.sqrt(3/(K+1))
340 | TEMP = np.random.uniform(-bound, bound, K+1)
341 | TEMP = TEMP/np.sum(np.abs(TEMP))
342 | elif Init == 'WS':
343 | # Specify Gamma
344 | TEMP = Gamma
345 |
346 | self.temp = nn.Parameter(torch.tensor(TEMP))
347 |
348 | def reset_parameters(self):
349 | nn.init.zeros_(self.temp)
350 | for k in range(self.K+1):
351 | self.temp.data[k] = self.alpha*(1-self.alpha)**k
352 | self.temp.data[-1] = (1-self.alpha)**self.K
353 |
354 | def forward(self, x, edge_index, edge_weight=None):
355 | if isinstance(edge_index, torch.Tensor):
356 | edge_index, norm = gcn_norm(
357 | edge_index, edge_weight, num_nodes=x.size(0), dtype=x.dtype)
358 | elif isinstance(edge_index, SparseTensor):
359 | edge_index = gcn_norm(
360 | edge_index, edge_weight, num_nodes=x.size(0), dtype=x.dtype)
361 | norm = None
362 |
363 | hidden = x*(self.temp[0])
364 | for k in range(self.K):
365 | x = self.propagate(edge_index, x=x, norm=norm)
366 | gamma = self.temp[k+1]
367 | hidden = hidden + gamma*x
368 | return hidden
369 |
370 | def message(self, x_j, norm):
371 | return norm.view(-1, 1) * x_j
372 |
373 | def __repr__(self):
374 | return '{}(K={}, temp={})'.format(self.__class__.__name__, self.K,
375 | self.temp)
376 |
377 |
378 | class GPRGNN(nn.Module):
379 | """GPRGNN, from original repo https://github.com/jianhao2016/GPRGNN"""
380 |
381 | def __init__(self, dataset, args):
382 | super(GPRGNN, self).__init__()
383 |
384 | Init='PPR'
385 | dprate=.5
386 | dropout=.5
387 | K=args.gprgnn_k
388 | alpha= args.gprgnn_alpha
389 | Gamma=None
390 | ppnp='GPR_prop'
391 | self.lin1 = nn.Linear(dataset.num_features, args.dim)
392 | self.lin2 = nn.Linear(args.dim, dataset.num_classes)
393 |
394 | if ppnp == 'PPNP':
395 | self.prop1 = APPNP(K, alpha)
396 | elif ppnp == 'GPR_prop':
397 | self.prop1 = GPR_prop(K, alpha, Init, Gamma)
398 |
399 | self.Init = Init
400 | self.dprate = dprate
401 | self.dropout = dropout
402 |
403 | def reset_parameters(self):
404 | self.lin1.reset_parameters()
405 | self.lin2.reset_parameters()
406 | self.prop1.reset_parameters()
407 |
408 | def forward(self, data):
409 | x, edge_index = data.x, data.edge_index
410 |
411 | x = F.dropout(x, p=self.dropout, training=self.training)
412 | x = F.relu(self.lin1(x))
413 | x = F.dropout(x, p=self.dropout, training=self.training)
414 | x = self.lin2(x)
415 |
416 | if self.dprate == 0.0:
417 | x = self.prop1(x, edge_index)
418 | return F.log_softmax(x, dim=1)
419 | else:
420 | x = F.dropout(x, p=self.dprate, training=self.training)
421 | x = self.prop1(x, edge_index)
422 | return F.log_softmax(x, dim=1)
423 |
424 | class MixHopLayer(nn.Module):
425 | """ Our MixHop layer """
426 | def __init__(self, in_channels, out_channels, hops=2):
427 | super(MixHopLayer, self).__init__()
428 | self.hops = hops
429 | self.lins = nn.ModuleList()
430 | for hop in range(self.hops+1):
431 | lin = nn.Linear(in_channels, out_channels)
432 | self.lins.append(lin)
433 |
434 | def reset_parameters(self):
435 | for lin in self.lins:
436 | lin.reset_parameters()
437 |
438 | def forward(self, x, adj_t):
439 | xs = [self.lins[0](x) ]
440 | for j in range(1,self.hops+1):
441 | # less runtime efficient but usually more memory efficient to mult weight matrix first
442 | x_j = self.lins[j](x)
443 | for hop in range(j):
444 | x_j = matmul(adj_t, x_j)
445 | xs += [x_j]
446 | return torch.cat(xs, dim=1)
447 |
448 | class MixHop(nn.Module):
449 | """ our implementation of MixHop
450 | some assumptions: the powers of the adjacency are [0, 1, ..., hops],
451 | with every power in between
452 | each concatenated layer has the same dimension --- hidden_channels
453 | """
454 | def __init__(self, dataset, args):
455 | super(MixHop, self).__init__()
456 |
457 | num_layers= args.mixhop_layer_num
458 | dropout=args.mixhop_dropout
459 | hops=args.mixhop_hop
460 |
461 | self.convs = nn.ModuleList()
462 | self.convs.append(MixHopLayer(dataset.num_features, args.dim, hops=hops))
463 |
464 | self.bns = nn.ModuleList()
465 | self.bns.append(nn.BatchNorm1d(args.dim*(hops+1)))
466 | for _ in range(num_layers - 2):
467 | self.convs.append(
468 | MixHopLayer(args.dim*(hops+1), args.dim, hops=hops))
469 | self.bns.append(nn.BatchNorm1d(args.dim*(hops+1)))
470 |
471 | self.convs.append(
472 | MixHopLayer(args.dim*(hops+1), dataset.num_classes, hops=hops))
473 |
474 | # note: uses linear projection instead of paper's attention output
475 | self.final_project = nn.Linear(dataset.num_classes*(hops+1), dataset.num_classes)
476 |
477 | self.dropout = dropout
478 | self.activation = F.relu
479 |
480 | def reset_parameters(self):
481 | for conv in self.convs:
482 | conv.reset_parameters()
483 | for bn in self.bns:
484 | bn.reset_parameters()
485 | self.final_project.reset_parameters()
486 |
487 |
488 | def forward(self, data):
489 | x, edge_index = data.x, data.edge_index
490 | n = data.x.size(0)
491 | edge_weight = None
492 |
493 | if isinstance(edge_index, torch.Tensor):
494 | edge_index, edge_weight = gcn_norm(
495 | edge_index, edge_weight, n, False,
496 | dtype=x.dtype)
497 | row, col = edge_index
498 | adj_t = SparseTensor(row=col, col=row, value=edge_weight, sparse_sizes=(n, n))
499 | elif isinstance(edge_index, SparseTensor):
500 | edge_index = gcn_norm(
501 | edge_index, edge_weight, n, False,
502 | dtype=x.dtype)
503 | edge_weight=None
504 | adj_t = edge_index
505 |
506 | for i, conv in enumerate(self.convs[:-1]):
507 | x = conv(x, adj_t)
508 | x = self.bns[i](x)
509 | x = self.activation(x)
510 | x = F.dropout(x, p=self.dropout, training=self.training)
511 | x = self.convs[-1](x, adj_t)
512 |
513 | x = self.final_project(x)
514 | return x
515 |
516 |
517 | class HGCN(nn.Module):
518 | def __init__(self,dataset, args):
519 | super(HGCN, self).__init__()
520 | self.args = args
521 | self.lin1 = nn.Linear(dataset.num_features, args.dim)
522 | self.lin = nn.Linear(args.dim*5,dataset.num_classes)
523 |
524 | def forward(self,data):
525 |
526 | x, edge_index, edge_weight = data.x, data.edge_index, data.edge_attr
527 |
528 | edge_index, edge_weight = gcn_norm(edge_index, edge_weight)
529 |
530 | adj_coo = to_scipy_sparse_matrix(edge_index)
531 | adj_row = adj_coo.row
532 | adj_col = adj_coo.col
533 | adj_value = edge_weight.detach().cpu().numpy()
534 | # print(adj_value)
535 | # raise exception("pause")
536 | adj_size = adj_coo.shape
537 | edge_index = torch_sparse.SparseTensor(sparse_sizes=[adj_size[0], adj_size[1]], row=torch.tensor(adj_row, dtype=torch.long),
538 | col=torch.tensor(adj_col, dtype=torch.long),
539 | value=torch.tensor(adj_value, dtype=torch.float32)).to(self.args.device)
540 | temp = self.lin1(x)
541 | temp = F.relu(temp)
542 | temp1 = torch_sparse.matmul(edge_index,temp)
543 | temp1 =torch.cat((temp,temp1),dim=1)
544 | temp2 = torch_sparse.matmul(edge_index,temp1)
545 | temp = torch.cat((temp,temp1,temp2),dim=1)
546 | temp = F.dropout(temp,p=self.args.h2gcn_dropout)
547 | ans = self.lin(temp)
548 |
549 | return F.log_softmax(ans, dim=1)
550 |
551 |
552 | class FAGCNNet(nn.Module):
553 | def __init__(self, dataset, args):
554 | super(FAGCNNet, self).__init__()
555 | self.eps = 0.3
556 | self.layer_num = 2
557 | self.dropout = 0.6
558 | self.args = args
559 |
560 | self.layers = nn.ModuleList()
561 | for _ in range(self.layer_num):
562 | self.layers.append(FAConv(args.dim, self.eps, self.dropout))
563 |
564 | self.t1 = nn.Linear(dataset.num_features, args.dim)
565 | self.t2 = nn.Linear(args.dim, dataset.num_classes)
566 | self.reset_parameters()
567 |
568 | if args.adj_sparse:
569 | self.edge_weight_train = nn.Parameter(torch.randn_like(dataset.edge_index[0].to(torch.float32)))
570 |
571 | def reset_parameters(self):
572 | nn.init.xavier_normal_(self.t1.weight, gain=1.414)
573 | nn.init.xavier_normal_(self.t2.weight, gain=1.414)
574 |
575 | def forward(self, data):
576 | x, edge_index, edge_weight = data.x, data.edge_index, data.edge_attr
577 |
578 | if self.args.adj_sparse:
579 | edge_mask = torch.abs(self.edge_weight_train) > 0
580 | row, col = data.edge_index
581 | row, col= row[edge_mask], col[edge_mask]
582 | edge_index = torch.stack([row, col], dim=0)
583 | edge_weight = self.edge_weight_train.sigmoid()[edge_mask]
584 |
585 |
586 | h = F.dropout(x, p=self.dropout, training=self.training)
587 | h = torch.relu(self.t1(h))
588 | h = F.dropout(h, p=self.dropout, training=self.training)
589 | raw = h
590 | for i in range(self.layer_num):
591 | h = self.layers[i](h,raw,edge_index, edge_weight)
592 | h = self.t2(h)
593 |
594 | return F.log_softmax(h, dim=1)
595 |
596 |
597 |
598 | class HGCNNet(nn.Module):
599 | def __init__(self,dataset, args):
600 | super(HGCNNet, self).__init__()
601 | self.args = args
602 | self.lin1 = nn.Linear(dataset.num_features, args.dim)
603 | self.lin = nn.Linear(args.dim*5,dataset.num_classes)
604 |
605 | if args.adj_sparse:
606 | self.edge_weight_train = nn.Parameter(torch.randn_like(dataset.edge_index[0].to(torch.float32)))
607 |
608 |
609 | def forward(self,data):
610 |
611 | x, edge_index, edge_weight = data.x, data.edge_index, data.edge_attr
612 |
613 | if self.args.adj_sparse:
614 | edge_mask = torch.abs(self.edge_weight_train) > 0
615 | row, col = data.edge_index
616 | row, col= row[edge_mask], col[edge_mask]
617 | edge_index = torch.stack([row, col], dim=0)
618 | edge_weight = self.edge_weight_train.sigmoid()[edge_mask]
619 |
620 | edge_index, edge_weight = gcn_norm(edge_index, edge_weight, num_nodes = x.size(0))
621 |
622 | adj_coo = to_scipy_sparse_matrix(edge_index)
623 | adj_row = adj_coo.row
624 | adj_col = adj_coo.col
625 | adj_value = edge_weight.detach().cpu().numpy()
626 | # print(adj_value)
627 | # raise exception("pause")
628 | adj_size = adj_coo.shape
629 | edge_index = torch_sparse.SparseTensor(sparse_sizes=[adj_size[0], adj_size[1]], row=torch.tensor(adj_row, dtype=torch.long),
630 | col=torch.tensor(adj_col, dtype=torch.long),
631 | value=torch.tensor(adj_value, dtype=torch.float32)).to(self.args.device)
632 | temp = self.lin1(x)
633 | temp = F.relu(temp)
634 | temp1 = torch_sparse.matmul(edge_index,temp)
635 | temp1 =torch.cat((temp,temp1),dim=1)
636 | temp2 = torch_sparse.matmul(edge_index,temp1)
637 | temp = torch.cat((temp,temp1,temp2),dim=1)
638 | temp = F.dropout(temp,p=0.5)
639 | ans = self.lin(temp)
640 |
641 | return F.log_softmax(ans, dim=1)
642 |
643 |
644 |
645 |
646 | class GCNmasker(torch.nn.Module):
647 |
648 | def __init__(self, dataset, args):
649 | super(GCNmasker, self).__init__()
650 |
651 | self.conv1 = GCNConv(dataset.num_features, args.masker_dim, cached=False)
652 | self.conv2 = GCNConv(args.masker_dim, args.masker_dim, cached=False)
653 | self.mlp = nn.Linear(args.masker_dim * 2, 1)
654 | self.score_function = args.score_function
655 | self.sigmoid = nn.Sigmoid()
656 |
657 | def forward(self, data):
658 |
659 | x, edge_index, edge_weight = data.x, data.edge_index, data.edge_attr
660 | x = F.relu(self.conv1(x, edge_index, edge_weight))
661 | x = F.dropout(x, training=self.training)
662 | x = self.conv2(x, edge_index, edge_weight)
663 |
664 | if self.score_function == 'inner_product':
665 | link_score = self.inner_product_score(x, edge_index)
666 | elif self.score_function == 'concat_mlp':
667 | link_score = self.concat_mlp_score(x, edge_index)
668 | else:
669 | assert False
670 |
671 | return link_score
672 |
673 | def inner_product_score(self, x, edge_index):
674 |
675 | row, col = edge_index
676 | link_score = torch.sum(x[row] * x[col], dim=1)
677 | #print("max:{:.2f} min:{:.2f} mean:{:.2f}".format(link_score.max(), link_score.min(), link_score.mean()))
678 | link_score = self.sigmoid(link_score).view(-1)
679 | return link_score
680 |
681 | def concat_mlp_score(self, x, edge_index):
682 |
683 | row, col = edge_index
684 | link_score = torch.cat((x[row], x[col]), dim=1)
685 | link_score = self.mlp(link_score)
686 | # weight = self.mlp.weight
687 | # print("max:{:.2f} min:{:.2f} mean:{:.2f}".format(link_score.max(), link_score.min(), link_score.mean()))
688 | link_score = self.sigmoid(link_score).view(-1)
689 | return link_score
--------------------------------------------------------------------------------
/main_stgnn.py:
--------------------------------------------------------------------------------
1 | from __future__ import print_function
2 | import sys
3 | import os
4 | import os.path as osp
5 | import shutil
6 | import time
7 | import argparse
8 | import logging
9 | import hashlib
10 | import copy
11 | import csv
12 | import numpy as np
13 | import random
14 |
15 | import torch
16 | import torch.nn.functional as F
17 | import torch.optim as optim
18 | import torch.backends.cudnn as cudnn
19 | import torch_geometric.transforms as T
20 |
21 | from torch_geometric.datasets import Planetoid, Reddit, WebKB, Actor, WikipediaNetwork, Coauthor, Amazon, Flickr, WikiCS, Yelp
22 | from torch_geometric.loader import NeighborLoader, RandomNodeSampler
23 | from torch_geometric.transforms import RandomNodeSplit
24 | from torch_geometric.nn.conv.gcn_conv import gcn_norm
25 |
26 |
27 | from ogb.nodeproppred import Evaluator
28 | from ogb.nodeproppred import PygNodePropPredDataset
29 | from torch_scatter import scatter
30 | from torch_geometric.utils import to_undirected, add_self_loops
31 |
32 | import sparselearning
33 | from models import initializers
34 | from sparselearning.core import Masking, CosineDecay, LinearDecay
35 | from models.model import GCNNet, SGCNet, APPNPNet, GCNIINet, GATNet, MLP, FAGCN, HGCN, LINK, GPRGNN, MixHop, FAGCNNet, HGCNNet
36 |
37 | import warnings
38 | warnings.filterwarnings("ignore", category=UserWarning)
39 |
40 | cudnn.benchmark = True
41 | cudnn.deterministic = True
42 |
43 |
44 | if not os.path.exists('./models'): os.mkdir('./models')
45 | if not os.path.exists('./logs'): os.mkdir('./logs')
46 | if not os.path.exists('./results'): os.mkdir('./results')
47 | logger = None
48 |
49 | torch.backends.cudnn.enabled = True
50 | torch.backends.cudnn.benchmark = True
51 |
52 | models = {}
53 | models['gcn'] = (GCNNet)
54 | models['gat'] = (GATNet)
55 | models['sgc'] = (SGCNet)
56 | models['appnp'] = (APPNPNet)
57 | models['gcnii'] = (GCNIINet)
58 | models['mlp'] = (MLP)
59 | models['fagcn'] = (FAGCN)
60 | models['h2gcn'] = (HGCN)
61 | models['link'] = (LINK)
62 | models['gprgnn'] = (GPRGNN)
63 | models['mixhop'] = (MixHop)
64 | models['fagcnnet'] = (FAGCNNet)
65 | models['h2gcnnet'] = (HGCNNet)
66 |
67 |
68 | def save_checkpoint(state, filename='checkpoint.pth.tar'):
69 | print("SAVING")
70 | torch.save(state, filename)
71 |
72 |
73 | def setup_logger(args):
74 | global logger
75 | if logger == None:
76 | logger = logging.getLogger()
77 | else: # wish there was a logger.close()
78 | for handler in logger.handlers[:]: # make a copy of the list
79 | logger.removeHandler(handler)
80 |
81 | args_copy = copy.deepcopy(args)
82 | # copy to get a clean hash
83 | # use the same log file hash if iterations or verbose are different
84 | # these flags do not change the results
85 | args_copy.iters = 1
86 | args_copy.verbose = False
87 | args_copy.log_interval = 1
88 | args_copy.seed = 0
89 |
90 | if args.weight_sparse and not args.adj_sparse and not args.feature_sparse :
91 | sparse_way = 'w'
92 | log_path = './logs/{0}/{1}_{2}_{3}_{4}.log'.format(
93 | sparse_way,args.model,args.data, args.final_density, hashlib.md5(str(args_copy).encode('utf-8')).hexdigest()[:8])
94 | elif args.adj_sparse and not args.weight_sparse and not args.feature_sparse :
95 | sparse_way = 'a'
96 | log_path = './logs/{0}/{1}_{2}_{3}_{4}.log'.format(
97 | sparse_way,args.model,args.data, args.final_density_adj, hashlib.md5(str(args_copy).encode('utf-8')).hexdigest()[:8])
98 | elif args.feature_sparse and not args.weight_sparse and not args.adj_sparse :
99 | sparse_way = 'f'
100 | log_path = './logs/{0}/{1}_{2}_{3}_{4}.log'.format(
101 | sparse_way,args.model,args.data, args.final_density_feature, hashlib.md5(str(args_copy).encode('utf-8')).hexdigest()[:8])
102 | elif args.weight_sparse and args.adj_sparse and not args.feature_sparse :
103 | sparse_way = 'wa'
104 | log_path = './logs/{0}/{1}_{2}_{3}_{4}_{5}.log'.format(
105 | sparse_way,args.model,args.data, args.final_density, args.final_density_adj, hashlib.md5(str(args_copy).encode('utf-8')).hexdigest()[:8])
106 | elif args.weight_sparse and args.feature_sparse and not args.adj_sparse :
107 | sparse_way = 'wf'
108 | log_path = './logs/{0}/{1}_{2}_{3}_{4}_{5}.log'.format(
109 | sparse_way,args.model,args.data, args.final_density, args.final_density_feature, hashlib.md5(str(args_copy).encode('utf-8')).hexdigest()[:8])
110 | elif args.adj_sparse and args.feature_sparse and not args.weight_sparse :
111 | sparse_way = 'af'
112 | log_path = './logs/{0}/{1}_{2}_{3}_{4}_{5}.log'.format(
113 | sparse_way,args.model,args.data, args.final_density_adj, args.final_density_feature, hashlib.md5(str(args_copy).encode('utf-8')).hexdigest()[:8])
114 | elif args.weight_sparse and args.adj_sparse and args.feature_sparse:
115 | sparse_way = 'waf'
116 | log_path = './logs/{0}/{1}_{2}_{3}_{4}_{5}_{6}.log'.format(
117 | sparse_way,args.model,args.data, args.final_density, args.final_density_adj, args.final_density_feature, hashlib.md5(str(args_copy).encode('utf-8')).hexdigest()[:8])
118 | else:
119 | sparse_way = 'base'
120 | log_path = './logs/{0}/{1}_{2}_{3}.log'.format(
121 | sparse_way,args.model,args.data, hashlib.md5(str(args_copy).encode('utf-8')).hexdigest()[:8])
122 |
123 | if not os.path.exists('./logs/{}'.format(sparse_way)): os.mkdir('./logs/{}'.format(sparse_way))
124 |
125 | #log_path = './logs/{0}/{1}_{2}_{3}_{4}_{5}.log'.format(sparse_way, args.model, args.data, args.final_density, args.final_density_adj, hashlib.md5(str(args_copy).encode('utf-8')).hexdigest()[:8])
126 |
127 | logger.setLevel(logging.INFO)
128 | formatter = logging.Formatter(fmt='%(asctime)s: %(message)s', datefmt='%H:%M:%S')
129 |
130 | fh = logging.FileHandler(log_path)
131 | fh.setFormatter(formatter)
132 | logger.addHandler(fh)
133 |
134 | def print_and_log(msg):
135 | global logger
136 | print(msg)
137 | logger.info(msg)
138 |
139 |
140 | def results_to_file(args, train_acc, val_acc, test_acc, train_time, test_time):
141 |
142 |
143 | if args.weight_sparse and not args.adj_sparse and not args.feature_sparse :
144 | sparse_way = 'w'
145 | filename = "./results/{}/{}_{}_{}_{}_result.csv".format(
146 | sparse_way, args.model, args.data, args.init_density, args.final_density)
147 | elif args.adj_sparse and not args.weight_sparse and not args.feature_sparse :
148 | sparse_way = 'a'
149 | filename = "./results/{}/{}_{}_{}_{}_result.csv".format(
150 | sparse_way, args.model, args.data, args.final_density_adj)
151 | elif args.feature_sparse and not args.weight_sparse and not args.adj_sparse :
152 | sparse_way = 'f'
153 | filename = "./results/{}/{}_{}_{}_{}_result.csv".format(
154 | sparse_way, args.model, args.data, args.final_density_feature)
155 | elif args.weight_sparse and args.adj_sparse and not args.feature_sparse :
156 | sparse_way = 'wa'
157 | filename = "./results/{}/{}_{}_{}_{}_result.csv".format(
158 | sparse_way, args.model, args.data, args.final_density, args.final_density_adj)
159 | elif args.weight_sparse and args.feature_sparse and not args.adj_sparse :
160 | sparse_way = 'wf'
161 | filename = "./results/{}/{}_{}_{}_{}_result.csv".format(
162 | sparse_way, args.model, args.data, args.final_density, args.final_density_feature)
163 | elif args.adj_sparse and args.feature_sparse and not args.weight_sparse :
164 | sparse_way = 'af'
165 | filename = "./results/{}/{}_{}_{}_{}_result.csv".format(
166 | sparse_way, args.model, args.data, args.final_density_adj, args.final_density_feature)
167 | elif args.weight_sparse and args.adj_sparse and args.feature_sparse:
168 | sparse_way = 'waf'
169 | filename = "./results/{}/{}_{}_{}_{}_{}_result.csv".format(
170 | sparse_way, args.model, args.data, args.final_density, args.final_density_adj, args.final_density_feature)
171 | else:
172 | sparse_way = 'base'
173 | filename = "./results/{}/{}_{}_result.csv".format(
174 | sparse_way, args.model, args.data)
175 |
176 | if not os.path.exists('./results/{}'.format(sparse_way)): os.mkdir('./results/{}'.format(sparse_way))
177 |
178 | headerList = ["Method","Growth","Prune Rate", "Update Frequency", "Final Prune Epoch", "::", "train_acc", "val_acc", "test_acc", "train_time", "test_time"]
179 |
180 | #filename = "./results/{}/{}_{}_{}_{}_result.csv".format(sparse_way, args.model, args.data, args.final_density, args.final_density_adj)
181 | with open(filename, "a+") as f:
182 |
183 | # reader = csv.reader(f)
184 | # row1 = next(reader)
185 | f.seek(0)
186 | header = f.read(6)
187 | if header != "Method":
188 | dw = csv.DictWriter(f, delimiter=',',
189 | fieldnames=headerList)
190 | dw.writeheader()
191 |
192 | line = "{}, {}, {}, {}, {}, :::, {:.4f}, {:.4f}, {:.4f},{:.4f}, {:.4f}\n".format(
193 | args.method, args.growth_schedule, args.prune_rate, args.update_frequency, args.final_prune_epoch, train_acc, val_acc, test_acc, train_time, test_time
194 | )
195 | f.write(line)
196 |
197 |
198 |
199 | def train(args, model, device, data, optimizer, epoch, mask=None):
200 | model.train()
201 | train_loss = 0
202 | correct = 0
203 | n = 0
204 | criterion = torch.nn.BCEWithLogitsLoss()
205 |
206 | data = data.to(device)
207 |
208 | target = data.y[data.train_mask].to(device)
209 |
210 |
211 | if args.fp16: data = data.half()
212 |
213 | optimizer.zero_grad()
214 |
215 | output = model(data)[data.train_mask]
216 | if args.data in ['ogbn-proteins']:
217 | loss = criterion(output, target)
218 | acc = 0.0
219 | else:
220 | loss = F.nll_loss(output, target)
221 | pred = output.max(1)[1]
222 | acc = pred.eq(data.y[data.train_mask]).sum().item() / data.train_mask.sum().item()
223 |
224 | if args.fp16:
225 | optimizer.backward(loss)
226 | else:
227 | loss.backward()
228 |
229 | if mask is not None:
230 | #print("Mask!!!!")
231 | mask.step()
232 | else:
233 | optimizer.step()
234 |
235 | # print_and_log('\n{}: Average loss: {:.4f}, Accuracy: {} \n'.format(
236 | # 'Training summary', loss, acc,))
237 |
238 | def evaluate(args, model, device, data, is_test_set=False):
239 | model.eval()
240 | test_loss = 0
241 | correct = 0
242 | n = 0
243 | with torch.no_grad():
244 | #target = data.y[data.train_mask].to(device)
245 | data = data.to(device)
246 |
247 | if args.fp16: data = data.half()
248 |
249 | logits, accs = model(data), []
250 |
251 | for _, mask in data('train_mask', 'val_mask', 'test_mask'):
252 | pred = logits[mask].max(1)[1]
253 | acc = pred.eq(data.y[mask]).sum().item() / mask.sum().item()
254 | accs.append(acc)
255 | train_acc, val_acc, tmp_test_acc = accs
256 |
257 | # print_and_log('\n{}: Train Accuracy: {:.4f}, Val Accuracy: {} Test Accuracy: {}\n'.format(
258 | # 'Test evaluation' if is_test_set else 'Evaluation',
259 | # train_acc, val_acc, tmp_test_acc))
260 | return train_acc, val_acc, tmp_test_acc
261 |
262 | def evaluate_ogb(args, model, device, data, evaluator, is_test_set=False):
263 |
264 | model.eval()
265 | test_loss = 0
266 | correct = 0
267 | n = 0
268 | with torch.no_grad():
269 | #target = data.y[data.train_mask].to(device)
270 | data = data.to(device)
271 |
272 | if args.fp16: data = data.half()
273 |
274 | out = model(data)
275 |
276 | if args.data in ['ogbn-arxiv','ogbn-products' ]:
277 | y_pred = out.argmax(dim=-1, keepdim=True)
278 |
279 | train_acc = evaluator.eval({
280 | 'y_true': data.y.unsqueeze(-1)[data.train_mask],
281 | 'y_pred': y_pred[data.train_mask],
282 | })['acc']
283 | val_acc = evaluator.eval({
284 | 'y_true': data.y.unsqueeze(-1)[data.valid_mask],
285 | 'y_pred': y_pred[data.valid_mask],
286 | })['acc']
287 | tmp_test_acc = evaluator.eval({
288 | 'y_true': data.y.unsqueeze(-1)[data.test_mask],
289 | 'y_pred': y_pred[data.test_mask],
290 | })['acc']
291 | print_and_log('\n{}: Train Accuracy: {:.4f}, Val Accuracy: {} Test Accuracy: {}\n'.format(
292 | 'Test evaluation' if is_test_set else 'Evaluation',
293 | train_acc, val_acc, tmp_test_acc))
294 |
295 | elif args.data in ['ogbn-proteins']:
296 |
297 | train_acc = evaluator.eval({
298 | 'y_true': data.y[data.train_mask],
299 | 'y_pred': out[data.train_mask],
300 | })['rocauc'] # Acutually roc-auc, only name it train_acc
301 | val_acc = evaluator.eval({
302 | 'y_true': data.y[data.valid_mask],
303 | 'y_pred': out[data.valid_mask],
304 | })['rocauc']
305 | tmp_test_acc = evaluator.eval({
306 | 'y_true': data.y[data.test_mask],
307 | 'y_pred': out[data.test_mask],
308 | })['rocauc']
309 | print_and_log('\n{}: Train ROC-AUC: {:.4f}, Val ROC-AUC: {} Test ROC-AUC: {}\n'.format(
310 | 'Test evaluation' if is_test_set else 'Evaluation',
311 | train_acc, val_acc, tmp_test_acc))
312 |
313 | return train_acc, val_acc, tmp_test_acc
314 |
315 |
316 | def main():
317 | # Training settings
318 | parser = argparse.ArgumentParser(description='PyTorch GraNet for sparse training')
319 | parser.add_argument('--batch-size', type=int, default=100, metavar='N',
320 | help='input batch size for training (default: 100)')
321 | parser.add_argument('--batch-size-jac', type=int, default=200, metavar='N',
322 | help='batch size for jac (default: 1000)')
323 | parser.add_argument('--test-batch-size', type=int, default=100, metavar='N',
324 | help='input batch size for testing (default: 100)')
325 | parser.add_argument('--multiplier', type=int, default=1, metavar='N',
326 | help='extend training time by multiplier times')
327 | parser.add_argument('--epochs', type=int, default=250, metavar='N',
328 | help='number of epochs to train (default: 100)')
329 | parser.add_argument('--lr', type=float, default=0.1, metavar='LR',
330 | help='learning rate (default: 0.1)')
331 | parser.add_argument('--momentum', type=float, default=0.9, metavar='M',
332 | help='SGD momentum (default: 0.9)')
333 | parser.add_argument('--no-cuda', action='store_true', default=False,
334 | help='disables CUDA training')
335 | parser.add_argument('--seed', type=int, default=17, metavar='S', help='random seed (default: 17)')
336 | parser.add_argument('--log-interval', type=int, default=100, metavar='N',
337 | help='how many batches to wait before logging training status')
338 | parser.add_argument('--optimizer', type=str, default='adam', help='The optimizer to use. Default: sgd. Options: sgd, adam.')
339 | randomhash = ''.join(str(time.time()).split('.'))
340 | parser.add_argument('--save', type=str, default=randomhash + '.pt',
341 | help='path to save the final model')
342 | parser.add_argument('--data', type=str, default='mnist')
343 | parser.add_argument('--decay_frequency', type=int, default=25000)
344 | parser.add_argument('--l1', type=float, default=0.0)
345 | parser.add_argument('--fp16', action='store_true', help='Run in fp16 mode.')
346 | parser.add_argument('--valid_split', type=float, default=0.1)
347 | parser.add_argument('--resume', type=str)
348 | parser.add_argument('--start-epoch', type=int, default=1)
349 | parser.add_argument('--model', type=str, default='')
350 | parser.add_argument('--l2', type=float, default=1.0e-4)
351 | parser.add_argument('--iters', type=int, default=1, help='How many times the model should be run after each other. Default=1')
352 | parser.add_argument('--save-features', action='store_true', help='Resumes a saved model and saves its feature data to disk for plotting.')
353 | parser.add_argument('--bench', action='store_true', help='Enables the benchmarking of layers and estimates sparse speedups')
354 | parser.add_argument('--max-threads', type=int, default=10, help='How many threads to use for data loading.')
355 | parser.add_argument('--decay-schedule', type=str, default='cosine', help='The decay schedule for the pruning rate. Default: cosine. Choose from: cosine, linear.')
356 | parser.add_argument('--growth_schedule', type=str, default='gradient', help='The growth schedule. Default: gradient. Choose from: gradient, momentum, random.')
357 | parser.add_argument('--lr_scheduler', action='store_true', default=False,
358 | help='disables CUDA training')
359 | parser.add_argument('--adj_sparse', action='store_true', help='If Sparse Adj.')
360 | parser.add_argument('--feature_sparse', action='store_true', help='If Sparse Weight.')
361 | parser.add_argument('--weight_sparse', action='store_true', help='If Sparse Feature.',)
362 | parser.add_argument('--dim', type=int, default=512)
363 | parser.add_argument('--cuda', type=int, default=0)
364 |
365 |
366 | # FAGCN
367 | parser.add_argument('--fagcn_layer_num', type=int, default=1)
368 | parser.add_argument('--fagcn_dropout', type=float, default=0)
369 | parser.add_argument('--fagcn_eps', type=float, default=0.1)
370 |
371 | # MixHop
372 | parser.add_argument('--mixhop_layer_num', type=int, default=1)
373 | parser.add_argument('--mixhop_dropout', type=float, default=0)
374 | parser.add_argument('--mixhop_hop', type=int, default=2)
375 |
376 | # GPRGNN
377 | parser.add_argument('--gprgnn_alpha', type=float, default=0.1)
378 | parser.add_argument('--gprgnn_k', type=int, default=10)
379 |
380 | # H2GCN
381 |
382 | parser.add_argument('--h2gcn_dropout', type=float, default=0.1)
383 |
384 |
385 | sparselearning.core.add_sparse_args(parser)
386 |
387 | args = parser.parse_args()
388 | setup_logger(args)
389 | print_and_log(args)
390 |
391 | if args.fp16:
392 | try:
393 | from apex.fp16_utils import FP16_Optimizer
394 | except:
395 | print('WARNING: apex not installed, ignoring --fp16 option')
396 | args.fp16 = False
397 |
398 | use_cuda = not args.no_cuda and torch.cuda.is_available()
399 | args.device = torch.device('cuda:{}'.format(args.cuda) if use_cuda else "cpu")
400 |
401 |
402 |
403 | print_and_log('\n\n')
404 | print_and_log('='*80)
405 | np.random.seed(args.seed)
406 | torch.manual_seed(args.seed)
407 | random.seed(args.seed)
408 | if torch.cuda.is_available():
409 | torch.cuda.manual_seed(args.seed)
410 | torch.cuda.manual_seed_all(args.seed)
411 | for i in range(args.iters):
412 |
413 | #######################################################################################
414 | ############################# Datasets ################################################
415 | #######################################################################################
416 | print_and_log("\nIteration start: {0}/{1}\n".format(i+1, args.iters))
417 |
418 |
419 | if args.data in ['cora','citeseer','pubmed']:
420 | path = osp.join(osp.dirname(osp.realpath(__file__)), '../data', args.data)
421 | dataset = Planetoid(path, args.data, transform=T.NormalizeFeatures())
422 | data = dataset[0]
423 |
424 | data.num_classes = dataset.num_classes
425 | data.num_edges_orig = data.num_edges
426 | #print_and_log(data)
427 | #raise Exception('pause!! ')
428 |
429 | elif args.data in ["Cornell", "Texas", "Wisconsin"] :
430 | path = os.path.join(os.path.dirname(os.path.realpath(__file__)), '../data', args.data)
431 | dataset = WebKB(path,args.data, transform=T.NormalizeFeatures())
432 | data = dataset[0]
433 | data.num_classes = dataset.num_classes
434 | print_and_log(data)
435 | data.train_mask = data.train_mask[:, args.seed % 10]
436 | data.val_mask = data.val_mask[:, args.seed % 10]
437 | data.test_mask = data.test_mask[:, args.seed % 10]
438 | #print_and_log(data)
439 | #raise Exception('pause!! ')
440 |
441 | elif args.data in ["Actor"] :
442 | path = os.path.join(os.path.dirname(os.path.realpath(__file__)), '../data', args.data)
443 | dataset = Actor(path, transform=T.NormalizeFeatures())
444 | data = dataset[0]
445 | data.num_classes = dataset.num_classes
446 |
447 | #print_and_log(data)
448 | data.train_mask = data.train_mask[:, args.seed % 10]
449 | data.val_mask = data.val_mask[:, args.seed % 10]
450 | data.test_mask = data.test_mask[:, args.seed % 10]
451 | #print_and_log(data)
452 | #raise Exception('pause!! ')
453 |
454 | elif args.data in ["chameleon", "crocodile", "squirrel"] :
455 | path = os.path.join(os.path.dirname(os.path.realpath(__file__)), '../data', args.data)
456 | dataset = WikipediaNetwork(path, args.data, transform=T.NormalizeFeatures())
457 | data = dataset[0]
458 | data.num_classes = dataset.num_classes
459 |
460 | data.train_mask = data.train_mask[:, args.seed % 10]
461 | data.val_mask = data.val_mask[:, args.seed % 10]
462 | data.test_mask = data.test_mask[:, args.seed % 10]
463 | #print_and_log(data)
464 |
465 | # Multi Spilt
466 | raise Exception('pause!! ')
467 |
468 |
469 | elif args.data in ["CS", "Physics"] :
470 | path = os.path.join(os.path.dirname(os.path.realpath(__file__)), '../data', args.data)
471 | dataset = Coauthor(path, args.data, transform=T.NormalizeFeatures())
472 | data = dataset[0]
473 | data.num_classes = dataset.num_classes
474 | transform = RandomNodeSplit(split= "test_rest",
475 | num_train_per_class = 20,
476 | num_val = 30* data.num_classes,)
477 | transform(data)
478 | #print_and_log(data)
479 | #raise Exception('pause!! ')
480 |
481 | elif args.data in ["Computers", "Photo"] :
482 | path = os.path.join(os.path.dirname(os.path.realpath(__file__)), '../data', args.data)
483 | dataset = Amazon(path, args.data, transform=T.NormalizeFeatures())
484 | data = dataset[0]
485 | data.num_classes = dataset.num_classes
486 | transform = RandomNodeSplit(split= "test_rest",
487 | num_train_per_class = 20,
488 | num_val = 30* data.num_classes,)
489 | transform(data)
490 | #print_and_log(data)
491 |
492 | #raise Exception('pause!! ')
493 |
494 |
495 | elif args.data in ["Flickr"] :
496 | path = os.path.join(os.path.dirname(os.path.realpath(__file__)), '../data', args.data)
497 | dataset = Flickr(path)
498 | data = dataset[0]
499 | data.num_classes = dataset.num_classes
500 | print_and_log(data)
501 | # Cannot load file containing pickled data when allow_pickle=False
502 | raise Exception('pause!! ')
503 |
504 | elif args.data in ["Yelp"] :
505 | path = os.path.join(os.path.dirname(os.path.realpath(__file__)), '../data', args.data)
506 | dataset = Yelp(path)
507 | data = dataset[0]
508 | data.num_classes = dataset.num_classes
509 | print_and_log(data)
510 | # Cannot load file containing pickled data when allow_pickle=False
511 | # Fix: allow_pickle=True)
512 | raise Exception('pause!! ')
513 |
514 |
515 | elif args.data in ["WikiCS"] :
516 | path = os.path.join(os.path.dirname(os.path.realpath(__file__)), '../data', args.data)
517 | dataset = WikiCS(path, transform=T.NormalizeFeatures())
518 | data = dataset[0]
519 | data.num_classes = dataset.num_classes
520 |
521 | data.stopping_mask = None
522 | data.train_mask = data.train_mask[:, args.seed % 20]
523 | data.val_mask = data.val_mask[:, args.seed % 20]
524 | print_and_log(data)
525 |
526 | elif args.data == 'reddit':
527 | print("Loading Reddit .....")
528 | path = os.path.join(os.path.dirname(os.path.realpath(__file__)), '../data', 'Reddit')
529 | dataset = Reddit(path, transform=T.NormalizeFeatures())
530 |
531 | data = dataset[0]
532 | data.num_classes = dataset.num_classes
533 | #kwargs = {'batch_size': 1024, 'num_workers': 6, 'persistent_workers': True}
534 | print_and_log(data)
535 |
536 | print("Load Reddit Done!")
537 |
538 | elif args.data in['ogbn-arxiv', 'ogbn-products', 'ogbn-proteins', 'ogbn-papers100M']:
539 |
540 | print("Loading Dataset: {}".format(args.data))
541 |
542 | dataset = PygNodePropPredDataset(name=args.data, root='../data')
543 | data = dataset[0]
544 | split_idx = dataset.get_idx_split()
545 | evaluator = Evaluator(args.data)
546 |
547 | edge_index = to_undirected(data.edge_index, data.num_nodes)
548 | #edge_index = add_self_loops(edge_index, num_nodes=data.num_nodes)[0]
549 |
550 | data.edge_index = edge_index
551 | for split in ['train', 'valid', 'test']:
552 | mask = torch.zeros(data.num_nodes, dtype=torch.bool)
553 | mask[split_idx[split]] = True
554 | data[f'{split}_mask'] = mask
555 |
556 | if args.data in ['ogbn-proteins']:
557 | data.y = data.y.to(torch.float)
558 | data.num_classes = dataset.num_tasks
559 | data.node_species = None
560 | row, col = data.edge_index
561 | data.x = scatter(data.edge_attr, col, 0, dim_size=data.num_nodes, reduce='add')
562 | else:
563 | data.num_classes = dataset.num_classes
564 |
565 | if args.data in ['ogbn-arxiv']:
566 | data.y = data.y.squeeze(1)
567 |
568 | print("Load Done !")
569 | print_and_log(data)
570 |
571 | #######################################################################################
572 | ############################# Models ################################################
573 | #######################################################################################
574 |
575 | if args.model not in models:
576 | print('You need to select an existing model via the --model argument. Available models include: ')
577 | for key in models:
578 | print('\t{0}'.format(key))
579 | raise Exception('You need to select a model')
580 | else:
581 |
582 | if args.model == 'gcn':
583 | model = GCNNet(data, args).to(args.device)
584 |
585 | elif args.model == 'sgc':
586 | model = SGCNet(data, args).to(args.device)
587 |
588 | elif args.model == 'appnp':
589 | model = APPNPNet(data, args).to(args.device)
590 |
591 | elif args.model == 'gat':
592 | model = GATNet(data, args).to(args.device)
593 |
594 | elif args.model == 'gcnii':
595 | model = GCNIINet(data, args).to(args.device)
596 |
597 | elif args.model == 'mlp':
598 | model = MLP(data, args).to(args.device)
599 |
600 | elif args.model == 'fagcn':
601 | model = FAGCN(data, args).to(args.device)
602 |
603 | elif args.model == 'h2gcn':
604 | model = HGCN(data, args).to(args.device)
605 |
606 | elif args.model == 'link':
607 | model = LINK(data, args).to(args.device)
608 |
609 | elif args.model == 'gprgnn':
610 | model = GPRGNN(data, args).to(args.device)
611 |
612 | elif args.model == 'mixhop':
613 | model = MixHop(data, args).to(args.device)
614 |
615 | elif args.model == 'fagcnnet':
616 | model = FAGCNNet(data, args).to(args.device)
617 |
618 | elif args.model == 'h2gcnnet':
619 | model = HGCNNet(data, args).to(args.device)
620 |
621 | else:
622 | cls, cls_args = models[args.model]
623 | if args.data == 'cifar100':
624 | cls_args[2] = 100
625 | model = cls(*(cls_args + [args.save_features, args.bench])).to(args.device)
626 | print_and_log(model)
627 | print_and_log('='*60)
628 | print_and_log(args.model)
629 | print_and_log('='*60)
630 |
631 | print_and_log('='*60)
632 | print_and_log('Prune mode: {0}'.format(args.prune))
633 | print_and_log('Growth mode: {0}'.format(args.growth))
634 | print_and_log('Redistribution mode: {0}'.format(args.redistribution))
635 | print_and_log('='*60)
636 |
637 |
638 | optimizer = None
639 | if args.optimizer == 'sgd':
640 | optimizer = optim.SGD(model.parameters(),lr=args.lr,momentum=args.momentum,weight_decay=args.l2, nesterov=True)
641 | elif args.optimizer == 'adam':
642 | optimizer = optim.Adam(model.parameters(),lr=args.lr,weight_decay=args.l2)
643 | else:
644 | print('Unknown optimizer: {0}'.format(args.optimizer))
645 | raise Exception('Unknown optimizer.')
646 |
647 | if args.lr_scheduler:
648 | lr_scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, milestones=[int(args.epochs / 2) * args.multiplier, int(args.epochs * 3 / 4) * args.multiplier], last_epoch=-1)
649 | else:
650 | lr_scheduler = None
651 |
652 |
653 | if args.resume:
654 | if os.path.isfile(args.resume):
655 | print_and_log("=> loading checkpoint '{}'".format(args.resume))
656 | checkpoint = torch.load(args.resume)
657 | model.load_state_dict(checkpoint)
658 | original_acc = evaluate(args, model, args.device, test_loader)
659 |
660 |
661 | if args.fp16:
662 | print('FP16')
663 | optimizer = FP16_Optimizer(optimizer,
664 | static_loss_scale = None,
665 | dynamic_loss_scale = True,
666 | dynamic_loss_args = {'init_scale': 2 ** 16})
667 | model = model.half()
668 |
669 |
670 | mask = None
671 | if args.sparse:
672 | decay = CosineDecay(args.prune_rate, (args.epochs*args.multiplier))
673 | mask = Masking(optimizer, prune_rate=args.prune_rate, death_mode=args.prune, prune_rate_decay=decay, growth_mode=args.growth,
674 | redistribution_mode=args.redistribution, args=args, train_loader=None, device =args.device)
675 | mask.add_module(model, sparse_init=args.sparse_init)
676 |
677 | best_acc = 0.0
678 | t_start = time.time()
679 | for epoch in range(1, args.epochs*args.multiplier + 1):
680 |
681 | #print_and_log("Epoch:{}".format(epoch))
682 | #print("="*50)
683 |
684 | # save models
685 | save_path = './save/' + str(args.model) + '/' + str(args.data) + '/' + str(args.method) + '/' + str(args.seed)
686 | save_subfolder = os.path.join(save_path, 'Multiplier=' + str(args.multiplier) + '_sparsity' + str(1-args.final_density))
687 | if not os.path.exists(save_subfolder): os.makedirs(save_subfolder)
688 |
689 | t0 = time.time()
690 | #print(mask)
691 |
692 | train(args, model, args.device, data, optimizer, epoch, mask)
693 |
694 |
695 | if lr_scheduler is not None:
696 | lr_scheduler.step()
697 | if args.valid_split > 0.0:
698 | if args.data in['ogbn-arxiv', 'ogbn-products', 'ogbn-proteins', 'ogbn-ogbn-papers100M']:
699 | _, val_acc, _ = evaluate_ogb(args, model, args.device, data, evaluator)
700 | else:
701 | _, val_acc, _ = evaluate(args, model, args.device, data)
702 |
703 | # target sparsity is reached
704 | if args.sparse:
705 | if epoch == args.multiplier * args.final_prune_epoch+1:
706 | best_acc = 0.0
707 |
708 | if val_acc > best_acc:
709 | print('Saving model')
710 | best_acc = val_acc
711 | save_checkpoint({
712 | 'epoch': epoch + 1,
713 | 'state_dict': model.state_dict(),
714 | 'optimizer': optimizer.state_dict(),
715 | }, filename=os.path.join(save_subfolder, 'model_final.pth'))
716 |
717 | #print_and_log(' Time taken for epoch: {:.2f} seconds.\n'.format(time.time() - t0))
718 |
719 | train_time_total = time.time() - t_start
720 | print('Testing model')
721 | model.load_state_dict(torch.load(os.path.join(save_subfolder, 'model_final.pth'))['state_dict'])
722 |
723 | t_test_0 = time.time()
724 | if args.data in['ogbn-arxiv', 'ogbn-products', 'ogbn-proteins', 'ogbn-ogbn-papers100M']:
725 | train_acc, val_acc, test_acc = evaluate_ogb(args, model, args.device, data, evaluator, is_test_set=True)
726 | else:
727 | train_acc, val_acc, test_acc = evaluate(args, model, args.device, data, is_test_set=True)
728 | print('Test accuracy is:', test_acc)
729 | results_to_file(args, train_acc, val_acc, test_acc, train_time_total, time.time()- t_test_0)
730 |
731 |
732 |
733 | if __name__ == '__main__':
734 | print("Start Runing!")
735 | main()
736 |
--------------------------------------------------------------------------------
/sparselearning/core.py:
--------------------------------------------------------------------------------
1 | from __future__ import print_function
2 | import torch
3 | import torch.optim as optim
4 | import numpy as np
5 | import math
6 |
7 | # use_cuda = torch.cuda.is_available()
8 | # device = torch.device("cuda" if use_cuda else "cpu")
9 |
10 | def add_sparse_args(parser):
11 | # hyperparameters for Zero-Cost Neuroregeneration
12 | parser.add_argument('--growth', type=str, default='gradient', help='Growth mode. Choose from: momentum, random, and momentum_neuron.')
13 | parser.add_argument('--prune', type=str, default='magnitude', help='Death mode / pruning mode. Choose from: magnitude, SET, threshold, CS_death.')
14 | parser.add_argument('--redistribution', type=str, default='none', help='Redistribution mode. Choose from: momentum, magnitude, nonzeros, or none.')
15 | parser.add_argument('--prune-rate', type=float, default=0.50, help='The pruning rate / death rate for Zero-Cost Neuroregeneration.')
16 | parser.add_argument('--pruning-rate', type=float, default=0.50, help='The pruning rate / death rate.')
17 | parser.add_argument('--sparse', action='store_true', help='Enable sparse mode. Default: True.')
18 | parser.add_argument('--fix', action='store_true', help='Fix topology during training. Default: True.')
19 | parser.add_argument('--update-frequency', type=int, default=100, metavar='N', help='how many iterations to train between mask update')
20 | parser.add_argument('--sparse-init', type=str, default='ERK, uniform distributions for sparse training, global pruning and uniform pruning for pruning', help='sparse initialization')
21 | # hyperparameters for gradually pruning
22 | parser.add_argument('--method', type=str, default='GraNet', help='method name: DST, GraNet, GraNet_uniform, GMP, GMO_uniform')
23 |
24 | parser.add_argument('--init-density', type=float, default=0.50, help='The pruning rate / death rate.')
25 | parser.add_argument('--final-density', type=float, default=0.05, help='The density of the overall sparse network.')
26 | parser.add_argument('--init-density_adj', type=float, default=1.0, help='The pruning rate / death rate.')
27 | parser.add_argument('--final-density_adj', type=float, default=0.5, help='The density of the overall sparse network.')
28 | parser.add_argument('--init-density_feature', type=float, default=1.0, help='The pruning rate / death rate.')
29 | parser.add_argument('--final-density_feature', type=float, default=0.5, help='The density of the overall sparse network.')
30 | parser.add_argument('--init-prune-epoch', type=int, default=0, help='The pruning rate / death rate.')
31 | parser.add_argument('--final-prune-epoch', type=int, default=110, help='The density of the overall sparse network.')
32 | parser.add_argument('--rm-first', action='store_true', help='Keep the first layer dense.')
33 |
34 |
35 |
36 |
37 | class CosineDecay(object):
38 | def __init__(self, prune_rate, T_max, eta_min=0.005, last_epoch=-1):
39 | self.sgd = optim.SGD(torch.nn.ParameterList([torch.nn.Parameter(torch.zeros(1))]), lr=prune_rate)
40 | self.cosine_stepper = torch.optim.lr_scheduler.CosineAnnealingLR(self.sgd, T_max, eta_min, last_epoch)
41 |
42 | def step(self):
43 | self.cosine_stepper.step()
44 |
45 | def get_dr(self):
46 | return self.sgd.param_groups[0]['lr']
47 |
48 | class LinearDecay(object):
49 | def __init__(self, prune_rate, factor=0.99, frequency=600):
50 | self.factor = factor
51 | self.steps = 0
52 | self.frequency = frequency
53 |
54 | def step(self):
55 | self.steps += 1
56 |
57 | def get_dr(self, prune_rate):
58 | if self.steps > 0 and self.steps % self.frequency == 0:
59 | return prune_rate*self.factor
60 | else:
61 | return prune_rate
62 |
63 |
64 |
65 | class Masking(object):
66 | def __init__(self, optimizer,
67 | prune_rate=0.3,
68 | growth_death_ratio=1.0,
69 | prune_rate_decay=None,
70 | death_mode='magnitude',
71 | growth_mode='momentum',
72 | redistribution_mode='momentum',
73 | threshold=0.001,
74 | args=None,
75 | train_loader=None,
76 | device=None):
77 | growth_modes = ['random', 'momentum', 'momentum_neuron', 'gradient']
78 | if growth_mode not in growth_modes:
79 | print('Growth mode: {0} not supported!'.format(growth_mode))
80 | print('Supported modes are:', str(growth_modes))
81 |
82 | self.args = args
83 | self.loader = [1]
84 | self.device = args.device
85 | self.growth_mode = growth_mode
86 | self.death_mode = death_mode
87 | self.growth_death_ratio = growth_death_ratio
88 | self.redistribution_mode = redistribution_mode
89 | self.prune_rate_decay = prune_rate_decay
90 | self.sparse_init = args.sparse_init
91 |
92 |
93 | self.masks = {}
94 | self.final_masks = {}
95 | self.grads = {}
96 | self.nonzero_masks = {}
97 | self.scores = {}
98 | self.pruning_rate = {}
99 | self.modules = []
100 | self.names = []
101 | self.optimizer = optimizer
102 |
103 | self.adjusted_growth = 0
104 | self.adjustments = []
105 | self.baseline_nonzero = None
106 | self.name2baseline_nonzero = {}
107 |
108 | # stats
109 | self.name2variance = {}
110 | self.name2zeros = {}
111 | self.name2nonzeros = {}
112 | self.total_variance = 0
113 | self.total_removed = 0
114 | self.total_zero = 0
115 | self.total_nonzero = 0
116 | self.total_params = 0
117 | self.fc_params = 0
118 | self.prune_rate = prune_rate
119 | self.name2prune_rate = {}
120 | self.steps = 0
121 |
122 | if self.args.fix:
123 | self.prune_every_k_steps = None
124 | else:
125 | self.prune_every_k_steps = self.args.update_frequency
126 |
127 |
128 | def init(self, mode='ER', density=0.05, density_adj=0.05, density_feature=0.05, erk_power_scale=1.0, grad_dict=None):
129 | if self.args.method == 'GMP':
130 | print('initialized with GMP, ones')
131 | self.baseline_nonzero = 0
132 | for module in self.modules:
133 | for name, weight in module.named_parameters():
134 | if name not in self.masks: continue
135 | self.masks[name] = torch.ones_like(weight, dtype=torch.float32, requires_grad=False).to(self.device)
136 | self.baseline_nonzero += (self.masks[name] != 0).sum().int().item()
137 | self.apply_mask()
138 | elif self.sparse_init == 'prune_uniform':
139 | # used for pruning stabability test
140 | print('initialized by prune_uniform')
141 | self.baseline_nonzero = 0
142 | for module in self.modules:
143 | for name, weight in module.named_parameters():
144 | if name not in self.masks: continue
145 | self.masks[name] = (weight!=0).to(self.device)
146 | num_zeros = (weight==0).sum().item()
147 | num_remove = (self.args.pruning_rate) * self.masks[name].sum().item()
148 | k = math.ceil(num_zeros + num_remove)
149 | if num_remove == 0.0: return weight.data != 0.0
150 | x, idx = torch.sort(torch.abs(weight.data.view(-1)))
151 | self.masks[name].data.view(-1)[idx[:k]] = 0.0
152 | self.baseline_nonzero += (self.masks[name] != 0).sum().int().item()
153 | self.apply_mask()
154 |
155 | elif self.sparse_init == 'prune_global':
156 | # used for pruning stabability test
157 | print('initialized by prune_global')
158 | self.baseline_nonzero = 0
159 | total_num_nonzoros = 0
160 | for module in self.modules:
161 | for name, weight in module.named_parameters():
162 | if name not in self.masks: continue
163 | self.masks[name] = (weight!=0).to(self.device)
164 | self.name2nonzeros[name] = (weight!=0).sum().item()
165 | total_num_nonzoros += self.name2nonzeros[name]
166 |
167 | weight_abs = []
168 | for module in self.modules:
169 | for name, weight in module.named_parameters():
170 | if name not in self.masks: continue
171 | weight_abs.append(torch.abs(weight))
172 |
173 | # Gather all scores in a single vector and normalise
174 | all_scores = torch.cat([torch.flatten(x) for x in weight_abs])
175 | num_params_to_keep = int(total_num_nonzoros * (1 - self.args.pruning_rate))
176 |
177 | threshold, _ = torch.topk(all_scores, num_params_to_keep, sorted=True)
178 | acceptable_score = threshold[-1]
179 |
180 | for module in self.modules:
181 | for name, weight in module.named_parameters():
182 | if name not in self.masks: continue
183 | self.masks[name] = ((torch.abs(weight)) >= acceptable_score).float()
184 | self.apply_mask()
185 |
186 | elif self.sparse_init == 'prune_and_grow_uniform':
187 | # used for pruning stabability test
188 | print('initialized by pruning and growing uniformly')
189 |
190 | self.baseline_nonzero = 0
191 | for module in self.modules:
192 | for name, weight in module.named_parameters():
193 | if name not in self.masks: continue
194 | # prune
195 | self.masks[name] = (weight!=0).to(self.device)
196 | num_zeros = (weight==0).sum().item()
197 | num_remove = (self.args.pruning_rate) * self.masks[name].sum().item()
198 | k = math.ceil(num_zeros + num_remove)
199 | if num_remove == 0.0: return weight.data != 0.0
200 | x, idx = torch.sort(torch.abs(weight.data.view(-1)))
201 | self.masks[name].data.view(-1)[idx[:k]] = 0.0
202 | total_regrowth = (self.masks[name]==0).sum().item() - num_zeros
203 |
204 | # set the pruned weights to zero
205 | weight.data = weight.data * self.masks[name]
206 | if 'momentum_buffer' in self.optimizer.state[weight]:
207 | self.optimizer.state[weight]['momentum_buffer'] = self.optimizer.state[weight]['momentum_buffer'] * self.masks[name]
208 |
209 | # grow
210 | grad = grad_dict[name]
211 | grad = grad * (self.masks[name] == 0).float()
212 |
213 | y, idx = torch.sort(torch.abs(grad).flatten(), descending=True)
214 | self.masks[name].data.view(-1)[idx[:total_regrowth]] = 1.0
215 | self.baseline_nonzero += (self.masks[name] != 0).sum().int().item()
216 | self.apply_mask()
217 |
218 | elif self.sparse_init == 'prune_and_grow_global':
219 | # used for pruning stabability test
220 | print('initialized by pruning and growing globally')
221 | self.baseline_nonzero = 0
222 | total_num_nonzoros = 0
223 | for module in self.modules:
224 | for name, weight in module.named_parameters():
225 | if name not in self.masks: continue
226 | self.masks[name] = (weight!=0).to(self.device)
227 | self.name2nonzeros[name] = (weight!=0).sum().item()
228 | total_num_nonzoros += self.name2nonzeros[name]
229 |
230 | weight_abs = []
231 | for module in self.modules:
232 | for name, weight in module.named_parameters():
233 | if name not in self.masks: continue
234 | weight_abs.append(torch.abs(weight))
235 |
236 | # Gather all scores in a single vector and normalise
237 | all_scores = torch.cat([torch.flatten(x) for x in weight_abs])
238 | num_params_to_keep = int(total_num_nonzoros * (1 - self.args.pruning_rate))
239 |
240 | threshold, _ = torch.topk(all_scores, num_params_to_keep, sorted=True)
241 | acceptable_score = threshold[-1]
242 |
243 | for module in self.modules:
244 | for name, weight in module.named_parameters():
245 | if name not in self.masks: continue
246 | self.masks[name] = ((torch.abs(weight)) >= acceptable_score).float()
247 |
248 | # set the pruned weights to zero
249 | weight.data = weight.data * self.masks[name]
250 | if 'momentum_buffer' in self.optimizer.state[weight]:
251 | self.optimizer.state[weight]['momentum_buffer'] = self.optimizer.state[weight]['momentum_buffer'] * self.masks[name]
252 |
253 | ### grow
254 | for module in self.modules:
255 | for name, weight in module.named_parameters():
256 | if name not in self.masks: continue
257 | total_regrowth = self.name2nonzeros[name] - (self.masks[name]!=0).sum().item()
258 | grad = grad_dict[name]
259 | grad = grad * (self.masks[name] == 0).float()
260 |
261 | y, idx = torch.sort(torch.abs(grad).flatten(), descending=True)
262 | self.masks[name].data.view(-1)[idx[:total_regrowth]] = 1.0
263 | self.baseline_nonzero += (self.masks[name] != 0).sum().int().item()
264 | self.apply_mask()
265 |
266 | elif self.sparse_init == 'uniform':
267 | self.baseline_nonzero = 0
268 | for module in self.modules:
269 | for name, weight in module.named_parameters():
270 | if name not in self.masks: continue
271 | if name == "edge_weight_train":
272 | self.masks[name][:] = (torch.rand(weight.shape) < density_adj).float().data.to(self.device) #
273 | #self.baseline_nonzero += weight.numel() * density_adj
274 |
275 | elif name == "x_weight":
276 | self.masks[name][:] = (torch.rand(weight.shape) < density_feature).float().data.to(self.device) #
277 | #self.baseline_nonzero += weight.numel() * density_feature
278 |
279 | else:
280 | self.masks[name][:] = (torch.rand(weight.shape) < density).float().data.to(self.device) #
281 | #self.baseline_nonzero += weight.numel() * density
282 | self.apply_mask()
283 |
284 | elif self.sparse_init == 'ERK':
285 | print('initialize by ERK')
286 | for name, weight in self.masks.items():
287 | if name == "edge_weight_train": continue
288 | if name == "x_weight": continue
289 | self.total_params += weight.numel()
290 | if 'classifier' in name:
291 | self.fc_params = weight.numel()
292 | is_epsilon_valid = False
293 | dense_layers = set()
294 | while not is_epsilon_valid:
295 |
296 | divisor = 0
297 | rhs = 0
298 | raw_probabilities = {}
299 | for name, mask in self.masks.items():
300 | if name == "edge_weight_train": continue
301 | if name == "x_weight": continue
302 | n_param = np.prod(mask.shape)
303 | # if name == "edge_weight_train":
304 | # n_zeros = n_param * (1 - density_adj)
305 | # n_ones = n_param * density_adj
306 | # elif name == "x_weight":
307 | # n_zeros = n_param * (1 - density_feature)
308 | # n_ones = n_param * density_feature
309 | # else:
310 | n_zeros = n_param * (1 - density)
311 | n_ones = n_param * density
312 |
313 |
314 | if name in dense_layers:
315 | # See `- default_sparsity * (N_3 + N_4)` part of the equation above.
316 | rhs -= n_zeros
317 |
318 | else:
319 | # Corresponds to `(1 - default_sparsity) * (N_1 + N_2)` part of the
320 | # equation above.
321 | rhs += n_ones
322 | # Erdos-Renyi probability: epsilon * (n_in + n_out / n_in * n_out).
323 | raw_probabilities[name] = (
324 | np.sum(mask.shape) / np.prod(mask.shape)
325 | ) ** erk_power_scale
326 | # Note that raw_probabilities[mask] * n_param gives the individual
327 | # elements of the divisor.
328 | divisor += raw_probabilities[name] * n_param
329 | # By multipliying individual probabilites with epsilon, we should get the
330 | # number of parameters per layer correctly.
331 | epsilon = rhs / divisor
332 | # If epsilon * raw_probabilities[mask.name] > 1. We set the sparsities of that
333 | # mask to 0., so they become part of dense_layers sets.
334 | max_prob = np.max(list(raw_probabilities.values()))
335 | max_prob_one = max_prob * epsilon
336 | if max_prob_one > 1:
337 | is_epsilon_valid = False
338 | for mask_name, mask_raw_prob in raw_probabilities.items():
339 | if mask_raw_prob == max_prob:
340 | print(f"Sparsity of var:{mask_name} had to be set to 0.")
341 | dense_layers.add(mask_name)
342 | else:
343 | is_epsilon_valid = True
344 |
345 | density_dict = {}
346 | total_nonzero = 0.0
347 | # With the valid epsilon, we can set sparsities of the remaning layers.
348 | for name, mask in self.masks.items():
349 | if name == "edge_weight_train": continue
350 | if name == "x_weight": continue
351 | n_param = np.prod(mask.shape)
352 | if name in dense_layers:
353 | density_dict[name] = 1.0
354 | else:
355 | probability_one = epsilon * raw_probabilities[name]
356 | density_dict[name] = probability_one
357 | print(
358 | f"layer: {name}, shape: {mask.shape}, density: {density_dict[name]}"
359 | )
360 | self.masks[name][:] = (torch.rand(mask.shape) < density_dict[name]).float().data.to(self.device)
361 |
362 | total_nonzero += density_dict[name] * mask.numel()
363 | print(f"Overall sparsity {total_nonzero / self.total_params}")
364 |
365 | self.apply_mask()
366 |
367 |
368 | total_size = 0
369 | for name, weight in self.masks.items():
370 | total_size += weight.numel()
371 |
372 | sparse_size = 0
373 | for name, weight in self.masks.items():
374 | sparse_size += (weight != 0).sum().int().item()
375 | print('Total parameters under sparsity level of {0}: {1}'.format(density, sparse_size / total_size))
376 |
377 | def step(self):
378 | self.optimizer.step()
379 | self.apply_mask()
380 | self.prune_rate_decay.step()
381 | self.prune_rate = self.prune_rate_decay.get_dr()
382 | self.steps += 1
383 |
384 | if self.prune_every_k_steps is not None:
385 | if self.args.method == 'GraNet':
386 | if self.steps >= (self.args.init_prune_epoch * len(self.loader)*self.args.multiplier) and self.steps % self.prune_every_k_steps == 0:
387 | self.pruning(self.steps)
388 | self.truncate_weights(self.steps)
389 | self.print_nonzero_counts()
390 | elif self.args.method == 'GraNet_uniform':
391 | if self.steps >= (self.args.init_prune_epoch * len(self.loader)* self.args.multiplier) and self.steps % self.prune_every_k_steps == 0:
392 | self.pruning_uniform(self.steps)
393 | self.truncate_weights(self.steps)
394 | self.print_nonzero_counts()
395 | # _, _ = self.fired_masks_update()
396 | elif self.args.method == 'DST':
397 | if self.steps % self.prune_every_k_steps == 0:
398 | self.truncate_weights()
399 | self.print_nonzero_counts()
400 | elif self.args.method == 'GMP':
401 | if self.steps >= (self.args.init_prune_epoch * len(self.loader) * self.args.multiplier) and self.steps % self.prune_every_k_steps == 0:
402 | self.pruning(self.steps)
403 | elif self.args.method == 'GMP_uniform':
404 | if self.steps >= (self.args.init_prune_epoch * len(self.loader) * self.args.multiplier) and self.steps % self.prune_every_k_steps == 0:
405 | self.pruning_uniform(self.steps)
406 |
407 |
408 | def pruning(self, step):
409 | # prune_rate = 1 - self.args.final_density - self.args.init_density
410 | curr_prune_iter = int(step / self.prune_every_k_steps)
411 | final_iter = int((self.args.final_prune_epoch * len(self.loader)*self.args.multiplier) / self.prune_every_k_steps)
412 | ini_iter = int((self.args.init_prune_epoch * len(self.loader)*self.args.multiplier) / self.prune_every_k_steps)
413 | total_prune_iter = final_iter - ini_iter
414 |
415 |
416 |
417 | if curr_prune_iter >= ini_iter and curr_prune_iter <= final_iter - 1:
418 | print('******************************************************')
419 | print(f'Pruning Progress is {curr_prune_iter - ini_iter} / {total_prune_iter}')
420 | print('******************************************************')
421 | print("Pruning Start!!")
422 | prune_decay = (1 - ((curr_prune_iter - ini_iter) / total_prune_iter)) ** 3
423 | curr_prune_rate = (1 - self.args.init_density) + (self.args.init_density - self.args.final_density) * (
424 | 1 - prune_decay)
425 |
426 | curr_prune_rate_adj = (1 - self.args.init_density_adj) + (self.args.init_density_adj - self.args.final_density_adj) * (
427 | 1 - prune_decay)
428 |
429 | curr_prune_rate_feature = (1 - self.args.init_density_feature) + (self.args.init_density_feature - self.args.final_density_feature) * (
430 | 1 - prune_decay)
431 |
432 | weight_abs = []
433 | adj_abs =[]
434 | feature_abs =[]
435 | for module in self.modules:
436 | for name, weight in module.named_parameters():
437 | if name not in self.masks: continue
438 |
439 | if name == "edge_weight_train":
440 | adj_abs.append(torch.abs(weight))
441 | elif name == "x_weight":
442 | feature_abs.append(torch.abs(weight))
443 | else:
444 | weight_abs.append(torch.abs(weight))
445 |
446 | # Gather all scores in a single vector and normalise
447 | if self.args.weight_sparse:
448 | all_scores = torch.cat([torch.flatten(x) for x in weight_abs])
449 | num_params_to_keep = int(len(all_scores) * (1 - curr_prune_rate))
450 |
451 | threshold, _ = torch.topk(all_scores, num_params_to_keep, sorted=True)
452 | acceptable_score_weight = threshold[-1]
453 |
454 | # Gather adj scores
455 | if self.args.adj_sparse:
456 | all_scores = torch.cat([torch.flatten(x) for x in adj_abs])
457 | num_params_to_keep = int(len(all_scores) * (1 - curr_prune_rate_adj))
458 |
459 | threshold_adj, _ = torch.topk(all_scores, num_params_to_keep, sorted=True)
460 | acceptable_score_adj = threshold_adj[-1]
461 |
462 | # Gather adj scores
463 |
464 | if self.args.feature_sparse:
465 | all_scores = torch.cat([torch.flatten(x) for x in feature_abs])
466 | num_params_to_keep = int(len(all_scores) * (1 - curr_prune_rate_feature))
467 |
468 | threshold_feature, _ = torch.topk(all_scores, num_params_to_keep, sorted=True)
469 | acceptable_score_feature = threshold_feature[-1]
470 |
471 |
472 | for module in self.modules:
473 | for name, weight in module.named_parameters():
474 | if name not in self.masks: continue
475 |
476 | if self.args.adj_sparse:
477 | if name == "edge_weight_train":
478 | self.masks[name] = ((torch.abs(weight)) > acceptable_score_adj).float()
479 | print("Add Sparse Mask --- Graph Adj: {} !".format(name))
480 |
481 | if self.args.feature_sparse:
482 | if name == "x_weight":
483 | self.masks[name] = ((torch.abs(weight)) > acceptable_score_feature).float()
484 | print("Add Sparse Mask --- Graph Feature: {} !".format(name))
485 |
486 | if self.args.weight_sparse:
487 | if len(weight.size()) == 4 or len(weight.size()) == 2:
488 | self.masks[name] = ((torch.abs(weight)) > acceptable_score_weight).float()
489 | #must be > to prevent acceptable_score is zero, leading to dense tensors
490 | print("Add Sparse Mask --- Model Weight: {} !".format(name))
491 | print("="*40)
492 | self.apply_mask()
493 |
494 | weight_total_size = 1
495 | adj_total_size = 1
496 | feature_total_size = 1
497 |
498 | for name, weight in self.masks.items():
499 | if name == "edge_weight_train":
500 | adj_total_size += weight.numel()
501 | elif name == "x_weight":
502 | feature_total_size += weight.numel()
503 | else:
504 | weight_total_size += weight.numel()
505 |
506 | print('Total Model parameters:{}, Graph Edge Numbers:{}, Feature Channels:{}'.format(weight_total_size,adj_total_size,feature_total_size))
507 |
508 | weight_sparse_size = 0
509 | adj_sparse_size = 0
510 | feature_sparse_size = 0
511 |
512 | for name, weight in self.masks.items():
513 |
514 | if name == "edge_weight_train":
515 | adj_sparse_size += (weight != 0).sum().int().item()
516 | elif name == "x_weight":
517 | feature_sparse_size += (weight != 0).sum().int().item()
518 | else:
519 | weight_sparse_size += (weight != 0).sum().int().item()
520 |
521 | print('Model Parameters Sparsity after pruning: {} \nGraph Edge Numbers after pruning: {} \nFeature Channels Sparsity after pruning:{}'.format(
522 | (weight_total_size-weight_sparse_size) / weight_total_size,
523 | (adj_total_size-adj_sparse_size) / adj_total_size,
524 | (feature_total_size-feature_sparse_size) / feature_total_size))
525 | print("="*40)
526 |
527 | def pruning_uniform(self, step):
528 | # prune_rate = 1 - self.args.final_density - self.args.init_density
529 | curr_prune_iter = int(step / self.prune_every_k_steps)
530 | final_iter = (self.args.final_prune_epoch * len(self.loader)*self.args.multiplier) / self.prune_every_k_steps
531 | ini_iter = (self.args.init_prune_epoch * len(self.loader)*self.args.multiplier) / self.prune_every_k_steps
532 | total_prune_iter = final_iter - ini_iter
533 |
534 |
535 | if curr_prune_iter >= ini_iter and curr_prune_iter <= final_iter:
536 | print('******************************************************')
537 | print(f'Pruning Progress is {curr_prune_iter - ini_iter} / {total_prune_iter}')
538 | print('******************************************************')
539 |
540 | prune_decay = (1 - ((curr_prune_iter - ini_iter) / total_prune_iter)) ** 3
541 | curr_prune_rate = (1 - self.args.init_density) + (self.args.init_density - self.args.final_density) * (
542 | 1 - prune_decay)
543 |
544 | curr_prune_rate_adj = (1 - self.args.init_density_adj) + (self.args.init_density_adj - self.args.final_density_adj) * (
545 | 1 - prune_decay)
546 |
547 | curr_prune_rate_feature = (1 - self.args.init_density_feature) + (self.args.init_density_feature - self.args.final_density_feature) * (
548 | 1 - prune_decay)
549 |
550 | # keep the density of the last layer as 0.2 if spasity is larger then 0.8
551 | # if curr_prune_rate >= 0.8:
552 | # curr_prune_rate = 1 - (self.total_params * (1-curr_prune_rate) - 0.2 * self.fc_params)/(self.total_params-self.fc_params)
553 |
554 | # for module in self.modules:
555 | # for name, weight in module.named_parameters():
556 | # if name not in self.masks: continue
557 | # score = torch.flatten(torch.abs(weight))
558 | # if 'classifier' in name:
559 | # num_params_to_keep = int(len(score) * 0.2)
560 | # threshold, _ = torch.topk(score, num_params_to_keep, sorted=True)
561 | # acceptable_score = threshold[-1]
562 | # self.masks[name] = ((torch.abs(weight)) >= acceptable_score).float()
563 | # else:
564 | # num_params_to_keep = int(len(score) * (1 - curr_prune_rate))
565 | # threshold, _ = torch.topk(score, num_params_to_keep, sorted=True)
566 | # acceptable_score = threshold[-1]
567 | # self.masks[name] = ((torch.abs(weight)) >= acceptable_score).float()
568 |
569 | for module in self.modules:
570 | for name, weight in module.named_parameters():
571 | if name not in self.masks: continue
572 |
573 | score = torch.flatten(torch.abs(weight))
574 |
575 | if name == "edge_weight_train":
576 | num_params_to_keep = int(len(score) * (1 - curr_prune_rate_adj))
577 | print("Add Sparse Mask --- Graph Adj: {} !".format(name))
578 | elif name == "x_weight":
579 | num_params_to_keep = int(len(score) * (1 - curr_prune_rate_feature))
580 | print("Add Sparse Mask --- Graph Feature: {} !".format(name))
581 | else:
582 | num_params_to_keep = int(len(score) * (1 - curr_prune_rate))
583 | print("Add Sparse Mask --- Model Weight: {} !".format(name))
584 |
585 | threshold, _ = torch.topk(score, num_params_to_keep, sorted=True)
586 | acceptable_score = threshold[-1]
587 | self.masks[name] = ((torch.abs(weight)) >= acceptable_score).float()
588 |
589 |
590 | self.apply_mask()
591 |
592 | weight_total_size = 1
593 | adj_total_size = 1
594 | feature_total_size = 1
595 |
596 | for name, weight in self.masks.items():
597 | if name == "edge_weight_train":
598 | adj_total_size += weight.numel()
599 | elif name == "x_weight":
600 | feature_total_size += weight.numel()
601 | else:
602 | weight_total_size += weight.numel()
603 |
604 | print('Total Model parameters:{}, Graph Edge Numbers:{}, Feature Channels:{}'.format(weight_total_size,adj_total_size,feature_total_size))
605 |
606 | weight_sparse_size = 0
607 | adj_sparse_size = 0
608 | feature_sparse_size = 0
609 |
610 | for name, weight in self.masks.items():
611 |
612 | if name == "edge_weight_train":
613 | adj_sparse_size += (weight != 0).sum().int().item()
614 | elif name == "x_weight":
615 | feature_sparse_size += (weight != 0).sum().int().item()
616 | else:
617 | weight_sparse_size += (weight != 0).sum().int().item()
618 |
619 | print('Model Parameters Sparsity after pruning: {} \nGraph Edge Numbers after pruning: {} \nFeature Channels Sparsity after pruning:{}'.format(
620 | (weight_total_size-weight_sparse_size) / weight_total_size,
621 | (adj_total_size-adj_sparse_size) / adj_total_size,
622 | (feature_total_size-feature_sparse_size) / feature_total_size))
623 | print("="*40)
624 |
625 | # total_size = 0
626 | # for name, weight in self.masks.items():
627 | # total_size += weight.numel()
628 | # print('Total Model parameters:', total_size)
629 |
630 | # sparse_size = 0
631 | # for name, weight in self.masks.items():
632 | # sparse_size += (weight != 0).sum().int().item()
633 |
634 | # print('Sparsity after pruning: {0}'.format(
635 | # (total_size-sparse_size) / total_size))
636 |
637 |
638 | def add_module(self, module, sparse_init='ERK', grad_dic=None):
639 | self.module = module
640 | self.sparse_init = self.sparse_init
641 | self.modules.append(module)
642 | for name, tensor in module.named_parameters():
643 |
644 | if self.args.adj_sparse:
645 | if name == "edge_weight_train":
646 | self.names.append(name)
647 | self.masks[name] = torch.ones_like(tensor, dtype=torch.float32, requires_grad=False).to(self.device)
648 | print("Add Sparse Module --- Graph Adj:{} Sparse Module!".format(name))
649 |
650 | if self.args.feature_sparse:
651 | if name == "x_weight":
652 | self.names.append(name)
653 | self.masks[name] = torch.ones_like(tensor, dtype=torch.float32, requires_grad=False).to(self.device)
654 | print("Add Sparse Module --- Graph Feature: {} !".format(name))
655 |
656 | if self.args.weight_sparse:
657 | if len(tensor.size()) == 4 or len(tensor.size()) == 2:
658 | self.names.append(name)
659 | self.masks[name] = torch.ones_like(tensor, dtype=torch.float32, requires_grad=False).to(self.device)
660 | print("Add Sparse Module --- Model Weight: {} !".format(name))
661 |
662 |
663 | print("Add Module Done!")
664 | print("="*40)
665 |
666 | if self.args.rm_first:
667 | for name, tensor in module.named_parameters():
668 | if 'conv.weight' in name or 'feature.0.weight' in name:
669 | self.masks.pop(name)
670 | print(f"pop out {name}")
671 |
672 | self.init( mode=self.args.sparse_init,
673 | density=self.args.init_density,
674 | density_adj =self.args.init_density_adj,
675 | density_feature= self.args.init_density_feature,
676 | grad_dict=grad_dic) # init weight
677 |
678 |
679 | def remove_weight(self, name):
680 | if name in self.masks:
681 | print('Removing {0} of size {1} = {2} parameters.'.format(name, self.masks[name].shape,
682 | self.masks[name].numel()))
683 | self.masks.pop(name)
684 | elif name + '.weight' in self.masks:
685 | print('Removing {0} of size {1} = {2} parameters.'.format(name, self.masks[name + '.weight'].shape,
686 | self.masks[name + '.weight'].numel()))
687 | self.masks.pop(name + '.weight')
688 | else:
689 | print('ERROR', name)
690 |
691 | def remove_weight_partial_name(self, partial_name):
692 | removed = set()
693 | for name in list(self.masks.keys()):
694 | if partial_name in name:
695 |
696 | print('Removing {0} of size {1} with {2} parameters...'.format(name, self.masks[name].shape,
697 | np.prod(self.masks[name].shape)))
698 | removed.add(name)
699 | self.masks.pop(name)
700 |
701 | print('Removed {0} layers.'.format(len(removed)))
702 |
703 | i = 0
704 | while i < len(self.names):
705 | name = self.names[i]
706 | if name in removed:
707 | self.names.pop(i)
708 | else:
709 | i += 1
710 |
711 | def remove_type(self, nn_type):
712 | for module in self.modules:
713 | for name, module in module.named_modules():
714 | if isinstance(module, nn_type):
715 | self.remove_weight(name)
716 |
717 | def apply_mask(self):
718 | for module in self.modules:
719 | for name, tensor in module.named_parameters():
720 | if name in self.masks:
721 | tensor.data = tensor.data*self.masks[name]
722 | #print("Trying to Apply Mask on {}".format(name))
723 | if 'momentum_buffer' in self.optimizer.state[tensor]:
724 | self.optimizer.state[tensor]['momentum_buffer'] = self.optimizer.state[tensor]['momentum_buffer']*self.masks[name]
725 |
726 |
727 | def truncate_weights(self, step=None):
728 |
729 | curr_prune_iter = int(step / self.prune_every_k_steps)
730 | final_iter = int((self.args.final_prune_epoch * len(self.loader)*self.args.multiplier) / self.prune_every_k_steps)
731 | ini_iter = int((self.args.init_prune_epoch * len(self.loader)*self.args.multiplier) / self.prune_every_k_steps)
732 | total_prune_iter = final_iter - ini_iter
733 |
734 | if curr_prune_iter >= ini_iter and curr_prune_iter <= final_iter - 1:
735 | print('******************************************************')
736 | print(f'Death and Growth Progress is {curr_prune_iter - ini_iter} / {total_prune_iter}')
737 | print('******************************************************')
738 |
739 | self.gather_statistics()
740 |
741 | # prune
742 | for module in self.modules:
743 | for name, weight in module.named_parameters():
744 | if name not in self.masks: continue
745 | mask = self.masks[name]
746 |
747 | new_mask = self.magnitude_death(mask, weight, name)
748 | self.pruning_rate[name] = int(self.name2nonzeros[name] - new_mask.sum().item())
749 | self.masks[name][:] = new_mask
750 |
751 | # grow
752 | for module in self.modules:
753 | for name, weight in module.named_parameters():
754 | if name not in self.masks: continue
755 | new_mask = self.masks[name].data.byte()
756 |
757 | if self.args.growth_schedule == "gradient":
758 | new_mask = self.gradient_growth(name, new_mask, self.pruning_rate[name], weight)
759 | elif self.args.growth_schedule == "momentum":
760 | new_mask = self.momentum_growth(name, new_mask, self.pruning_rate[name], weight)
761 | elif self.args.growth_schedule == "random":
762 | new_mask = self.random_growth(name, new_mask, self.pruning_rate[name], weight)
763 | # exchanging masks
764 | self.masks.pop(name)
765 | self.masks[name] = new_mask.float()
766 |
767 |
768 | self.apply_mask()
769 |
770 |
771 | '''
772 | REDISTRIBUTION
773 | '''
774 |
775 | def gather_statistics(self):
776 | self.name2nonzeros = {}
777 | self.name2zeros = {}
778 |
779 | for module in self.modules:
780 | for name, tensor in module.named_parameters():
781 | if name not in self.masks: continue
782 | mask = self.masks[name]
783 |
784 | self.name2nonzeros[name] = mask.sum().item()
785 | self.name2zeros[name] = mask.numel() - self.name2nonzeros[name]
786 |
787 | ############################ DEATH ###########################
788 |
789 | def magnitude_death(self, mask, weight, name):
790 | num_remove = math.ceil(self.prune_rate*self.name2nonzeros[name])
791 | if num_remove == 0.0: return weight.data != 0.0
792 | num_zeros = self.name2zeros[name]
793 | k = math.ceil(num_zeros + num_remove)
794 | x, idx = torch.sort(torch.abs(weight.data.view(-1)))
795 | threshold = x[k-1].item()
796 |
797 | return (torch.abs(weight.data) > threshold)
798 |
799 |
800 | ########################### GROWTH ###########################
801 |
802 | def random_growth(self, name, new_mask, total_regrowth, weight):
803 | n = (new_mask==0).sum().item()
804 | if n == 0: return new_mask
805 | expeced_growth_probability = (total_regrowth/n)
806 | new_weights = torch.rand(new_mask.shape).to(self.device) < expeced_growth_probability #lsw
807 | # new_weights = torch.rand(new_mask.shape) < expeced_growth_probability
808 | new_mask_ = new_mask.byte() | new_weights
809 | if (new_mask_!=0).sum().item() == 0:
810 | new_mask_ = new_mask
811 | return new_mask_
812 |
813 | def momentum_growth(self, name, new_mask, total_regrowth, weight):
814 | grad = self.get_momentum_for_weight(weight)
815 | grad = grad*(new_mask==0).float()
816 | y, idx = torch.sort(torch.abs(grad).flatten(), descending=True)
817 | new_mask.data.view(-1)[idx[:total_regrowth]] = 1.0
818 |
819 | return new_mask
820 |
821 |
822 | def gradient_growth(self, name, new_mask, total_regrowth, weight):
823 | grad = self.get_gradient_for_weights(weight)
824 | grad = grad*(new_mask==0).float()
825 |
826 | y, idx = torch.sort(torch.abs(grad).flatten(), descending=True)
827 | new_mask.data.view(-1)[idx[:total_regrowth]] = 1.0
828 |
829 | return new_mask
830 |
831 |
832 |
833 | '''
834 | UTILITY
835 | '''
836 | def get_momentum_for_weight(self, weight):
837 | if 'exp_avg' in self.optimizer.state[weight]:
838 | adam_m1 = self.optimizer.state[weight]['exp_avg']
839 | adam_m2 = self.optimizer.state[weight]['exp_avg_sq']
840 | grad = adam_m1/(torch.sqrt(adam_m2) + 1e-08)
841 | elif 'momentum_buffer' in self.optimizer.state[weight]:
842 | grad = self.optimizer.state[weight]['momentum_buffer']
843 | return grad
844 |
845 | def get_gradient_for_weights(self, weight):
846 | grad = weight.grad.clone()
847 | return grad
848 |
849 | def print_nonzero_counts(self):
850 | for module in self.modules:
851 | for name, tensor in module.named_parameters():
852 | if name not in self.masks: continue
853 | mask = self.masks[name]
854 | num_nonzeros = (mask != 0).sum().item()
855 | val = '{0}: {1}->{2}, density: {3:.3f}'.format(name, self.name2nonzeros[name], num_nonzeros, num_nonzeros/float(mask.numel()))
856 | print(val)
857 |
858 | print('Death rate: {0}\n'.format(self.prune_rate))
859 | print("="*40)
860 |
861 | def reset_momentum(self):
862 | """
863 | Taken from: https://github.com/AlliedToasters/synapses/blob/master/synapses/SET_layer.py
864 | Resets buffers from memory according to passed indices.
865 | When connections are reset, parameters should be treated
866 | as freshly initialized.
867 | """
868 | for module in self.modules:
869 | for name, tensor in module.named_parameters():
870 | if name not in self.masks: continue
871 | mask = self.masks[name]
872 | weights = list(self.optimizer.state[tensor])
873 | for w in weights:
874 | if w == 'momentum_buffer':
875 | # momentum
876 | if self.args.reset_mom_zero:
877 | print('zero')
878 | self.optimizer.state[tensor][w][mask == 0] = 0
879 | else:
880 | print('mean')
881 | self.optimizer.state[tensor][w][mask==0] = torch.mean(self.optimizer.state[tensor][w][mask.byte()])
882 | # self.optimizer.state[tensor][w][mask==0] = 0
883 | elif w == 'square_avg' or \
884 | w == 'exp_avg' or \
885 | w == 'exp_avg_sq' or \
886 | w == 'exp_inf':
887 | # Adam
888 | self.optimizer.state[tensor][w][mask==0] = torch.mean(self.optimizer.state[tensor][w][mask.byte()])
889 |
890 | def fired_masks_update(self):
891 | ntotal_fired_weights = 0.0
892 | ntotal_weights = 0.0
893 | layer_fired_weights = {}
894 | for module in self.modules:
895 | for name, weight in module.named_parameters():
896 | if name not in self.masks: continue
897 | self.fired_masks[name] = self.masks[name].data.byte() | self.fired_masks[name].data.byte()
898 | ntotal_fired_weights += float(self.fired_masks[name].sum().item())
899 | ntotal_weights += float(self.fired_masks[name].numel())
900 | layer_fired_weights[name] = float(self.fired_masks[name].sum().item())/float(self.fired_masks[name].numel())
901 | print('Layerwise percentage of the fired weights of', name, 'is:', layer_fired_weights[name])
902 | total_fired_weights = ntotal_fired_weights/ntotal_weights
903 | print('The percentage of the total fired weights is:', total_fired_weights)
904 | return layer_fired_weights, total_fired_weights
905 |
--------------------------------------------------------------------------------