├── datasets └── .gitkeep ├── .gitignore ├── environment.yml ├── main.py ├── argument.py ├── README.md ├── data.py ├── experiment.py └── kcn.py /datasets/.gitkeep: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | .*.swp 2 | .*.swo 3 | *.out 4 | *.err 5 | *.npz 6 | *.slurm 7 | __pycache__ 8 | .ipynb_checkpoints/ 9 | 10 | -------------------------------------------------------------------------------- /environment.yml: -------------------------------------------------------------------------------- 1 | name: kcn_env 2 | 3 | dependencies: 4 | - python=3.10 5 | - numpy 6 | - pytorch::pytorch=2.0 7 | - pyg::pyg 8 | - conda-forge::scikit-learn 9 | -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | from argument import parse_opt 4 | from experiment import run_kcn 5 | 6 | 7 | # repeat the experiment in the paper 8 | def random_runs(args): 9 | test_errors = [] 10 | for args.random_seed in range(10): 11 | np.random.seed(args.random_seed) 12 | torch.manual_seed(args.random_seed) 13 | 14 | err = run_kcn(args) 15 | test_errors.append(err) 16 | 17 | test_errors = np.array(test_errors) 18 | return test_errors 19 | 20 | 21 | 22 | if __name__ == "__main__": 23 | 24 | args = parse_opt() 25 | print(args) 26 | 27 | # set random seeds 28 | np.random.seed(args.random_seed) 29 | torch.manual_seed(args.random_seed) 30 | 31 | # run experiment on one train-test split 32 | err = run_kcn(args) 33 | print('Model: {}, test error: {}\n'.format(args.model, err)) 34 | 35 | 36 | ## run all experiments on one dataset 37 | #model_error = dict() 38 | #for args.model in ["kcn", "kcn_gat", "kcn_sage"]: 39 | # test_errors = random_runs(args) 40 | # model_error[args.model] = (np.mean(test_errors), np.std(test_errors)) 41 | # print(model_error) 42 | 43 | 44 | -------------------------------------------------------------------------------- /argument.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import torch 3 | 4 | def parse_opt(): 5 | 6 | # Settings 7 | parser = argparse.ArgumentParser() 8 | parser.add_argument('--random_seed', type=int, default="5", help="The random seed") 9 | parser.add_argument('--dataset', type=str, default="bird_count", help="The dataset name: currently can only be 'bird_count'") 10 | parser.add_argument('--data_path', type=str, default="./datasets", help="The folder containing the data file. The default file is './data/{dataset}.pkl'") 11 | parser.add_argument('--use_default_test_set', type=bool, default=False, help='Use the default test set from the data') 12 | 13 | parser.add_argument('--model', type=str, default='kcn', help='One of three model types, kcn, kcn_gat, kcn_sage, which use GCN, GAT, and GraphSAGE respectively') 14 | parser.add_argument('--n_neighbors', type=int, default=5, help='Number of neighbors') 15 | parser.add_argument('--length_scale', default="auto", help='Length scale for RBF kernel. If set to "auto", then it will be set to the median of neighbor distances') 16 | parser.add_argument('--hidden_sizes', type=list, default=[8, 8, 8], help='Number of units in hidden layers, also decide the number of layers') 17 | parser.add_argument('--dropout', type=float, default=0.1, help='Dropout rate (1 - keep probability).') 18 | parser.add_argument('--last_activation', type=str, default='none', help='Activation for the last layer') 19 | 20 | parser.add_argument('--loss_type', type=str, default='squared_error', help='Loss type') 21 | parser.add_argument('--validation_size', type=int, default=5000, help='Validation size') 22 | 23 | parser.add_argument('--lr', type=float, default=5e-3, help='Learning rate.') 24 | parser.add_argument('--weight_decay', type=float, default=5e-4, help='Weight decay for the optimizer.') 25 | parser.add_argument('--epochs', type=int, default=100, help='Number of training epochs.') 26 | parser.add_argument('--es_patience', type=int, default=20, help='Patience for early stopping.') 27 | parser.add_argument('--batch_size', type=int, default=64, help='Batch size') 28 | 29 | parser.add_argument('--device', type=str, default="auto", help='Computation device.') 30 | 31 | args, unknowns = parser.parse_known_args() 32 | 33 | if args.device == "auto": 34 | args.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 35 | else: 36 | args.device = torch.device(args.device) 37 | 38 | return args 39 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # A PyTorch Implementation of Kriging Convolutional Networks 2 | 3 | **Wenrui Zhao1, Ruiqi Xu1, and Li-Ping Liu1** 4 | 5 | 1 Department of Computer Science, Tufts University 6 | 7 | ### Overview 8 | 9 | This repo contains the PyTorch implementation of Kriging Convolutional Networks (KCNs) [1]. The original code release was an implementation with TensorFlow 1.x, which is not well supported now. In this repo, we update the implementation with PyTorch and PyTorch Geometric. We hope this new implementation will help researchers in this field. If you have any questions about this repo, please feel free to raise issues. 10 | 11 | ### Data 12 | 13 | The model works with datasets in the `SpatialDataset` format. In particular, it contains three fields: 14 | 15 | * `dataset.coords`: a tensor with shape `[n, 2]`, each row is a 2-D coordinate of an instance. Coordinates with other dimensions can also be handled by KCN. 16 | * `dataset.features`: a tensor with shape `[n, d]`, each row is a `d`-dimensional feature vector of an instance. 17 | * `dataset.y`: a tensor with shape `[n, l]`, each row is a vector of `l` labels of an instance. The current example only work for one-dimensional continuous labels. 18 | 19 | The current repo provides a running example of KCN on a dataset of bird counts (counts of wood thrush reported in Eeatern US during June 2014) [[download link](https://tufts.box.com/v/kcn-bird-count-dataset)]. 20 | 21 | ### Model 22 | 23 | A KCN predicts a data point's label based on data points in its neighborhood. The KCN model stores a training set internally. To make a prediction for a data point, it looks up neighbors for the data point and construct, forms an attributed graph over data points in the neighborhood, and then uses a Graph Neural Network (GNN) to predict the data point's label. During training, these graphs are computed before training to save repeated graph constructions. The general structure of KCN is similar to a k-nearest-neighbor classifier, though the former one employs a much more flexible predictive function than simple averaging. 24 | 25 | In the implementation, a KCN model is a PyTorch module. It is initialized with a `SpatialDataset`. In the `forward` function, it takes coordinates and features of a batch of data points and then predicts their labels. 26 | 27 | ### Run the code 28 | 29 | #### Requirements 30 | The code has been tested on a linux platform with the following packages installed: `python=3.10, numpy=1.24.3, torch=2.01, scikit-learn=1.2.2, pyg=2.3.0`. You can install the environment from `environment.yml` with `conda env create -f environment.yml`. 31 | 32 | If you want to run the Jupyter notebook `demo.ipynb`, you need also install `geoplot` and `geopandas`. 33 | 34 | You can try the KCN model on a single train-test split by running `python main.py`. If you want to make changes to experiment settings and model parameter, you can provide more arguments to the command according to `args.py`, or you can directly edit default values of arguments in `args.py`. 35 | 36 | ### Reference 37 | [1] Gabriel Appleby, Linfeng Liu, and Li-Ping Liu. "Kriging convolutional networks." Proceedings of the AAAI Conference on Artificial Intelligence. Vol. 34. No. 04. 2020. 38 | [2] Sullivan, B.L., C.L. Wood, M.J. Iliff, R.E. Bonney, D. Fink, and S. Kelling. 2009. eBird: a citizen-based bird observation network in the biological sciences. Biological Conservation 142: 2282-2292. 39 | 40 | 41 | -------------------------------------------------------------------------------- /data.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | import torch 4 | 5 | class SpatialDataset(torch.utils.data.Dataset): 6 | """A dataset class for spatial data.""" 7 | 8 | def __init__(self, coords, features, y): 9 | """ 10 | Args: 11 | coords: tensor with shape `(n, 2)`, coordinates of `n` instances 12 | features: tensor with shape `(n, d)`, `d` dimensional feature vectors of `n` instances 13 | y: tensor with shape `(n, )`, labels of `n` instances. Please provide zeros if unknown. 14 | neighbors: tensor with shape `(n, num_neighbors)`, neighbors in an external training set. 15 | It can be none and computed later. 16 | """ 17 | super(SpatialDataset, self).__init__() 18 | 19 | if coords.shape[0] != features.shape[0] or features.shape[0] != y.shape[0]: 20 | raise Exception(f"Coordinates, features, and labels have different numbers of instances: \ 21 | coords.shape[0]={coords.shape[0]}, features.shape[0]={features.shape[0]}, \ 22 | y.shape[0]={y.shape[0]}") 23 | 24 | 25 | self.coords = torch.Tensor(coords) 26 | self.features = torch.Tensor(features) 27 | self.y = torch.Tensor(y) 28 | 29 | 30 | def __len__(self): 31 | return self.coords.shape[0] 32 | 33 | def __getitem__(self, idx): 34 | 35 | ins = (self.coords[idx], self.features[idx], self.y[idx]) 36 | 37 | return ins 38 | 39 | 40 | 41 | 42 | def load_bird_count_data(args): 43 | """ 44 | Load data for training and testing 45 | 46 | Args 47 | ---- 48 | args : will use three fields, args.dataset, args.data_path, args.random_seed 49 | 50 | Returns 51 | ------- 52 | coords : np.ndarray, shape (N, 2), coordinates of the data points 53 | features : np.ndarray, shape (N, D), features of the data points 54 | y : np.ndarray, shape (N, 1), labels of the data points 55 | num_total_train : int, number of training data points. The first `num_total_train` 56 | of instances from three other return values should form the training set 57 | """ 58 | 59 | # data file path 60 | datafile = os.path.join(args.data_path, args.dataset + ".npz") 61 | 62 | # download data if not finding it 63 | if not os.path.isfile(datafile): 64 | raise Exception(f"Data file {datafile} not found. Please download the dataset from https://tufts.box.com/v/kcn-bird-count-dataset and save it to ./datasets/bird_count.npz") 65 | 66 | # load the data 67 | data = np.load(datafile) 68 | X_train = np.ndarray.astype(data['Xtrain'], np.float32) 69 | Y_train = data['Ytrain'].astype(np.float32) 70 | Y_train = Y_train[:, None] 71 | X_test = np.ndarray.astype(data['Xtest'], np.float32) 72 | Y_test = data['Ytest'].astype(np.float32) 73 | Y_test = Y_test[:, None] 74 | 75 | 76 | 77 | num_total_train = X_train.shape[0] 78 | 79 | # check and record shapes 80 | assert (X_train.shape[0] == Y_train.shape[0]) 81 | assert (X_test.shape[0] == Y_test.shape[0]) 82 | 83 | if args.use_default_test_set: 84 | print("Using the default test set from the data") 85 | trainset = SpatialDataset(coords=X_train[:, 0:2], features=X_train, y=Y_train) 86 | testset = SpatialDataset(coords=X_test[:, 0:2], features=X_test, y=Y_test) 87 | else: 88 | X = np.concatenate([X_train, X_test], axis=0) 89 | Y = np.concatenate([Y_train, Y_test], axis=0) 90 | 91 | perm = np.random.RandomState(seed=args.random_seed).permutation(X.shape[0]) 92 | 93 | # include coordinates in features 94 | trainset = SpatialDataset(coords=X[perm[0:num_total_train], 0:2], features=X[perm[0:num_total_train]], y=Y[perm[0:num_total_train]]) 95 | testset = SpatialDataset(coords=X[perm[num_total_train:], 0:2], features=X[perm[num_total_train:]], y=Y[perm[num_total_train:]]) 96 | 97 | # feature normalization 98 | feature_mean = torch.mean(trainset.features, axis=0, keepdims=True) 99 | feature_std = torch.std(trainset.features, axis=0, keepdims=True) 100 | 101 | trainset.features = (trainset.features - feature_mean) / (feature_std + 0.01) 102 | testset.features = (testset.features - feature_mean) / (feature_std + 0.01) 103 | 104 | return trainset, testset 105 | 106 | 107 | 108 | 109 | -------------------------------------------------------------------------------- /experiment.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | import kcn 4 | import data 5 | from tqdm import tqdm 6 | 7 | def run_kcn(args): 8 | """ Train and test a KCN model on a train-test split 9 | 10 | Args 11 | ---- 12 | args : argparse.Namespace object, which contains the following attributes: 13 | - 'model' : str, which is one of 'gcn', 'gcn_gat', 'gcn_sage' 14 | - 'n_neighbors' : int, number of neighbors 15 | - 'hidden1' : int, number of units in hidden layer 1 16 | - 'dropout' : float, the dropout rate in a dropout layer 17 | - 'lr' : float, learning rate of the Adam optimizer 18 | - 'epochs' : int, number of training epochs 19 | - 'es_patience' : int, patience for early stopping 20 | - 'batch_size' : int, batch size 21 | - 'dataset' : str, path to the data file 22 | - 'last_activation' : str, activation for the last layer 23 | - 'weight_decay' : float, weight decay for the Adam optimizer 24 | - 'length_scale' : float, length scale for RBF kernel 25 | - 'loss_type' : str, which is one of 'squared_error', 'nll_error' 26 | - 'validation_size' : int, validation size 27 | - 'gcn_kriging' : bool, whether to use gcn kriging 28 | - 'sparse_input' : bool, whether to use sparse matrices 29 | - 'device' : torch.device, which is either 'cuda' or 'cpu' 30 | 31 | """ 32 | # This function has the following three steps: 33 | # 1) loading data; 2) spliting the data into training and test subsets; 3) normalizing data 34 | if args.dataset == "bird_count": 35 | trainset, testset = data.load_bird_count_data(args) 36 | else: 37 | raise Exception(f"The repo does not support this dataset yet: args.dataset={args.dataset}") 38 | 39 | print(f"The {args.dataset} dataset has {len(trainset)} training instances and {len(testset)} test instances.") 40 | 41 | num_total_train = len(trainset) 42 | num_valid = args.validation_size 43 | num_train = num_total_train - args.validation_size 44 | 45 | # initialize a kcn model 46 | # 1) the entire training set including validation points are recorded by the model and will 47 | # be looked up in neighbor searches 48 | # 2) the model will pre-compute neighbors for a training or validation instance to avoid repeated neighbor search 49 | # 3) if a data point appears in training set and validation set, its neighbors does not include itself 50 | model = kcn.KCN(trainset, args) 51 | model = model.to(args.device) 52 | 53 | loss_func = torch.nn.MSELoss(reduction='mean') 54 | 55 | optimizer = torch.optim.Adam(model.parameters(), lr=args.lr, weight_decay=args.weight_decay) 56 | 57 | epoch_train_error = [] 58 | epoch_valid_error = [] 59 | 60 | # the training loop 61 | model.train() 62 | 63 | 64 | for epoch in range(args.epochs): 65 | 66 | batch_train_error = [] 67 | 68 | # use training indices directly because it will be used to retrieve pre-computed neighbors 69 | for i in tqdm(range(0, num_train, args.batch_size)): 70 | 71 | # fetch a batch of data 72 | batch_ind = range(i, min(i + args.batch_size, num_train)) 73 | batch_coords, batch_features, batch_y = model.trainset[batch_ind] 74 | 75 | # make predictions and compute the average loss 76 | pred = model(batch_coords, batch_features, batch_ind) 77 | loss = loss_func(pred, batch_y.to(args.device)) 78 | 79 | # update parameters 80 | optimizer.zero_grad() 81 | loss.backward() 82 | optimizer.step() 83 | 84 | # record the training error 85 | batch_train_error.append(loss.item()) 86 | 87 | train_error = sum(batch_train_error) / len(batch_train_error) 88 | epoch_train_error.append(train_error) 89 | 90 | # fetch the validation set 91 | valid_ind = range(num_train, num_total_train) 92 | valid_coords, valid_features, valid_y = model.trainset[valid_ind] 93 | 94 | # make predictions and calculate the error 95 | valid_pred = model(valid_coords, valid_features, valid_ind) 96 | valid_error = loss_func(valid_pred, valid_y.to(args.device)) 97 | 98 | epoch_valid_error.append(valid_error.item()) 99 | 100 | print(f"Epoch: {epoch},", f"train error: {train_error},", f"validation error: {valid_error}") 101 | 102 | # check whether to stop 103 | if (epoch > args.es_patience) and \ 104 | (np.mean(np.array(epoch_valid_error[-3:])) > 105 | np.mean(np.array(epoch_valid_error[-(args.es_patience + 3):-3]))): 106 | print("\nEarly stopping at epoch {}".format(epoch)) 107 | break 108 | 109 | # test the model 110 | model.eval() 111 | 112 | test_preds = model(testset.coords, testset.features) 113 | test_error = loss_func(test_preds, testset.y.to(args.device)) 114 | test_error = torch.mean(test_error).item() 115 | 116 | print(f"Test error is {test_error}") 117 | 118 | return test_error 119 | -------------------------------------------------------------------------------- /kcn.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import sklearn 3 | import sklearn.neighbors 4 | import torch 5 | import torch_geometric 6 | 7 | 8 | class KCN(torch.nn.Module): 9 | """ Creates a KCN model with the given parameters.""" 10 | 11 | def __init__(self, trainset, args) -> None: 12 | super(KCN, self).__init__() 13 | 14 | self.trainset = trainset 15 | 16 | # set neighbor relationships within the training set 17 | self.n_neighbors = args.n_neighbors 18 | self.knn = sklearn.neighbors.NearestNeighbors(n_neighbors=self.n_neighbors).fit(self.trainset.coords) 19 | distances, self.train_neighbors = self.knn.kneighbors(None, return_distance=True) 20 | 21 | if args.length_scale == "auto": 22 | self.length_scale = np.median(distances.flatten()) 23 | print(f"Length scale is set to {self.length_scale}") 24 | else: 25 | if not isinstance(args.length_scale, float): 26 | raise Exception(f"If the provided length scale is not 'auto', then it should be a float number: args.length_scale={args.length_scale}") 27 | self.length_scale = args.length_scale 28 | 29 | with torch.no_grad(): 30 | self.graph_inputs = [] 31 | for i in range(self.trainset.coords.shape[0]): 32 | att_graph = self.form_input_graph(self.trainset.coords[i], self.trainset.features[i], self.train_neighbors[i]) 33 | self.graph_inputs.append(att_graph) 34 | 35 | # initialize model 36 | # input dimensions should be feature dimensions, a label dimension and an indicator dimension 37 | input_dim = trainset.features.shape[1] + 2 38 | output_dim = trainset.y.shape[1] 39 | 40 | self.gnn = GNN(input_dim, args) 41 | 42 | # the last linear layer 43 | self.linear = torch.nn.Linear(args.hidden_sizes[-1], output_dim, bias=False) 44 | 45 | # the last activation function 46 | if args.last_activation == 'relu': 47 | self.last_activation = torch.nn.ReLU() 48 | elif args.last_activation == 'sigmoid': 49 | self.last_activation = torch.nn.Sigmoid() 50 | elif args.last_activation == 'tanh': 51 | self.last_activation = torch.nn.Tanh() 52 | elif args.last_activation == 'softplus': 53 | self.last_activation = torch.nn.Softplus() 54 | elif args.last_activation == 'none': 55 | self.last_activation = lambda _: _ 56 | else: 57 | raise Exception(f"No such choice of activation for the output: args.last_activation={args.last_activation}") 58 | 59 | 60 | self.collate_fn = torch_geometric.loader.dataloader.Collater(None, None) 61 | 62 | self.device = args.device 63 | self.gnn = self.gnn.to(self.device) 64 | 65 | 66 | def forward(self, coords, features, train_indices=None): 67 | 68 | if train_indices is not None: 69 | 70 | # if from training set, then read in pre-computed graphs 71 | batch_inputs = [] 72 | for i in train_indices: 73 | batch_inputs.append(self.graph_inputs[i]) 74 | 75 | batch_inputs = self.collate_fn(batch_inputs) 76 | 77 | 78 | else: 79 | 80 | # if new instances, then need to find neighbors and form input graphs 81 | neighbors = self.knn.kneighbors(coords, return_distance=False) 82 | 83 | with torch.no_grad(): 84 | batch_inputs = [] 85 | for i in range(len(coords)): 86 | att_graph = self.form_input_graph(coords[i], features[i], neighbors[i]) 87 | batch_inputs.append(att_graph) 88 | 89 | batch_inputs = self.collate_fn(batch_inputs) 90 | 91 | batch_inputs = batch_inputs.to(self.device) 92 | 93 | # run gnn on the graph input 94 | output = self.gnn(batch_inputs.x, batch_inputs.edge_index, batch_inputs.edge_attr) 95 | 96 | # take representations only corresponding to center nodes 97 | output = torch.reshape(output, [-1, (self.n_neighbors + 1), output.shape[1]]) 98 | center_output = output[:, 0] 99 | pred = self.last_activation(self.linear(center_output)) 100 | 101 | return pred 102 | 103 | def form_input_graph(self, coord, feature, neighbors): 104 | 105 | output_dim = self.trainset.y.shape[1] 106 | 107 | # label inputs 108 | y = torch.concat([torch.zeros([1, output_dim]), self.trainset.y[neighbors]], axis=0) 109 | 110 | # indicator 111 | indicator = torch.zeros([neighbors.shape[0] + 1]) 112 | indicator[0] = 1.0 113 | 114 | # feature inputs 115 | features = torch.concat([feature[None, :], self.trainset.features[neighbors]], axis=0) 116 | 117 | # form graph features 118 | graph_features = torch.concat([features, y, indicator[:, None]], axis=1) 119 | 120 | 121 | # compute a weighted graph from an rbf kernel 122 | all_coords = torch.concat([coord[None, :], self.trainset.coords[neighbors]], axis=0) 123 | 124 | # K(x, y) = exp(-gamma ||x-y||^2) 125 | kernel = sklearn.metrics.pairwise.rbf_kernel(all_coords.numpy(), gamma=1 / (2 * self.length_scale ** 2)) 126 | ## the implementation here is the same as sklearn.metrics.pairwise.rbf_kernel 127 | #row_norm = torch.sum(torch.square(all_coords), dim=1) 128 | #dist = row_norm[:, None] - 2 * torch.matmul(all_coords, all_coords.t()) + row_norm[None, :] 129 | #kernel = torch.exp(-self.length_scale * dist) 130 | 131 | adj = torch.from_numpy(kernel) 132 | # one choice is to normalize the adjacency matrix 133 | #curr_adj = normalize_adj(curr_adj + np.eye(curr_adj.shape[0])) 134 | 135 | # create a graph from it 136 | nz = adj.nonzero(as_tuple=True) 137 | edges = torch.stack(nz, dim=0) 138 | edge_weights = adj[nz] 139 | 140 | # form the graph 141 | attributed_graph = torch_geometric.data.Data(x=graph_features, edge_index=edges, edge_attr=edge_weights, y=None) 142 | 143 | return attributed_graph 144 | 145 | def _normalize_adj(self, adj): 146 | """Symmetrically normalize adjacency matrix.""" 147 | 148 | row_sum = np.array(adj.sum(1)) 149 | d_inv_sqrt = np.power(row_sum, -0.5).flatten() 150 | d_inv_sqrt[np.isinf(d_inv_sqrt)] = 0. 151 | 152 | adj_normalized = d_inv_sqrt[:, None] * adj * d_inv_sqrt[None, :] 153 | 154 | return adj_normalized 155 | 156 | 157 | 158 | 159 | class GNN(torch.nn.Module): 160 | """ Creates a KCN model with the given parameters.""" 161 | 162 | def __init__(self, input_dim, args) -> None: 163 | super().__init__() 164 | 165 | self.hidden_sizes = args.hidden_sizes 166 | self.dropout = args.dropout 167 | self.model_type = args.model 168 | 169 | if self.model_type == 'kcn': 170 | conv_layer = torch_geometric.nn.GCNConv (input_dim, self.hidden_sizes[0], bias=False, add_self_loops=True) 171 | elif self.model_type == 'kcn_gat': 172 | conv_layer = torch_geometric.nn.GATConv (input_dim, self.hidden_sizes[0]) 173 | elif self.model_type == 'kcn_sage': 174 | conv_layer = torch_geometric.nn.SAGEConv(input_dim, self.hidden_sizes[0], aggr='max', normalize=True) 175 | else: 176 | raise Exception(f"No such model choice: args.model={args.model}") 177 | 178 | self.add_module("layer0", conv_layer) 179 | 180 | 181 | for ilayer in range(1, len(self.hidden_sizes)): 182 | if self.model_type == 'kcn': 183 | conv_layer = torch_geometric.nn.GCNConv (self.hidden_sizes[ilayer - 1], self.hidden_sizes[ilayer], bias=False, add_self_loops=True) 184 | elif self.model_type == 'kcn_gat': 185 | conv_layer = torch_geometric.nn.GATConv (self.hidden_sizes[ilayer - 1], self.hidden_sizes[ilayer]) 186 | elif self.model_type == 'kcn_sage': 187 | conv_layer = torch_geometric.nn.SAGEConv(self.hidden_sizes[ilayer - 1], self.hidden_sizes[ilayer], aggr='max', normalize=True) 188 | 189 | self.add_module("layer"+str(ilayer), conv_layer) 190 | 191 | def forward(self, x, edge_index, edge_weight): 192 | 193 | for conv_layer in self.children(): 194 | 195 | if self.model_type == 'kcn': 196 | x = conv_layer(x, edge_index, edge_weight=edge_weight) 197 | 198 | elif self.model_type == 'kcn_gat': 199 | x, (edge_index, attention_weights) = conv_layer(x, edge_index, edge_attr=edge_weight, return_attention_weights=True) 200 | #edge_weight = ttention_weights 201 | 202 | elif self.model_type == 'kcn_sage': 203 | x = conv_layer(x, edge_index) 204 | 205 | x = torch.nn.functional.relu(x) 206 | x = torch.nn.functional.dropout(x, p=self.dropout, training=self.training) 207 | 208 | return x 209 | 210 | 211 | --------------------------------------------------------------------------------