├── datasets
└── .gitkeep
├── .gitignore
├── environment.yml
├── main.py
├── argument.py
├── README.md
├── data.py
├── experiment.py
└── kcn.py
/datasets/.gitkeep:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | .*.swp
2 | .*.swo
3 | *.out
4 | *.err
5 | *.npz
6 | *.slurm
7 | __pycache__
8 | .ipynb_checkpoints/
9 |
10 |
--------------------------------------------------------------------------------
/environment.yml:
--------------------------------------------------------------------------------
1 | name: kcn_env
2 |
3 | dependencies:
4 | - python=3.10
5 | - numpy
6 | - pytorch::pytorch=2.0
7 | - pyg::pyg
8 | - conda-forge::scikit-learn
9 |
--------------------------------------------------------------------------------
/main.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import torch
3 | from argument import parse_opt
4 | from experiment import run_kcn
5 |
6 |
7 | # repeat the experiment in the paper
8 | def random_runs(args):
9 | test_errors = []
10 | for args.random_seed in range(10):
11 | np.random.seed(args.random_seed)
12 | torch.manual_seed(args.random_seed)
13 |
14 | err = run_kcn(args)
15 | test_errors.append(err)
16 |
17 | test_errors = np.array(test_errors)
18 | return test_errors
19 |
20 |
21 |
22 | if __name__ == "__main__":
23 |
24 | args = parse_opt()
25 | print(args)
26 |
27 | # set random seeds
28 | np.random.seed(args.random_seed)
29 | torch.manual_seed(args.random_seed)
30 |
31 | # run experiment on one train-test split
32 | err = run_kcn(args)
33 | print('Model: {}, test error: {}\n'.format(args.model, err))
34 |
35 |
36 | ## run all experiments on one dataset
37 | #model_error = dict()
38 | #for args.model in ["kcn", "kcn_gat", "kcn_sage"]:
39 | # test_errors = random_runs(args)
40 | # model_error[args.model] = (np.mean(test_errors), np.std(test_errors))
41 | # print(model_error)
42 |
43 |
44 |
--------------------------------------------------------------------------------
/argument.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import torch
3 |
4 | def parse_opt():
5 |
6 | # Settings
7 | parser = argparse.ArgumentParser()
8 | parser.add_argument('--random_seed', type=int, default="5", help="The random seed")
9 | parser.add_argument('--dataset', type=str, default="bird_count", help="The dataset name: currently can only be 'bird_count'")
10 | parser.add_argument('--data_path', type=str, default="./datasets", help="The folder containing the data file. The default file is './data/{dataset}.pkl'")
11 | parser.add_argument('--use_default_test_set', type=bool, default=False, help='Use the default test set from the data')
12 |
13 | parser.add_argument('--model', type=str, default='kcn', help='One of three model types, kcn, kcn_gat, kcn_sage, which use GCN, GAT, and GraphSAGE respectively')
14 | parser.add_argument('--n_neighbors', type=int, default=5, help='Number of neighbors')
15 | parser.add_argument('--length_scale', default="auto", help='Length scale for RBF kernel. If set to "auto", then it will be set to the median of neighbor distances')
16 | parser.add_argument('--hidden_sizes', type=list, default=[8, 8, 8], help='Number of units in hidden layers, also decide the number of layers')
17 | parser.add_argument('--dropout', type=float, default=0.1, help='Dropout rate (1 - keep probability).')
18 | parser.add_argument('--last_activation', type=str, default='none', help='Activation for the last layer')
19 |
20 | parser.add_argument('--loss_type', type=str, default='squared_error', help='Loss type')
21 | parser.add_argument('--validation_size', type=int, default=5000, help='Validation size')
22 |
23 | parser.add_argument('--lr', type=float, default=5e-3, help='Learning rate.')
24 | parser.add_argument('--weight_decay', type=float, default=5e-4, help='Weight decay for the optimizer.')
25 | parser.add_argument('--epochs', type=int, default=100, help='Number of training epochs.')
26 | parser.add_argument('--es_patience', type=int, default=20, help='Patience for early stopping.')
27 | parser.add_argument('--batch_size', type=int, default=64, help='Batch size')
28 |
29 | parser.add_argument('--device', type=str, default="auto", help='Computation device.')
30 |
31 | args, unknowns = parser.parse_known_args()
32 |
33 | if args.device == "auto":
34 | args.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
35 | else:
36 | args.device = torch.device(args.device)
37 |
38 | return args
39 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # A PyTorch Implementation of Kriging Convolutional Networks
2 |
3 | **Wenrui Zhao1, Ruiqi Xu1, and Li-Ping Liu1**
4 |
5 | 1 Department of Computer Science, Tufts University
6 |
7 | ### Overview
8 |
9 | This repo contains the PyTorch implementation of Kriging Convolutional Networks (KCNs) [1]. The original code release was an implementation with TensorFlow 1.x, which is not well supported now. In this repo, we update the implementation with PyTorch and PyTorch Geometric. We hope this new implementation will help researchers in this field. If you have any questions about this repo, please feel free to raise issues.
10 |
11 | ### Data
12 |
13 | The model works with datasets in the `SpatialDataset` format. In particular, it contains three fields:
14 |
15 | * `dataset.coords`: a tensor with shape `[n, 2]`, each row is a 2-D coordinate of an instance. Coordinates with other dimensions can also be handled by KCN.
16 | * `dataset.features`: a tensor with shape `[n, d]`, each row is a `d`-dimensional feature vector of an instance.
17 | * `dataset.y`: a tensor with shape `[n, l]`, each row is a vector of `l` labels of an instance. The current example only work for one-dimensional continuous labels.
18 |
19 | The current repo provides a running example of KCN on a dataset of bird counts (counts of wood thrush reported in Eeatern US during June 2014) [[download link](https://tufts.box.com/v/kcn-bird-count-dataset)].
20 |
21 | ### Model
22 |
23 | A KCN predicts a data point's label based on data points in its neighborhood. The KCN model stores a training set internally. To make a prediction for a data point, it looks up neighbors for the data point and construct, forms an attributed graph over data points in the neighborhood, and then uses a Graph Neural Network (GNN) to predict the data point's label. During training, these graphs are computed before training to save repeated graph constructions. The general structure of KCN is similar to a k-nearest-neighbor classifier, though the former one employs a much more flexible predictive function than simple averaging.
24 |
25 | In the implementation, a KCN model is a PyTorch module. It is initialized with a `SpatialDataset`. In the `forward` function, it takes coordinates and features of a batch of data points and then predicts their labels.
26 |
27 | ### Run the code
28 |
29 | #### Requirements
30 | The code has been tested on a linux platform with the following packages installed: `python=3.10, numpy=1.24.3, torch=2.01, scikit-learn=1.2.2, pyg=2.3.0`. You can install the environment from `environment.yml` with `conda env create -f environment.yml`.
31 |
32 | If you want to run the Jupyter notebook `demo.ipynb`, you need also install `geoplot` and `geopandas`.
33 |
34 | You can try the KCN model on a single train-test split by running `python main.py`. If you want to make changes to experiment settings and model parameter, you can provide more arguments to the command according to `args.py`, or you can directly edit default values of arguments in `args.py`.
35 |
36 | ### Reference
37 | [1] Gabriel Appleby, Linfeng Liu, and Li-Ping Liu. "Kriging convolutional networks." Proceedings of the AAAI Conference on Artificial Intelligence. Vol. 34. No. 04. 2020.
38 | [2] Sullivan, B.L., C.L. Wood, M.J. Iliff, R.E. Bonney, D. Fink, and S. Kelling. 2009. eBird: a citizen-based bird observation network in the biological sciences. Biological Conservation 142: 2282-2292.
39 |
40 |
41 |
--------------------------------------------------------------------------------
/data.py:
--------------------------------------------------------------------------------
1 | import os
2 | import numpy as np
3 | import torch
4 |
5 | class SpatialDataset(torch.utils.data.Dataset):
6 | """A dataset class for spatial data."""
7 |
8 | def __init__(self, coords, features, y):
9 | """
10 | Args:
11 | coords: tensor with shape `(n, 2)`, coordinates of `n` instances
12 | features: tensor with shape `(n, d)`, `d` dimensional feature vectors of `n` instances
13 | y: tensor with shape `(n, )`, labels of `n` instances. Please provide zeros if unknown.
14 | neighbors: tensor with shape `(n, num_neighbors)`, neighbors in an external training set.
15 | It can be none and computed later.
16 | """
17 | super(SpatialDataset, self).__init__()
18 |
19 | if coords.shape[0] != features.shape[0] or features.shape[0] != y.shape[0]:
20 | raise Exception(f"Coordinates, features, and labels have different numbers of instances: \
21 | coords.shape[0]={coords.shape[0]}, features.shape[0]={features.shape[0]}, \
22 | y.shape[0]={y.shape[0]}")
23 |
24 |
25 | self.coords = torch.Tensor(coords)
26 | self.features = torch.Tensor(features)
27 | self.y = torch.Tensor(y)
28 |
29 |
30 | def __len__(self):
31 | return self.coords.shape[0]
32 |
33 | def __getitem__(self, idx):
34 |
35 | ins = (self.coords[idx], self.features[idx], self.y[idx])
36 |
37 | return ins
38 |
39 |
40 |
41 |
42 | def load_bird_count_data(args):
43 | """
44 | Load data for training and testing
45 |
46 | Args
47 | ----
48 | args : will use three fields, args.dataset, args.data_path, args.random_seed
49 |
50 | Returns
51 | -------
52 | coords : np.ndarray, shape (N, 2), coordinates of the data points
53 | features : np.ndarray, shape (N, D), features of the data points
54 | y : np.ndarray, shape (N, 1), labels of the data points
55 | num_total_train : int, number of training data points. The first `num_total_train`
56 | of instances from three other return values should form the training set
57 | """
58 |
59 | # data file path
60 | datafile = os.path.join(args.data_path, args.dataset + ".npz")
61 |
62 | # download data if not finding it
63 | if not os.path.isfile(datafile):
64 | raise Exception(f"Data file {datafile} not found. Please download the dataset from https://tufts.box.com/v/kcn-bird-count-dataset and save it to ./datasets/bird_count.npz")
65 |
66 | # load the data
67 | data = np.load(datafile)
68 | X_train = np.ndarray.astype(data['Xtrain'], np.float32)
69 | Y_train = data['Ytrain'].astype(np.float32)
70 | Y_train = Y_train[:, None]
71 | X_test = np.ndarray.astype(data['Xtest'], np.float32)
72 | Y_test = data['Ytest'].astype(np.float32)
73 | Y_test = Y_test[:, None]
74 |
75 |
76 |
77 | num_total_train = X_train.shape[0]
78 |
79 | # check and record shapes
80 | assert (X_train.shape[0] == Y_train.shape[0])
81 | assert (X_test.shape[0] == Y_test.shape[0])
82 |
83 | if args.use_default_test_set:
84 | print("Using the default test set from the data")
85 | trainset = SpatialDataset(coords=X_train[:, 0:2], features=X_train, y=Y_train)
86 | testset = SpatialDataset(coords=X_test[:, 0:2], features=X_test, y=Y_test)
87 | else:
88 | X = np.concatenate([X_train, X_test], axis=0)
89 | Y = np.concatenate([Y_train, Y_test], axis=0)
90 |
91 | perm = np.random.RandomState(seed=args.random_seed).permutation(X.shape[0])
92 |
93 | # include coordinates in features
94 | trainset = SpatialDataset(coords=X[perm[0:num_total_train], 0:2], features=X[perm[0:num_total_train]], y=Y[perm[0:num_total_train]])
95 | testset = SpatialDataset(coords=X[perm[num_total_train:], 0:2], features=X[perm[num_total_train:]], y=Y[perm[num_total_train:]])
96 |
97 | # feature normalization
98 | feature_mean = torch.mean(trainset.features, axis=0, keepdims=True)
99 | feature_std = torch.std(trainset.features, axis=0, keepdims=True)
100 |
101 | trainset.features = (trainset.features - feature_mean) / (feature_std + 0.01)
102 | testset.features = (testset.features - feature_mean) / (feature_std + 0.01)
103 |
104 | return trainset, testset
105 |
106 |
107 |
108 |
109 |
--------------------------------------------------------------------------------
/experiment.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import numpy as np
3 | import kcn
4 | import data
5 | from tqdm import tqdm
6 |
7 | def run_kcn(args):
8 | """ Train and test a KCN model on a train-test split
9 |
10 | Args
11 | ----
12 | args : argparse.Namespace object, which contains the following attributes:
13 | - 'model' : str, which is one of 'gcn', 'gcn_gat', 'gcn_sage'
14 | - 'n_neighbors' : int, number of neighbors
15 | - 'hidden1' : int, number of units in hidden layer 1
16 | - 'dropout' : float, the dropout rate in a dropout layer
17 | - 'lr' : float, learning rate of the Adam optimizer
18 | - 'epochs' : int, number of training epochs
19 | - 'es_patience' : int, patience for early stopping
20 | - 'batch_size' : int, batch size
21 | - 'dataset' : str, path to the data file
22 | - 'last_activation' : str, activation for the last layer
23 | - 'weight_decay' : float, weight decay for the Adam optimizer
24 | - 'length_scale' : float, length scale for RBF kernel
25 | - 'loss_type' : str, which is one of 'squared_error', 'nll_error'
26 | - 'validation_size' : int, validation size
27 | - 'gcn_kriging' : bool, whether to use gcn kriging
28 | - 'sparse_input' : bool, whether to use sparse matrices
29 | - 'device' : torch.device, which is either 'cuda' or 'cpu'
30 |
31 | """
32 | # This function has the following three steps:
33 | # 1) loading data; 2) spliting the data into training and test subsets; 3) normalizing data
34 | if args.dataset == "bird_count":
35 | trainset, testset = data.load_bird_count_data(args)
36 | else:
37 | raise Exception(f"The repo does not support this dataset yet: args.dataset={args.dataset}")
38 |
39 | print(f"The {args.dataset} dataset has {len(trainset)} training instances and {len(testset)} test instances.")
40 |
41 | num_total_train = len(trainset)
42 | num_valid = args.validation_size
43 | num_train = num_total_train - args.validation_size
44 |
45 | # initialize a kcn model
46 | # 1) the entire training set including validation points are recorded by the model and will
47 | # be looked up in neighbor searches
48 | # 2) the model will pre-compute neighbors for a training or validation instance to avoid repeated neighbor search
49 | # 3) if a data point appears in training set and validation set, its neighbors does not include itself
50 | model = kcn.KCN(trainset, args)
51 | model = model.to(args.device)
52 |
53 | loss_func = torch.nn.MSELoss(reduction='mean')
54 |
55 | optimizer = torch.optim.Adam(model.parameters(), lr=args.lr, weight_decay=args.weight_decay)
56 |
57 | epoch_train_error = []
58 | epoch_valid_error = []
59 |
60 | # the training loop
61 | model.train()
62 |
63 |
64 | for epoch in range(args.epochs):
65 |
66 | batch_train_error = []
67 |
68 | # use training indices directly because it will be used to retrieve pre-computed neighbors
69 | for i in tqdm(range(0, num_train, args.batch_size)):
70 |
71 | # fetch a batch of data
72 | batch_ind = range(i, min(i + args.batch_size, num_train))
73 | batch_coords, batch_features, batch_y = model.trainset[batch_ind]
74 |
75 | # make predictions and compute the average loss
76 | pred = model(batch_coords, batch_features, batch_ind)
77 | loss = loss_func(pred, batch_y.to(args.device))
78 |
79 | # update parameters
80 | optimizer.zero_grad()
81 | loss.backward()
82 | optimizer.step()
83 |
84 | # record the training error
85 | batch_train_error.append(loss.item())
86 |
87 | train_error = sum(batch_train_error) / len(batch_train_error)
88 | epoch_train_error.append(train_error)
89 |
90 | # fetch the validation set
91 | valid_ind = range(num_train, num_total_train)
92 | valid_coords, valid_features, valid_y = model.trainset[valid_ind]
93 |
94 | # make predictions and calculate the error
95 | valid_pred = model(valid_coords, valid_features, valid_ind)
96 | valid_error = loss_func(valid_pred, valid_y.to(args.device))
97 |
98 | epoch_valid_error.append(valid_error.item())
99 |
100 | print(f"Epoch: {epoch},", f"train error: {train_error},", f"validation error: {valid_error}")
101 |
102 | # check whether to stop
103 | if (epoch > args.es_patience) and \
104 | (np.mean(np.array(epoch_valid_error[-3:])) >
105 | np.mean(np.array(epoch_valid_error[-(args.es_patience + 3):-3]))):
106 | print("\nEarly stopping at epoch {}".format(epoch))
107 | break
108 |
109 | # test the model
110 | model.eval()
111 |
112 | test_preds = model(testset.coords, testset.features)
113 | test_error = loss_func(test_preds, testset.y.to(args.device))
114 | test_error = torch.mean(test_error).item()
115 |
116 | print(f"Test error is {test_error}")
117 |
118 | return test_error
119 |
--------------------------------------------------------------------------------
/kcn.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import sklearn
3 | import sklearn.neighbors
4 | import torch
5 | import torch_geometric
6 |
7 |
8 | class KCN(torch.nn.Module):
9 | """ Creates a KCN model with the given parameters."""
10 |
11 | def __init__(self, trainset, args) -> None:
12 | super(KCN, self).__init__()
13 |
14 | self.trainset = trainset
15 |
16 | # set neighbor relationships within the training set
17 | self.n_neighbors = args.n_neighbors
18 | self.knn = sklearn.neighbors.NearestNeighbors(n_neighbors=self.n_neighbors).fit(self.trainset.coords)
19 | distances, self.train_neighbors = self.knn.kneighbors(None, return_distance=True)
20 |
21 | if args.length_scale == "auto":
22 | self.length_scale = np.median(distances.flatten())
23 | print(f"Length scale is set to {self.length_scale}")
24 | else:
25 | if not isinstance(args.length_scale, float):
26 | raise Exception(f"If the provided length scale is not 'auto', then it should be a float number: args.length_scale={args.length_scale}")
27 | self.length_scale = args.length_scale
28 |
29 | with torch.no_grad():
30 | self.graph_inputs = []
31 | for i in range(self.trainset.coords.shape[0]):
32 | att_graph = self.form_input_graph(self.trainset.coords[i], self.trainset.features[i], self.train_neighbors[i])
33 | self.graph_inputs.append(att_graph)
34 |
35 | # initialize model
36 | # input dimensions should be feature dimensions, a label dimension and an indicator dimension
37 | input_dim = trainset.features.shape[1] + 2
38 | output_dim = trainset.y.shape[1]
39 |
40 | self.gnn = GNN(input_dim, args)
41 |
42 | # the last linear layer
43 | self.linear = torch.nn.Linear(args.hidden_sizes[-1], output_dim, bias=False)
44 |
45 | # the last activation function
46 | if args.last_activation == 'relu':
47 | self.last_activation = torch.nn.ReLU()
48 | elif args.last_activation == 'sigmoid':
49 | self.last_activation = torch.nn.Sigmoid()
50 | elif args.last_activation == 'tanh':
51 | self.last_activation = torch.nn.Tanh()
52 | elif args.last_activation == 'softplus':
53 | self.last_activation = torch.nn.Softplus()
54 | elif args.last_activation == 'none':
55 | self.last_activation = lambda _: _
56 | else:
57 | raise Exception(f"No such choice of activation for the output: args.last_activation={args.last_activation}")
58 |
59 |
60 | self.collate_fn = torch_geometric.loader.dataloader.Collater(None, None)
61 |
62 | self.device = args.device
63 | self.gnn = self.gnn.to(self.device)
64 |
65 |
66 | def forward(self, coords, features, train_indices=None):
67 |
68 | if train_indices is not None:
69 |
70 | # if from training set, then read in pre-computed graphs
71 | batch_inputs = []
72 | for i in train_indices:
73 | batch_inputs.append(self.graph_inputs[i])
74 |
75 | batch_inputs = self.collate_fn(batch_inputs)
76 |
77 |
78 | else:
79 |
80 | # if new instances, then need to find neighbors and form input graphs
81 | neighbors = self.knn.kneighbors(coords, return_distance=False)
82 |
83 | with torch.no_grad():
84 | batch_inputs = []
85 | for i in range(len(coords)):
86 | att_graph = self.form_input_graph(coords[i], features[i], neighbors[i])
87 | batch_inputs.append(att_graph)
88 |
89 | batch_inputs = self.collate_fn(batch_inputs)
90 |
91 | batch_inputs = batch_inputs.to(self.device)
92 |
93 | # run gnn on the graph input
94 | output = self.gnn(batch_inputs.x, batch_inputs.edge_index, batch_inputs.edge_attr)
95 |
96 | # take representations only corresponding to center nodes
97 | output = torch.reshape(output, [-1, (self.n_neighbors + 1), output.shape[1]])
98 | center_output = output[:, 0]
99 | pred = self.last_activation(self.linear(center_output))
100 |
101 | return pred
102 |
103 | def form_input_graph(self, coord, feature, neighbors):
104 |
105 | output_dim = self.trainset.y.shape[1]
106 |
107 | # label inputs
108 | y = torch.concat([torch.zeros([1, output_dim]), self.trainset.y[neighbors]], axis=0)
109 |
110 | # indicator
111 | indicator = torch.zeros([neighbors.shape[0] + 1])
112 | indicator[0] = 1.0
113 |
114 | # feature inputs
115 | features = torch.concat([feature[None, :], self.trainset.features[neighbors]], axis=0)
116 |
117 | # form graph features
118 | graph_features = torch.concat([features, y, indicator[:, None]], axis=1)
119 |
120 |
121 | # compute a weighted graph from an rbf kernel
122 | all_coords = torch.concat([coord[None, :], self.trainset.coords[neighbors]], axis=0)
123 |
124 | # K(x, y) = exp(-gamma ||x-y||^2)
125 | kernel = sklearn.metrics.pairwise.rbf_kernel(all_coords.numpy(), gamma=1 / (2 * self.length_scale ** 2))
126 | ## the implementation here is the same as sklearn.metrics.pairwise.rbf_kernel
127 | #row_norm = torch.sum(torch.square(all_coords), dim=1)
128 | #dist = row_norm[:, None] - 2 * torch.matmul(all_coords, all_coords.t()) + row_norm[None, :]
129 | #kernel = torch.exp(-self.length_scale * dist)
130 |
131 | adj = torch.from_numpy(kernel)
132 | # one choice is to normalize the adjacency matrix
133 | #curr_adj = normalize_adj(curr_adj + np.eye(curr_adj.shape[0]))
134 |
135 | # create a graph from it
136 | nz = adj.nonzero(as_tuple=True)
137 | edges = torch.stack(nz, dim=0)
138 | edge_weights = adj[nz]
139 |
140 | # form the graph
141 | attributed_graph = torch_geometric.data.Data(x=graph_features, edge_index=edges, edge_attr=edge_weights, y=None)
142 |
143 | return attributed_graph
144 |
145 | def _normalize_adj(self, adj):
146 | """Symmetrically normalize adjacency matrix."""
147 |
148 | row_sum = np.array(adj.sum(1))
149 | d_inv_sqrt = np.power(row_sum, -0.5).flatten()
150 | d_inv_sqrt[np.isinf(d_inv_sqrt)] = 0.
151 |
152 | adj_normalized = d_inv_sqrt[:, None] * adj * d_inv_sqrt[None, :]
153 |
154 | return adj_normalized
155 |
156 |
157 |
158 |
159 | class GNN(torch.nn.Module):
160 | """ Creates a KCN model with the given parameters."""
161 |
162 | def __init__(self, input_dim, args) -> None:
163 | super().__init__()
164 |
165 | self.hidden_sizes = args.hidden_sizes
166 | self.dropout = args.dropout
167 | self.model_type = args.model
168 |
169 | if self.model_type == 'kcn':
170 | conv_layer = torch_geometric.nn.GCNConv (input_dim, self.hidden_sizes[0], bias=False, add_self_loops=True)
171 | elif self.model_type == 'kcn_gat':
172 | conv_layer = torch_geometric.nn.GATConv (input_dim, self.hidden_sizes[0])
173 | elif self.model_type == 'kcn_sage':
174 | conv_layer = torch_geometric.nn.SAGEConv(input_dim, self.hidden_sizes[0], aggr='max', normalize=True)
175 | else:
176 | raise Exception(f"No such model choice: args.model={args.model}")
177 |
178 | self.add_module("layer0", conv_layer)
179 |
180 |
181 | for ilayer in range(1, len(self.hidden_sizes)):
182 | if self.model_type == 'kcn':
183 | conv_layer = torch_geometric.nn.GCNConv (self.hidden_sizes[ilayer - 1], self.hidden_sizes[ilayer], bias=False, add_self_loops=True)
184 | elif self.model_type == 'kcn_gat':
185 | conv_layer = torch_geometric.nn.GATConv (self.hidden_sizes[ilayer - 1], self.hidden_sizes[ilayer])
186 | elif self.model_type == 'kcn_sage':
187 | conv_layer = torch_geometric.nn.SAGEConv(self.hidden_sizes[ilayer - 1], self.hidden_sizes[ilayer], aggr='max', normalize=True)
188 |
189 | self.add_module("layer"+str(ilayer), conv_layer)
190 |
191 | def forward(self, x, edge_index, edge_weight):
192 |
193 | for conv_layer in self.children():
194 |
195 | if self.model_type == 'kcn':
196 | x = conv_layer(x, edge_index, edge_weight=edge_weight)
197 |
198 | elif self.model_type == 'kcn_gat':
199 | x, (edge_index, attention_weights) = conv_layer(x, edge_index, edge_attr=edge_weight, return_attention_weights=True)
200 | #edge_weight = ttention_weights
201 |
202 | elif self.model_type == 'kcn_sage':
203 | x = conv_layer(x, edge_index)
204 |
205 | x = torch.nn.functional.relu(x)
206 | x = torch.nn.functional.dropout(x, p=self.dropout, training=self.training)
207 |
208 | return x
209 |
210 |
211 |
--------------------------------------------------------------------------------