├── LICENSE ├── README.md ├── election.PNG ├── mil_benchmark ├── ds_net_main.py ├── evaluate.py ├── models_all.py ├── modules.py ├── rFF_pool_main.py ├── res_pool_main.py ├── run_trials.py ├── set_transformer_main.py └── utils.py ├── mil_election_classification ├── accuracy_deepset_num_neib_0_r_0_vanilla.csv ├── accuracy_deepset_num_neib_5_r_0_GCN.csv ├── accuracy_deepset_num_neib_5_r_1_BGCN.csv ├── data_cleaning.py ├── deepset_main.py ├── election.pdf ├── evaluate.py ├── evaluate_nd.py ├── models_all_deepset.py ├── models_all_set_transformer.py ├── modules.py ├── run_trials.py ├── set_transformer_main.py ├── utils.py └── visualization.py ├── mil_rental_data ├── data │ ├── adj_nbhd.txt │ ├── districts_data_cleaned.csv │ └── neighbourhood_data.csv ├── deepset_main.py ├── evaluate.py ├── mae_deepset.csv ├── mae_set_transformer.csv ├── mape_deepset.csv ├── mape_set_transformer.csv ├── models_all_deepset.py ├── models_all_set_transformer.py ├── modules.py ├── rmse_deepset.csv ├── rmse_set_transformer.csv ├── set_transformer_main.py └── utils.py ├── mil_text ├── evaluate.py ├── evaluate_all.py ├── models_all.py ├── modules.py ├── rFF_pool_main.py ├── rank_box.pdf ├── rank_box_results.txt ├── rank_plot_all.py ├── res_pool_main.py ├── run_trials.py ├── set_transformer_main.py └── utils.py └── requirements.txt /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2021 Antonios Valkanas 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Bag-Graph-MIL 2 | 3 | ## Bag Graph: Multiple Instance Learning using Bayesian Graph Neural Networks 4 | 5 | This repository contains the code to replicate the results reported in our AAAI 2022 paper: *Bag Graph: 6 | Multiple Instance Learning using Bayesian Graph Neural Networks*. 7 | 8 | ## Contributors: 9 | Antonios Valkanas & Soumyasundar Pal (equal first authors), Florence Regol, Mark Coates (Prof. McGill University) 10 | 11 | 12 | ## Getting Started 13 | 14 | Install the dependencies from the command line with `pip`: 15 | 16 | ```sh 17 | pip install -r requirements.txt --progress-bar off 18 | ``` 19 | 20 | Note: When the datasets are small enough we include them in the repository, otherwise we point to the original source where they can be downloaded. We provide all pre-processing code and the plotting code for the diagrams in the paper. [Google Drive secondary repo with full datasets for experiments 1, 2, 4.](https://drive.google.com/drive/folders/1tpzJivhFtRCxeCOYnpoqpXKSAuCyoPFG?usp=sharing) 21 | 22 | ## Training 23 | All our experiments occupy separate folders. To run a specific experiment go to the appropriate directory and follow the instructions below. Unless otherwise specified, the datasets are included in the experiment subfolders. 24 | ### Running Experiments & Reproducing Results 25 | We perform classification experiments on 5 benchmark MIL datasets, 20 text datasets from the 20 Newsgroups corpus, 26 | and the 2016 US election data. In addition, we also consider a distribution regression task of predicting neighborhood 27 | property rental prices in New York City. 28 | #### 1. Benchmark MIL Datasets 29 | Go to `mil_benchamark` folder and run: 30 | ```sh 31 | python run_trials.py 32 | ``` 33 | #### 2. Text Categorization 34 | Navigate to `mil_text` folder and run: 35 | ```sh 36 | python run_trials.py 37 | ``` 38 | #### 3. Electoral Results Prediction 39 | To get the data go to [the source repository](https://github.com/flaxter/us2016/tree/master/data) and download `centroids_cartesian_10.csv` and `results-2016-election.csv`. Create a `data` subdirectory and copy these files there. 40 | Continue by downloading the `set_features` and `cleaned_set_features` subdirectories from our [drive link](https://drive.google.com/drive/folders/1Qb5us6pu0RUGD20UaFKPy1I8OdPmKDrQ?usp=sharing) and saving them under the same `data` directory. 41 | 42 | In `mil_election_classification` folder run: 43 | ```sh 44 | python run_trials.py 45 | ``` 46 | After the runs have completed you can get the figure from the paper by running: 47 | ```sh 48 | python visualization.py 49 | ``` 50 | You should get the following figure (top 2 methods are baselines, bottom left is ours, bottom right is ground truth): 51 | 52 | ![failed to load](election.PNG "Election Plot") 53 | 54 | #### 4. Rental Price Prediction 55 | Locate `mil_rental_data` folder and run: 56 | ```sh 57 | python deepset_main.py # runs DeepSet backbone trials 58 | python set_transformer_main.py # runs Set Transformer backbone trials 59 | python evaluate.py # statistical significance tests 60 | ``` 61 | 62 | ## Cite 63 | 64 | Please cite our paper if you use this code in your own work: 65 | 66 | ``` 67 | @inproceedings{pal_valkanas2022, 68 | author={S. Pal and A. Valkanas and F. Regol and M. Coates}, 69 | title = {Bag graph: {M}ultiple instance learning using {B}ayesian graph neural networks}, 70 | booktitle={Proc. AAAI Conf. Artificial Intell.}, 71 | month = {Feb.}, 72 | year = {2022}, 73 | address = {Online Conference} 74 | } 75 | -------------------------------------------------------------------------------- /election.PNG: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/networkslab/BagGraph/5ecbc365aeecd6e495a500cea11deb2412f12311/election.PNG -------------------------------------------------------------------------------- /mil_benchmark/evaluate.py: -------------------------------------------------------------------------------- 1 | import csv 2 | import numpy as np 3 | 4 | data_all = ['musk1', 'musk2', 'fox', 'tiger', 'elephant'] 5 | for data_index in range(len(data_all)): 6 | data_name = data_all[data_index] 7 | 8 | if data_index == 3: 9 | pool_ = 'mean' 10 | else: 11 | pool_ = 'max' 12 | 13 | k_gcn_list = [2, 3, 3, 4, 1] 14 | k_bgcn_list = [2, 3, 3, 4, 1] 15 | r_list = [1, 10, 5, 10, 10] 16 | 17 | prnt_str_head = '' 18 | prnt_str = '' 19 | 20 | alg_all = ['vanilla'] 21 | k_all = [0] 22 | r_all = [0] 23 | 24 | alg_all.append('GCN') 25 | k_all.append(k_gcn_list[data_index]) 26 | r_all.append(0) 27 | 28 | alg_all.append('BGCN') 29 | k_all.append(k_bgcn_list[data_index]) 30 | r_all.append(r_list[data_index]) 31 | 32 | for idx, alg_ in enumerate(alg_all): 33 | 34 | file_name = 'accuracy_' + data_name + '_ds_net_pool_num_neib_' + str(k_all[idx]) + '_' + pool_ + '_r_' + str(r_all[idx]) + '_' + alg_ + '.csv' 35 | 36 | raw_acc = np.loadtxt(open(file_name, "rb"), delimiter=",", skiprows=1) 37 | # print(raw_acc.shape) 38 | if raw_acc.ndim == 1: 39 | raw_acc = np.expand_dims(raw_acc, axis=1) 40 | 41 | num_alg = raw_acc.shape[1] 42 | 43 | for i in range(num_alg): 44 | # print('----------------------------------------------------------------------') 45 | # print('Algorithm : ' + algorithms[i]) 46 | acc = np.squeeze(raw_acc[:, i]) 47 | acc = np.reshape(acc, [10, 10]) 48 | mu_ = np.mean(acc) * 100 49 | sigma_ = np.std(np.mean(acc, axis=0)) * 100 50 | new_str = '&' + "{:.1f}".format(mu_) + '$\pm$' + "{:.1f}".format(sigma_) + ' ' 51 | prnt_str = prnt_str + new_str 52 | 53 | new_str_ = alg_ + '_k_' + str(k_all[idx]) + '_r_' + str(r_all[idx]) 54 | prnt_str_head = prnt_str_head + new_str_.rjust(16) + ' ' 55 | 56 | print(prnt_str_head) 57 | print(prnt_str) 58 | 59 | 60 | 61 | -------------------------------------------------------------------------------- /mil_benchmark/modules.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | import math 5 | import scipy.sparse as sp 6 | from torch.nn.parameter import Parameter 7 | from torch.nn.modules.module import Module 8 | 9 | 10 | # Set transformer layers 11 | class MAB(nn.Module): 12 | def __init__(self, dim_Q, dim_K, dim_V, num_heads, ln=False): 13 | super(MAB, self).__init__() 14 | self.dim_V = dim_V 15 | self.num_heads = num_heads 16 | self.fc_q = nn.Linear(dim_Q, dim_V) 17 | self.fc_k = nn.Linear(dim_K, dim_V) 18 | self.fc_v = nn.Linear(dim_K, dim_V) 19 | if ln: 20 | self.ln0 = nn.LayerNorm(dim_V) 21 | self.ln1 = nn.LayerNorm(dim_V) 22 | self.fc_o = nn.Linear(dim_V, dim_V) 23 | self.reset_parameters() 24 | 25 | def reset_parameters(self): 26 | for module in self.children(): 27 | reset_op = getattr(module, "reset_parameters", None) 28 | if callable(reset_op): 29 | reset_op() 30 | 31 | def forward(self, Q, K): 32 | Q = self.fc_q(Q) 33 | K, V = self.fc_k(K), self.fc_v(K) 34 | 35 | dim_split = self.dim_V // self.num_heads 36 | Q_ = torch.cat(Q.split(dim_split, 2), 0) 37 | K_ = torch.cat(K.split(dim_split, 2), 0) 38 | V_ = torch.cat(V.split(dim_split, 2), 0) 39 | 40 | A = torch.softmax(Q_.bmm(K_.transpose(1,2))/math.sqrt(self.dim_V), 2) 41 | O = torch.cat((Q_ + A.bmm(V_)).split(Q.size(0), 0), 2) 42 | O = O if getattr(self, 'ln0', None) is None else self.ln0(O) 43 | O = O + F.relu(self.fc_o(O)) 44 | O = O if getattr(self, 'ln1', None) is None else self.ln1(O) 45 | return O 46 | 47 | 48 | class SAB(nn.Module): 49 | def __init__(self, dim_in, dim_out, num_heads, ln=False): 50 | super(SAB, self).__init__() 51 | self.mab = MAB(dim_in, dim_in, dim_out, num_heads, ln=ln) 52 | 53 | def forward(self, X): 54 | return self.mab(X, X) 55 | 56 | 57 | class ISAB(nn.Module): 58 | def __init__(self, dim_in, dim_out, num_heads, num_inds, ln=False): 59 | super(ISAB, self).__init__() 60 | self.I = nn.Parameter(torch.Tensor(1, num_inds, dim_out)) 61 | nn.init.xavier_uniform_(self.I) 62 | self.mab0 = MAB(dim_out, dim_in, dim_out, num_heads, ln=ln) 63 | self.mab1 = MAB(dim_in, dim_out, dim_out, num_heads, ln=ln) 64 | 65 | def forward(self, X): 66 | H = self.mab0(self.I.repeat(X.size(0), 1, 1), X) 67 | return self.mab1(X, H) 68 | 69 | 70 | class PMA(nn.Module): 71 | def __init__(self, dim, num_heads, num_seeds, ln=False): 72 | super(PMA, self).__init__() 73 | self.S = nn.Parameter(torch.Tensor(1, num_seeds, dim)) 74 | nn.init.xavier_uniform_(self.S) 75 | self.mab = MAB(dim, dim, dim, num_heads, ln=ln) 76 | 77 | def forward(self, X): 78 | return self.mab(self.S.repeat(X.size(0), 1, 1), X) 79 | 80 | 81 | -------------------------------------------------------------------------------- /mil_benchmark/run_trials.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import os 3 | import multiprocessing as mp 4 | import sys 5 | import time 6 | from res_pool_main import run_res_pool_one_dataset 7 | from rFF_pool_main import run_rFF_pool_one_dataset 8 | from ds_net_main import run_ds_net_one_dataset 9 | indices = np.arange(5) 10 | 11 | 12 | def run_all_methods(blank, index): 13 | 14 | if index == 3: 15 | pool_ = 'mean' 16 | else: 17 | pool_ = 'max' 18 | 19 | run_ds_net_one_dataset(blank, index, num_neib=0, pooling=pool_, r=0, alg_name='vanilla') 20 | 21 | if index == 0: 22 | k_gcn = 2 23 | k_bgcn = 2 24 | r_ = 1 25 | elif index == 1: 26 | k_gcn = 3 27 | k_bgcn = 3 28 | r_ = 10 29 | elif index == 2: 30 | k_gcn = 3 31 | k_bgcn = 3 32 | r_ = 5 33 | elif index == 3: 34 | k_gcn = 4 35 | k_bgcn = 4 36 | r_ = 10 37 | elif index == 4: 38 | k_gcn = 1 39 | k_bgcn = 1 40 | r_ = 10 41 | 42 | run_ds_net_one_dataset(blank, index, num_neib=k_gcn, pooling=pool_, r=0, alg_name='GCN') 43 | run_ds_net_one_dataset(blank, index, num_neib=k_bgcn, pooling=pool_, r=r_, alg_name='BGCN') 44 | 45 | 46 | if __name__ == '__main__': 47 | 48 | pool = mp.Pool(processes=5) 49 | pool_results = [pool.apply_async(run_all_methods, (1, index)) for index in indices] 50 | pool.close() 51 | pool.join() 52 | for pr in pool_results: 53 | dict_results = pr.get() 54 | -------------------------------------------------------------------------------- /mil_benchmark/set_transformer_main.py: -------------------------------------------------------------------------------- 1 | import pandas as pd 2 | import numpy as np 3 | import torch.nn.functional as F 4 | from torch import optim 5 | import seaborn as sns 6 | from scipy import stats 7 | import matplotlib as mpl 8 | import scipy.sparse as sp 9 | import networkx as nx 10 | import time 11 | import os 12 | import glob 13 | import csv 14 | import torch 15 | import torchvision 16 | from models_all import * 17 | from utils import * 18 | from sklearn import preprocessing 19 | from scipy.io import loadmat 20 | from sklearn.neighbors import kneighbors_graph 21 | mpl.rcParams['figure.dpi'] = 600 22 | color = sns.color_palette() 23 | #%matplotlib inline 24 | pd.options.mode.chained_assignment = None # default='warn' 25 | 26 | 27 | def run_set_transformer_one_dataset(blank, data_index, r=5): 28 | datasets = ['fox', 'tiger', 'elephant'] 29 | 30 | data_name = datasets[data_index] 31 | print(data_name) 32 | 33 | torch.manual_seed(0) 34 | np.random.seed(0) 35 | 36 | num_neib = 3 37 | EPOCHS = 200 38 | MC_samples = 20 39 | 40 | num_trial = 10 41 | num_fold = 10 42 | 43 | if data_name == 'fox': 44 | lr_ = 1e-4 45 | elif data_name == 'tiger': 46 | lr_ = 5e-4 47 | elif data_name == 'elephant': 48 | lr_ = 1e-4 49 | else: 50 | print('Wrong data name!!!') 51 | exit(0) 52 | 53 | log_dir = 'log_st_r_' + str(r) + '/' + data_name 54 | 55 | check = os.path.isdir(log_dir) 56 | if not check: 57 | os.makedirs(log_dir) 58 | print("created folder : ", log_dir) 59 | else: 60 | print(log_dir, " folder already exists.") 61 | 62 | for f in os.listdir(log_dir): 63 | os.remove(os.path.join(log_dir, f)) 64 | 65 | mat = loadmat('data/' + data_name + '_100x100_matlab.mat') # load mat-file 66 | 67 | bag_ids = np.array(mat['bag_ids']).flatten() # bags 68 | bag_features = np.array(mat['features']).flatten() # features 69 | 70 | df_features = pd.DataFrame(bag_features[0].todense()) 71 | df_bag_ids = pd.DataFrame(bag_ids, columns=['bag_id']) 72 | labels = np.concatenate([np.ones(100), np.zeros(100)], axis=0) 73 | 74 | df = pd.concat([df_bag_ids, df_features], axis=1) 75 | 76 | print(df_bag_ids.shape) 77 | print(df_features.shape) 78 | print(labels.shape) 79 | print(df.shape) 80 | 81 | print(df.isna().sum().max()) 82 | print(df['bag_id'].value_counts()) 83 | 84 | x = df.iloc[:, :] 85 | print(x.shape) 86 | 87 | groups = x.groupby('bag_id').mean() 88 | print(groups.shape) 89 | 90 | grouped_data = groups.values[:, :] 91 | y = labels 92 | print(y.shape) 93 | print(grouped_data.shape) 94 | 95 | scaled_features = x.copy() 96 | col_names = list(x) 97 | features = scaled_features[col_names[1:]].values 98 | 99 | mean_fea = np.mean(features, axis=0, keepdims=True) + 1e-6 100 | std_fea = np.std(features, axis=0, keepdims=True) + 1e-6 101 | features = np.divide(features - mean_fea, std_fea) 102 | print(features.shape) 103 | 104 | scaled_features[col_names[1:]] = features 105 | scaled_features.head() 106 | print(scaled_features.shape) 107 | 108 | groups = scaled_features.groupby('bag_id') 109 | 110 | # Iterate over each group 111 | set_list = [] 112 | for group_name, df_group in groups: 113 | single_set = [] 114 | for row_index, row in df_group.iterrows(): 115 | single_set.append(row[1:].values) 116 | set_list.append(single_set) 117 | 118 | print(len(set_list)) 119 | 120 | target = set_list # target = Set of sets, row = set, 121 | max_cols = max([len(row) for batch in target for row in batch]) 122 | max_rows = max([len(batch) for batch in target]) 123 | print(max_cols) 124 | print(max_rows) 125 | 126 | for i in range(len(set_list)): 127 | set_list[i] = np.array(set_list[i], dtype=float) 128 | 129 | y_ = y.copy().reshape(-1, 1) 130 | labels_copy = y.copy() 131 | print(labels_copy.shape) 132 | print(y_.shape) 133 | print(y.shape) 134 | 135 | def get_performance_test(trial_i, fold_j): 136 | torch.manual_seed(0) 137 | np.random.seed(0) 138 | 139 | y = y_ 140 | 141 | models = [ 142 | SetTransformer(in_features=230, num_heads=4, ln=False), 143 | STGCN(in_features=230, num_heads=4, ln=False) 144 | ] 145 | 146 | def weights_init(m): 147 | if isinstance(m, torch.nn.Linear) or isinstance(m, GraphConvolution): 148 | torch.nn.init.xavier_uniform_(m.weight) 149 | if m.bias is not None: 150 | torch.nn.init.zeros_(m.bias) 151 | 152 | for model in models: 153 | model.apply(weights_init) 154 | 155 | # produce a split for training, validation and testing 156 | mat_fold = loadmat('data/fold_animal/' + str(trial_i) + '/index' + str(fold_j) + '.mat') # load mat-file 157 | 158 | idx_train = np.array(mat_fold['trainIndex']).flatten() - 1 # matlab index from 1, python from 0 159 | idx_test = np.array(mat_fold['testIndex']).flatten() - 1 160 | 161 | bce = nn.BCELoss() 162 | 163 | features = [torch.FloatTensor(set_) for set_ in set_list] 164 | labels = torch.FloatTensor(y) 165 | idx_train = torch.LongTensor(idx_train) 166 | idx_test = torch.LongTensor(idx_test) 167 | 168 | acc = [] 169 | 170 | def train(epoch, adj_=None): 171 | t = time.time() 172 | model.train() 173 | optimizer.zero_grad() 174 | if adj_ is not None: 175 | output, _ = model(features, adj_) 176 | else: 177 | output, _ = model(features) 178 | loss_train = bce(output[idx_train], labels[idx_train]) 179 | loss_train.backward() 180 | optimizer.step() 181 | 182 | acc_train = accuracy(labels[idx_train], output[idx_train]) 183 | acc_test = accuracy(labels[idx_test], output[idx_test]) 184 | # if (epoch+1)%10 == 0: 185 | # print('Epoch: {:04d}'.format(epoch+1), 186 | # 'loss_train: {:.4f}'.format(loss_train.item()), 187 | # 'acc_train: {:.4f}'.format(acc_train.item()), 188 | # 'acc_test: {:.4f}'.format(acc_test.item()), 189 | # 'time: {:.4f}s'.format(time.time() - t)) 190 | 191 | return acc_test.item() 192 | 193 | for network in (models): 194 | 195 | # Model and optimizer 196 | model = network 197 | 198 | adj = None 199 | adj_norm = None 200 | 201 | if isinstance(model, (STGCN)): 202 | embedding = np.loadtxt(log_dir + '/decoding_set_transformer_' + 203 | data_name + '_trial_' + str(trial_i) + '_fold_' + str(fold_j), delimiter=",") 204 | 205 | A = kneighbors_graph(embedding, num_neib, mode='connectivity', include_self=True) 206 | G = nx.from_scipy_sparse_matrix(A) 207 | 208 | adj = nx.to_numpy_array(G) 209 | adj_norm = normalize(adj) 210 | adj_norm = torch.FloatTensor(adj_norm) 211 | 212 | optimizer = optim.Adam(model.parameters(), lr=lr_) 213 | 214 | # Train model 215 | t_total = time.time() 216 | for epoch in range(EPOCHS): 217 | if isinstance(model, SetTransformer): 218 | value = train(epoch) 219 | else: 220 | value = train(epoch, adj_norm) 221 | 222 | model.eval() 223 | with torch.no_grad(): 224 | if isinstance(model, SetTransformer): 225 | output, decoding = model(features) 226 | np.savetxt(log_dir + '/decoding_set_transformer_' + data_name + 227 | '_trial_' + str(trial_i) + '_fold_' + str(fold_j), decoding, delimiter=',') 228 | else: 229 | output, _ = model(features, adj_norm) 230 | 231 | acc.append(accuracy(labels[idx_test], output[idx_test])) 232 | model.apply(weights_init) 233 | 234 | adj = None 235 | adj_norm = None 236 | return acc 237 | 238 | tests = [] 239 | print('SetTransformer', 'STGCN') 240 | for i in range(num_trial): 241 | for j in range(num_fold): 242 | t = time.time() 243 | acc_all = get_performance_test(i+1, j+1) 244 | tests.append(acc_all) 245 | print('run : ' + str(num_fold*i+j+1) + ', accuracy : ' + str(np.array(acc_all)) + ', run time: {:.4f}s'.format(time.time() - t)) 246 | 247 | tests = np.array(tests) 248 | df_test = pd.DataFrame(tests, columns=['SetTransformer', 'STGCN']) 249 | 250 | print(df_test.mean(axis=0)) 251 | 252 | def get_performance_test_bayesian(trial_i, fold_j): 253 | torch.manual_seed(0) 254 | np.random.seed(0) 255 | 256 | y = y_ 257 | 258 | models = [ 259 | STGCN(in_features=230, num_heads=4, ln=False) 260 | ] 261 | 262 | def weights_init(m): 263 | if isinstance(m, torch.nn.Linear) or isinstance(m, GraphConvolution): 264 | torch.nn.init.xavier_uniform_(m.weight) 265 | if m.bias is not None: 266 | torch.nn.init.zeros_(m.bias) 267 | 268 | for model in models: 269 | model.apply(weights_init) 270 | 271 | # produce a split for training, validation and testing 272 | mat_fold = loadmat('data/fold_animal/' + str(trial_i) + '/index' + str(fold_j) + '.mat') # load mat-file 273 | 274 | idx_train = np.array(mat_fold['trainIndex']).flatten() - 1 # matlab index from 1, python from 0 275 | idx_test = np.array(mat_fold['testIndex']).flatten() - 1 276 | 277 | bce = nn.BCELoss() 278 | 279 | features = [torch.FloatTensor(set_) for set_ in set_list] 280 | labels = torch.FloatTensor(y) 281 | idx_train = torch.LongTensor(idx_train) 282 | idx_test = torch.LongTensor(idx_test) 283 | 284 | acc = [] 285 | 286 | def train_(epoch, adj_=None): 287 | t = time.time() 288 | model.train() 289 | optimizer.zero_grad() 290 | 291 | output, _ = model(features, adj_) 292 | 293 | loss_train = bce(output[idx_train], labels[idx_train]) 294 | loss_train.backward() 295 | optimizer.step() 296 | 297 | acc_train = accuracy(labels[idx_train], output[idx_train]) 298 | acc_test = accuracy(labels[idx_test], output[idx_test]) 299 | # if (epoch+1)%10 == 0: 300 | # print('Epoch: {:04d}'.format(epoch+1), 301 | # 'loss_train: {:.4f}'.format(loss_train.item()), 302 | # 'acc_train: {:.4f}'.format(acc_train.item()), 303 | # 'acc_test: {:.4f}'.format(acc_test.item()), 304 | # 'time: {:.4f}s'.format(time.time() - t)) 305 | 306 | return acc_test.item(), output 307 | 308 | for network in (models): 309 | 310 | # Model and optimizer 311 | model = network 312 | 313 | adj = None 314 | adj_norm = None 315 | 316 | if isinstance(model, (STGCN)): 317 | embedding = np.loadtxt(log_dir + '/decoding_set_transformer_' + 318 | data_name + '_trial_' + str(trial_i) + '_fold_' + str(fold_j), delimiter=",") 319 | 320 | adj_np = MAP_inference(embedding, num_neib, r) 321 | 322 | adj_np_norm = normalize(adj_np) 323 | adj_norm = adj_np_norm 324 | adj_norm = torch.FloatTensor(adj_norm) 325 | 326 | optimizer = optim.Adam(model.parameters(), lr=lr_) 327 | 328 | # Train model 329 | output_ = 0.0 330 | t_total = time.time() 331 | for epoch in range(EPOCHS): 332 | value, output = train_(epoch, adj_norm) 333 | 334 | if epoch >= EPOCHS - MC_samples: 335 | output_ += output 336 | 337 | output = output_ / np.float32(MC_samples) 338 | 339 | acc.append(accuracy(labels[idx_test], output[idx_test])) 340 | model.apply(weights_init) 341 | 342 | adj = None 343 | adj_norm = None 344 | return acc 345 | 346 | tests_bayesian = [] 347 | print('B-STGCN') 348 | for i in range(num_trial): 349 | for j in range(num_fold): 350 | t = time.time() 351 | acc_bayes = get_performance_test_bayesian(i+1, j+1) 352 | tests_bayesian.append(acc_bayes) 353 | print('run : ' + str(num_fold*i+j+1) + ', accuracy : ' + str(np.array(acc_bayes)) + ', run time: {:.4f}s'.format(time.time() - t)) 354 | 355 | tests_bayesian = np.array(tests_bayesian) 356 | df_test_bayesian = pd.DataFrame(tests_bayesian, columns=['B-STGCN']) 357 | 358 | df_concat = pd.concat([df_test, df_test_bayesian], axis=1) 359 | print(df_concat.mean(axis=0)) 360 | df_concat.to_csv('accuracy_' + data_name + '_set_transformer_r_' + str(r) + '.csv', index=False) 361 | 362 | -------------------------------------------------------------------------------- /mil_benchmark/utils.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from numpy import matlib 3 | from scipy.stats import t 4 | import scipy.sparse as sp 5 | from math import sqrt 6 | from statistics import stdev 7 | 8 | 9 | def normalize(mx): 10 | """Row-normalize sparse matrix""" 11 | rowsum = np.array(mx.sum(1)) 12 | r_inv = np.power(rowsum, -1).flatten() 13 | r_inv[np.isinf(r_inv)] = 0. 14 | r_inv_sqrt = np.sqrt(r_inv) 15 | r_mat_inv_sqrt = np.diag(r_inv_sqrt) 16 | mx = r_mat_inv_sqrt.dot(mx) 17 | mx = mx.dot(r_mat_inv_sqrt) 18 | return mx 19 | 20 | 21 | def accuracy(labels, output): 22 | preds = (output > 0.5).type_as(labels) 23 | correct = preds.eq(labels).double() 24 | correct = correct.sum() 25 | return correct / len(labels) 26 | 27 | 28 | def compute_distance(embed): 29 | N = embed.shape[0] 30 | p = np.dot(embed, np.transpose(embed)) 31 | q = np.matlib.repmat(np.diag(p), N, 1) 32 | dist = q + np.transpose(q) - 2 * p 33 | dist[dist < 1e-8] = 1e-8 34 | return dist 35 | 36 | 37 | def estimate_graph(gamma, epsilon, dist, max_iter, k, r): 38 | np.random.seed(0) 39 | 40 | N = dist.shape[0] 41 | dist += 1e10 * np.eye(N) 42 | 43 | deg_exp = np.minimum(int(N-1), int(k * r)) 44 | 45 | dist_sort_col_idx = np.argsort(dist, axis=0) 46 | dist_sort_col_idx = np.transpose(dist_sort_col_idx[0:deg_exp, :]) 47 | 48 | dist_sort_row_idx = np.matlib.repmat(np.arange(N).reshape(N, 1), 1, deg_exp) 49 | 50 | dist_sort_col_idx = np.reshape(dist_sort_col_idx, int(N * deg_exp)).astype(int) 51 | dist_sort_row_idx = np.reshape(dist_sort_row_idx, int(N * deg_exp)).astype(int) 52 | 53 | dist_idx = np.zeros((int(N * deg_exp), 2)).astype(int) 54 | dist_idx[:, 0] = dist_sort_col_idx 55 | dist_idx[:, 1] = dist_sort_row_idx 56 | dist_idx = np.sort(dist_idx, axis=1) 57 | dist_idx = np.unique(dist_idx, axis=0) 58 | dist_sort_col_idx = dist_idx[:, 0] 59 | dist_sort_row_idx = dist_idx[:, 1] 60 | 61 | num_edges = len(dist_sort_col_idx) 62 | 63 | w_init = np.random.uniform(0, 1, size=(num_edges, 1)) 64 | d_init = k * np.random.uniform(0, 1, size=(N, 1)) 65 | 66 | w_current = w_init 67 | d_current = d_init 68 | 69 | dist_sorted = np.sort(dist, axis=0) 70 | 71 | B_k = np.sum(dist_sorted[0:k, :], axis=0) 72 | dist_sorted_k = dist_sorted[k-1, :] 73 | dist_sorted_k_plus_1 = dist_sorted[k, :] 74 | 75 | theta_lb = 1 / np.sqrt(k * dist_sorted_k_plus_1 ** 2 - B_k * dist_sorted_k_plus_1) 76 | theta_lb = theta_lb[~np.isnan(theta_lb)] 77 | theta_lb = theta_lb[~np.isinf(theta_lb)] 78 | theta_lb = np.mean(theta_lb) 79 | 80 | theta_ub = 1 / np.sqrt(k * dist_sorted_k ** 2 - B_k * dist_sorted_k) 81 | theta_ub = theta_ub[~np.isnan(theta_ub)] 82 | theta_ub = theta_ub[~np.isinf(theta_ub)] 83 | if len(theta_ub) > 0: 84 | theta_ub = np.mean(theta_ub) 85 | else: 86 | theta_ub = theta_lb 87 | 88 | theta = (theta_lb + theta_ub) / 2 89 | 90 | dist = theta * dist 91 | 92 | z = dist[dist_sort_row_idx, dist_sort_col_idx] 93 | z.shape = (num_edges, 1) 94 | 95 | for iter in range(max_iter): 96 | 97 | # print('Graph inference epoch : ' + str(iter)) 98 | 99 | St_times_d = d_current[dist_sort_row_idx] + d_current[dist_sort_col_idx] 100 | y_current = w_current - gamma * (2 * w_current + St_times_d) 101 | 102 | adj_current = np.zeros((N, N)) 103 | adj_current[dist_sort_row_idx, dist_sort_col_idx] = np.squeeze(w_current) 104 | adj_current = adj_current + np.transpose(adj_current) 105 | S_times_w = np.sum(adj_current, axis=1) 106 | S_times_w.shape = (N, 1) 107 | y_bar_current = d_current + gamma * S_times_w 108 | 109 | p_current = np.maximum(0, np.abs(y_current) - 2 * gamma * z) 110 | p_bar_current = (y_bar_current - np.sqrt(y_bar_current * y_bar_current + 4 * gamma)) / 2 111 | 112 | St_times_p_bar = p_bar_current[dist_sort_row_idx] + p_bar_current[dist_sort_col_idx] 113 | q_current = p_current - gamma * (2 * p_current + St_times_p_bar) 114 | 115 | p_matrix_current = np.zeros((N, N)) 116 | p_matrix_current[dist_sort_row_idx, dist_sort_col_idx] = np.squeeze(p_current) 117 | p_matrix_current = p_matrix_current + np.transpose(p_matrix_current) 118 | S_times_p = np.sum(p_matrix_current, axis=1) 119 | S_times_p.shape = (N, 1) 120 | q_bar_current = p_bar_current + gamma * S_times_p 121 | 122 | w_updated = np.abs(w_current - y_current + q_current) 123 | d_updated = np.abs(d_current - y_bar_current + q_bar_current) 124 | 125 | if (np.linalg.norm(w_updated - w_current) / np.linalg.norm(w_current) < epsilon) and \ 126 | (np.linalg.norm(d_updated - d_current) / np.linalg.norm(d_current) < epsilon): 127 | break 128 | else: 129 | w_current = w_updated 130 | d_current = d_updated 131 | 132 | upper_tri_index = np.triu_indices(N, k=1) 133 | 134 | z = dist[upper_tri_index[0], upper_tri_index[1]] 135 | z.shape = (int(N * (N - 1) / 2), 1) 136 | z = z * np.max(w_current) 137 | 138 | w_current = w_current / np.max(w_current) 139 | 140 | inferred_graph = np.zeros((N, N)) 141 | inferred_graph[dist_sort_row_idx, dist_sort_col_idx] = np.squeeze(w_current) 142 | inferred_graph = inferred_graph + np.transpose(inferred_graph) + np.eye(N) 143 | 144 | return inferred_graph 145 | 146 | 147 | def MAP_inference(x, num_neib, r): 148 | N = x.shape[0] 149 | k = int(num_neib) 150 | 151 | dist = compute_distance(x) 152 | 153 | inferred_graph = estimate_graph(0.01, 0.001, dist, 1000, k, r) 154 | 155 | return inferred_graph 156 | -------------------------------------------------------------------------------- /mil_election_classification/accuracy_deepset_num_neib_0_r_0_vanilla.csv: -------------------------------------------------------------------------------- 1 | DeepSet 2 | 0.7528795811518325 3 | 0.7528795811518325 4 | 0.7549738219895288 5 | 0.7078534031413612 6 | 0.6732984293193718 7 | 0.7476439790575916 8 | 0.7465968586387435 9 | 0.6984293193717277 10 | 0.6387434554973822 11 | 0.7193717277486911 12 | 0.762303664921466 13 | 0.7455497382198953 14 | 0.7057591623036649 15 | 0.7602094240837697 16 | 0.7382198952879581 17 | 0.743455497382199 18 | 0.7329842931937173 19 | 0.7455497382198953 20 | 0.7476439790575916 21 | 0.7560209424083769 22 | 0.694240837696335 23 | 0.6952879581151833 24 | 0.7267015706806282 25 | 0.7151832460732984 26 | 0.7403141361256544 27 | 0.7591623036649214 28 | 0.7382198952879581 29 | 0.7717277486910995 30 | 0.7392670157068063 31 | 0.6963350785340314 32 | 0.650261780104712 33 | 0.6931937172774869 34 | 0.7392670157068063 35 | 0.7675392670157068 36 | 0.6575916230366492 37 | 0.7204188481675393 38 | 0.6890052356020943 39 | 0.7099476439790576 40 | 0.6554973821989529 41 | 0.749738219895288 42 | 0.7706806282722513 43 | 0.7643979057591623 44 | 0.7225130890052356 45 | 0.7518324607329843 46 | 0.7361256544502618 47 | 0.7424083769633508 48 | 0.7518324607329843 49 | 0.7717277486910995 50 | 0.669109947643979 51 | 0.762303664921466 52 | 0.6994764397905759 53 | 0.7392670157068063 54 | 0.7465968586387435 55 | 0.7706806282722513 56 | 0.7361256544502618 57 | 0.7507853403141361 58 | 0.743455497382199 59 | 0.7570680628272252 60 | 0.7445026178010471 61 | 0.7120418848167539 62 | 0.6984293193717277 63 | 0.7455497382198953 64 | 0.749738219895288 65 | 0.7769633507853403 66 | 0.7225130890052356 67 | 0.7654450261780105 68 | 0.6816753926701571 69 | 0.7204188481675393 70 | 0.7539267015706806 71 | 0.7403141361256544 72 | 0.7570680628272252 73 | 0.6701570680628273 74 | 0.6167539267015707 75 | 0.7528795811518325 76 | 0.7445026178010471 77 | 0.7664921465968586 78 | 0.7591623036649214 79 | 0.7717277486910995 80 | 0.7308900523560209 81 | 0.7643979057591623 82 | 0.7403141361256544 83 | 0.6921465968586388 84 | 0.7350785340314137 85 | 0.7413612565445026 86 | 0.7465968586387435 87 | 0.7654450261780105 88 | 0.7225130890052356 89 | 0.7089005235602094 90 | 0.7581151832460733 91 | 0.7005235602094241 92 | 0.7277486910994765 93 | 0.7361256544502618 94 | 0.7382198952879581 95 | 0.7392670157068063 96 | 0.743455497382199 97 | 0.7329842931937173 98 | 0.7518324607329843 99 | 0.7392670157068063 100 | 0.7675392670157068 101 | 0.7581151832460733 102 | -------------------------------------------------------------------------------- /mil_election_classification/accuracy_deepset_num_neib_5_r_0_GCN.csv: -------------------------------------------------------------------------------- 1 | DSGCN 2 | 0.7581151832460733 3 | 0.7643979057591623 4 | 0.7560209424083769 5 | 0.7382198952879581 6 | 0.7706806282722513 7 | 0.7675392670157068 8 | 0.7780104712041885 9 | 0.6534031413612565 10 | 0.7654450261780105 11 | 0.724607329842932 12 | 0.7916230366492146 13 | 0.6869109947643979 14 | 0.6816753926701571 15 | 0.7811518324607329 16 | 0.7696335078534031 17 | 0.762303664921466 18 | 0.7591623036649214 19 | 0.7727748691099476 20 | 0.774869109947644 21 | 0.7162303664921466 22 | 0.7486910994764397 23 | 0.643979057591623 24 | 0.5738219895287958 25 | 0.669109947643979 26 | 0.6848167539267016 27 | 0.7486910994764397 28 | 0.7549738219895288 29 | 0.768586387434555 30 | 0.7842931937172775 31 | 0.7633507853403141 32 | 0.7130890052356021 33 | 0.7738219895287958 34 | 0.7560209424083769 35 | 0.7382198952879581 36 | 0.7078534031413612 37 | 0.7654450261780105 38 | 0.5842931937172775 39 | 0.7581151832460733 40 | 0.6544502617801047 41 | 0.7696335078534031 42 | 0.7109947643979058 43 | 0.7664921465968586 44 | 0.7926701570680629 45 | 0.7853403141361257 46 | 0.7738219895287958 47 | 0.7947643979057591 48 | 0.7738219895287958 49 | 0.7696335078534031 50 | 0.7151832460732984 51 | 0.7633507853403141 52 | 0.6282722513089005 53 | 0.7403141361256544 54 | 0.7350785340314137 55 | 0.7371727748691099 56 | 0.7842931937172775 57 | 0.7193717277486911 58 | 0.7581151832460733 59 | 0.7727748691099476 60 | 0.7675392670157068 61 | 0.7287958115183246 62 | 0.7842931937172775 63 | 0.7340314136125654 64 | 0.7361256544502618 65 | 0.7005235602094241 66 | 0.7329842931937173 67 | 0.7801047120418848 68 | 0.7267015706806282 69 | 0.7445026178010471 70 | 0.7706806282722513 71 | 0.7392670157068063 72 | 0.7675392670157068 73 | 0.694240837696335 74 | 0.6712041884816754 75 | 0.7214659685863875 76 | 0.7979057591623037 77 | 0.7727748691099476 78 | 0.7549738219895288 79 | 0.7801047120418848 80 | 0.7738219895287958 81 | 0.7549738219895288 82 | 0.7162303664921466 83 | 0.7183246073298429 84 | 0.6910994764397905 85 | 0.787434554973822 86 | 0.7424083769633508 87 | 0.7916230366492146 88 | 0.7633507853403141 89 | 0.7769633507853403 90 | 0.768586387434555 91 | 0.7654450261780105 92 | 0.7382198952879581 93 | 0.7717277486910995 94 | 0.7413612565445026 95 | 0.7518324607329843 96 | 0.650261780104712 97 | 0.7518324607329843 98 | 0.7905759162303665 99 | 0.6649214659685864 100 | 0.6261780104712041 101 | 0.7507853403141361 102 | -------------------------------------------------------------------------------- /mil_election_classification/accuracy_deepset_num_neib_5_r_1_BGCN.csv: -------------------------------------------------------------------------------- 1 | B-DSGCN 2 | 0.7675392670157068 3 | 0.749738219895288 4 | 0.762303664921466 5 | 0.7350785340314137 6 | 0.7151832460732984 7 | 0.762303664921466 8 | 0.7675392670157068 9 | 0.7267015706806282 10 | 0.6963350785340314 11 | 0.7298429319371728 12 | 0.768586387434555 13 | 0.7549738219895288 14 | 0.7738219895287958 15 | 0.7633507853403141 16 | 0.7371727748691099 17 | 0.7570680628272252 18 | 0.7549738219895288 19 | 0.7602094240837697 20 | 0.7821989528795812 21 | 0.7539267015706806 22 | 0.7413612565445026 23 | 0.7403141361256544 24 | 0.6869109947643979 25 | 0.7706806282722513 26 | 0.7717277486910995 27 | 0.7486910994764397 28 | 0.7267015706806282 29 | 0.7717277486910995 30 | 0.7780104712041885 31 | 0.7413612565445026 32 | 0.7392670157068063 33 | 0.7539267015706806 34 | 0.7340314136125654 35 | 0.7643979057591623 36 | 0.6973821989528796 37 | 0.7581151832460733 38 | 0.6638743455497382 39 | 0.7256544502617801 40 | 0.6617801047120419 41 | 0.762303664921466 42 | 0.7518324607329843 43 | 0.7926701570680629 44 | 0.7057591623036649 45 | 0.7602094240837697 46 | 0.7769633507853403 47 | 0.7675392670157068 48 | 0.7560209424083769 49 | 0.7727748691099476 50 | 0.7183246073298429 51 | 0.7591623036649214 52 | 0.6837696335078534 53 | 0.7392670157068063 54 | 0.7382198952879581 55 | 0.7549738219895288 56 | 0.7696335078534031 57 | 0.7581151832460733 58 | 0.749738219895288 59 | 0.743455497382199 60 | 0.7612565445026178 61 | 0.7120418848167539 62 | 0.7057591623036649 63 | 0.7298429319371728 64 | 0.7225130890052356 65 | 0.7256544502617801 66 | 0.6973821989528796 67 | 0.7675392670157068 68 | 0.7078534031413612 69 | 0.7371727748691099 70 | 0.7643979057591623 71 | 0.7424083769633508 72 | 0.7675392670157068 73 | 0.7225130890052356 74 | 0.669109947643979 75 | 0.7602094240837697 76 | 0.7099476439790576 77 | 0.768586387434555 78 | 0.7602094240837697 79 | 0.7769633507853403 80 | 0.7507853403141361 81 | 0.762303664921466 82 | 0.7350785340314137 83 | 0.5916230366492147 84 | 0.7528795811518325 85 | 0.7612565445026178 86 | 0.7633507853403141 87 | 0.7549738219895288 88 | 0.7696335078534031 89 | 0.7549738219895288 90 | 0.762303664921466 91 | 0.7141361256544503 92 | 0.7769633507853403 93 | 0.7539267015706806 94 | 0.7298429319371728 95 | 0.7455497382198953 96 | 0.7308900523560209 97 | 0.7287958115183246 98 | 0.7842931937172775 99 | 0.7539267015706806 100 | 0.675392670157068 101 | 0.7706806282722513 102 | -------------------------------------------------------------------------------- /mil_election_classification/data_cleaning.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import pandas as pd 3 | import pickle as pk 4 | import glob 5 | import os 6 | 7 | county_list = os.listdir('data/set_features') 8 | print(county_list) 9 | 10 | to_remove = [] 11 | 12 | for id, file_name_ in enumerate(county_list[0:1]): 13 | print(id) 14 | 15 | s = pd.read_parquet('data/set_features/' + file_name_, engine='pyarrow') 16 | # for i in s.index: 17 | # print(s.loc[i].tolist()) 18 | nan_list = s.columns[s.isna().any()].tolist() 19 | str_list = [] 20 | for c in s.columns: 21 | # if s[c].dtype == object: 22 | # print('damn') 23 | # print(c) 24 | # print(s[c]) 25 | if isinstance(s.iloc[0][c], str): 26 | str_list.append(c) 27 | 28 | non_numer_list = [] 29 | for c in s.columns: 30 | if not pd.to_numeric(s[c], errors='coerce').notnull().all(): 31 | non_numer_list.append(c) 32 | 33 | # print(len(nan_list)) 34 | # print(len(str_list)) 35 | # print(len(non_numer_list)) 36 | 37 | to_remove = list(set.union(set(to_remove), set(nan_list), set(str_list), set(non_numer_list))) 38 | 39 | 40 | print(to_remove) 41 | print(len(to_remove)) 42 | 43 | for file_name_ in county_list: 44 | s = pd.read_parquet('data/set_features/' + file_name_, engine='pyarrow') 45 | s = s.drop(columns=to_remove).to_numpy() 46 | s = np.float32(s) 47 | with open('data/cleaned_features/' + file_name_[:-3] + '.pkl', 'wb') as f: 48 | pk.dump(s, f) 49 | -------------------------------------------------------------------------------- /mil_election_classification/deepset_main.py: -------------------------------------------------------------------------------- 1 | import numpy as np # linear algebra 2 | import pandas as pd # data processing, CSV file I/O (e.g. pd.read_csv) 3 | import matplotlib.pyplot as plt 4 | from numpy import matlib 5 | import seaborn as sns 6 | from scipy import stats 7 | from sklearn.preprocessing import StandardScaler 8 | from torch import optim 9 | import torch.nn as nn 10 | import matplotlib as mpl 11 | import networkx as nx 12 | import scipy.sparse as sp 13 | import pickle as pk 14 | import time 15 | import torch 16 | import sys 17 | import glob 18 | import os 19 | from models_all_deepset import * 20 | from sklearn.neighbors import kneighbors_graph 21 | from scipy.stats import wilcoxon 22 | from utils import * 23 | from sklearn import preprocessing 24 | from subprocess import call 25 | mpl.rcParams['figure.dpi'] = 600 26 | color = sns.color_palette() 27 | pd.options.mode.chained_assignment = None # default='warn' 28 | 29 | 30 | def run_deepset(num_neib, r, alg_name='vanilla'): 31 | num_trials = 100 32 | lr_ = 1e-3 33 | weight_decay_ = 1e-4 34 | EPOCHS = 200 35 | MC_samples = 5 36 | 37 | log_dir = 'log_ds' 38 | 39 | check = os.path.isdir(log_dir) 40 | if not check: 41 | os.makedirs(log_dir) 42 | print("created folder : ", log_dir) 43 | else: 44 | print(log_dir, " folder already exists.") 45 | 46 | file_name = 'data/centroids_cartesian_10.csv' 47 | location = np.loadtxt(open(file_name, "rb"), delimiter=",", skiprows=1, usecols=range(1, 3)) 48 | print(location.shape) 49 | 50 | file_name = 'data/results-2016-election.csv' 51 | votes = np.loadtxt(open(file_name, "rb"), delimiter=",", skiprows=1, usecols=range(1, 4)) 52 | votes = votes[:, [0, 2]] # keeping republican and democrat 53 | votes = votes[:, [0]] / votes.sum(axis=1, keepdims=True) # republican vote probability 54 | print(votes.shape) 55 | 56 | county_list = os.listdir('data/cleaned_features') 57 | # print(county_list) 58 | 59 | set_list = [] 60 | 61 | for file_name_ in county_list: 62 | with open('data/cleaned_features/' + file_name_, 'rb') as f: 63 | set_ = pk.load(f) 64 | set_list.append(set_) 65 | 66 | def sample_dataset(n=100, seed=0): 67 | np.random.seed(seed) 68 | 69 | X = [] 70 | for set_ in set_list: 71 | row_id = np.random.choice(set_.shape[0], n, replace=False) 72 | X.append(set_[row_id, :]) 73 | 74 | X = np.array(X) 75 | X = (X - np.mean(X, axis=(0, 1), keepdims=True))/np.std(X, axis=(0, 1), keepdims=True) 76 | return X 77 | 78 | def get_performance_test(trial_i): 79 | torch.manual_seed(trial_i) 80 | np.random.seed(trial_i) 81 | 82 | indices = np.arange(0, votes.shape[0]) 83 | np.random.shuffle(indices) 84 | num_train = int(0.025 * votes.shape[0]) 85 | idx_train = indices[:num_train] 86 | idx_test = indices[num_train:] 87 | np.savetxt(log_dir + '/idx_test' + '_trial_' + str(trial_i) + '.txt', idx_test, delimiter=',') 88 | 89 | X = sample_dataset(n=100, seed=trial_i) 90 | X = np.array(X, dtype='float') 91 | 92 | features = torch.FloatTensor(X) 93 | labels = torch.FloatTensor(votes) 94 | 95 | acc = [] 96 | 97 | if alg_name == 'vanilla': 98 | models = [ 99 | DeepSet(in_features=94) 100 | ] 101 | elif alg_name == 'GCN': 102 | models = [ 103 | DSGCN(in_features=94) 104 | ] 105 | 106 | A = kneighbors_graph(location, num_neib, mode='connectivity', include_self=True) 107 | G = nx.from_scipy_sparse_matrix(A) 108 | adj = nx.to_numpy_array(G) 109 | adj_norm = normalize(adj) 110 | adj_norm = torch.FloatTensor(adj_norm) 111 | 112 | def weight_reset(m): 113 | reset_parameters = getattr(m, "reset_parameters", None) 114 | if callable(reset_parameters): 115 | m.reset_parameters() 116 | 117 | for model in models: 118 | model.apply(weight_reset) 119 | 120 | bce = nn.BCELoss() 121 | 122 | def train(epoch, adj_=None): 123 | t = time.time() 124 | model.train() 125 | optimizer.zero_grad() 126 | if adj_ is not None: 127 | output, _ = model(features, adj_) 128 | else: 129 | output, _ = model(features) 130 | loss_train = bce(output[idx_train], labels[idx_train]) 131 | loss_train.backward() 132 | optimizer.step() 133 | 134 | loss_test = bce(output[idx_test], labels[idx_test]) 135 | acc_train = accuracy(labels[idx_train], output[idx_train]) 136 | acc_test = accuracy(labels[idx_test], output[idx_test]) 137 | # if (epoch+1) % 10 == 0: 138 | # print('Epoch: {:04d}'.format(epoch+1), 139 | # 'loss_train: {:.4f}'.format(loss_train.item()), 140 | # 'loss_test: {:.4f}'.format(loss_test.item()), 141 | # 'acc_train: {:.4f}'.format(acc_train.item()), 142 | # 'acc_test: {:.4f}'.format(acc_test.item()), 143 | # 'time: {:.4f}s'.format(time.time() - t)) 144 | return acc_test.item() 145 | 146 | for network in (models): 147 | 148 | t_total = time.time() 149 | # Model and optimizer 150 | model = network 151 | 152 | no_decay = list() 153 | decay = list() 154 | for m in model.modules(): 155 | if isinstance(m, torch.nn.Linear) or isinstance(m, GraphConvolution): 156 | decay.append(m.weight) 157 | no_decay.append(m.bias) 158 | 159 | optimizer = optim.Adam([{'params': no_decay, 'weight_decay': 0}, {'params': decay, 'weight_decay': weight_decay_}], lr=lr_) 160 | 161 | # Train model 162 | t_total = time.time() 163 | for epoch in range(EPOCHS): 164 | if isinstance(model, DeepSet): 165 | value = train(epoch) 166 | else: 167 | value = train(epoch, adj_norm) 168 | 169 | model.eval() 170 | with torch.no_grad(): 171 | if isinstance(model, (DeepSet)): 172 | output, decoding = model(features) 173 | np.savetxt(log_dir + '/decoding_deepset' + '_trial_' + str(trial_i) + '_num_neib_' + str(num_neib) + '_r_' + str(r) + '.txt', decoding, delimiter=',') 174 | np.savetxt(log_dir + '/output_deepset' + '_trial_' + str(trial_i) + '_num_neib_' + str(num_neib) + '_r_' + str(r) + '.txt', output, delimiter=',') 175 | 176 | if isinstance(model, (DSGCN)): 177 | output, decoding = model(features, adj_norm) 178 | np.savetxt(log_dir + '/decoding_deepset_gcn' + '_trial_' + str(trial_i) + '_num_neib_' + str(num_neib) + '_r_' + str(r) + '.txt', decoding, delimiter=',') 179 | np.savetxt(log_dir + '/output_deepset_gcn' + '_trial_' + str(trial_i) + '_num_neib_' + str(num_neib) + '_r_' + str(r) + '.txt', output, delimiter=',') 180 | 181 | acc.append(accuracy(labels[idx_test], output[idx_test])) 182 | model.apply(weight_reset) 183 | return acc 184 | 185 | if alg_name == 'vanilla' or alg_name == 'GCN': 186 | tests = [] 187 | if alg_name == 'vanilla': 188 | print('DeepSet') 189 | elif alg_name == 'GCN': 190 | print('DSGCN') 191 | for i in range(num_trials): 192 | t = time.time() 193 | acc_all = get_performance_test(i) 194 | tests.append(acc_all) 195 | print('run : ' + str(i+1) + ', accuracy : ' + str(np.array(acc_all)) + ', run time: {:.2f}s'.format(time.time() - t)) 196 | 197 | tests = np.array(tests) 198 | 199 | if alg_name == 'vanilla': 200 | df_test = pd.DataFrame(tests, columns=['DeepSet']) 201 | elif alg_name == 'GCN': 202 | df_test = pd.DataFrame(tests, columns=['DSGCN']) 203 | 204 | print(df_test.mean(axis=0)) 205 | 206 | def get_performance_test_bayesian(trial_i): 207 | torch.manual_seed(trial_i) 208 | np.random.seed(trial_i) 209 | 210 | indices = np.arange(0, votes.shape[0]) 211 | np.random.shuffle(indices) 212 | num_train = int(0.025 * votes.shape[0]) 213 | idx_train = indices[:num_train] 214 | idx_test = indices[num_train:] 215 | 216 | X = sample_dataset(n=100, seed=trial_i) 217 | X = np.array(X, dtype='float') 218 | 219 | features = torch.FloatTensor(X) 220 | labels = torch.FloatTensor(votes) 221 | 222 | acc = [] 223 | 224 | if alg_name == 'BGCN': 225 | models = [ 226 | DSGCN(in_features=94) 227 | ] 228 | 229 | def weight_reset(m): 230 | reset_parameters = getattr(m, "reset_parameters", None) 231 | if callable(reset_parameters): 232 | m.reset_parameters() 233 | 234 | for model in models: 235 | model.apply(weight_reset) 236 | 237 | bce = nn.BCELoss() 238 | 239 | def train_(epoch, adj_=None): 240 | t = time.time() 241 | model.train() 242 | optimizer.zero_grad() 243 | if adj_ is not None: 244 | output, _ = model(features, adj_) 245 | else: 246 | output, _ = model(features) 247 | loss_train = bce(output[idx_train], labels[idx_train]) 248 | loss_train.backward() 249 | optimizer.step() 250 | 251 | loss_test = bce(output[idx_test], labels[idx_test]) 252 | acc_train = accuracy(labels[idx_train], output[idx_train]) 253 | acc_test = accuracy(labels[idx_test], output[idx_test]) 254 | return acc_test.item(), output 255 | 256 | for network in (models): 257 | 258 | t_total = time.time() 259 | # Model and optimizer 260 | model = network 261 | 262 | if isinstance(model, (DSGCN)): 263 | embedding = np.loadtxt(log_dir + '/decoding_deepset_gcn' + '_trial_' + str(trial_i) + '_num_neib_' + str(num_neib) + '_r_0.txt', delimiter=",") 264 | adj_np = MAP_inference(embedding, num_neib, r) 265 | adj_np_norm = normalize(adj_np) 266 | adj_norm = torch.FloatTensor(adj_np_norm) 267 | 268 | no_decay = list() 269 | decay = list() 270 | for m in model.modules(): 271 | if isinstance(m, torch.nn.Linear) or isinstance(m, GraphConvolution): 272 | decay.append(m.weight) 273 | no_decay.append(m.bias) 274 | 275 | optimizer = optim.Adam([{'params': no_decay, 'weight_decay': 0}, {'params': decay, 'weight_decay': weight_decay_}], lr=lr_) 276 | 277 | # Train model 278 | t_total = time.time() 279 | output_ = 0.0 280 | for epoch in range(EPOCHS): 281 | if isinstance(model, DeepSet): 282 | value, output = train_(epoch) 283 | else: 284 | value, output = train_(epoch, adj_norm) 285 | 286 | if epoch >= EPOCHS - MC_samples: 287 | output_ += output 288 | 289 | output = output_ / np.float32(MC_samples) 290 | np.savetxt(log_dir + '/output_deepset_bgcn' + '_trial_' + str(trial_i) + '_num_neib_' + str(num_neib) + '_r_' + str(r) + '.txt', output.detach().numpy(), delimiter=',') 291 | 292 | acc.append(accuracy(labels[idx_test], output[idx_test])) 293 | model.apply(weight_reset) 294 | adj_norm = None 295 | return acc 296 | 297 | if alg_name == 'BGCN': 298 | tests_bayesian = [] 299 | print('B-DSGCN') 300 | for i in range(num_trials): 301 | t = time.time() 302 | acc_bayes = get_performance_test_bayesian(i) 303 | tests_bayesian.append(acc_bayes) 304 | print('run : ' + str(i+1) + ', accuracy : ' + str(np.array(acc_bayes)) + ', run time: {:.2f}s'.format(time.time() - t)) 305 | 306 | tests_bayesian = np.array(tests_bayesian) 307 | 308 | df_test_bayesian = pd.DataFrame(tests_bayesian, columns=['B-DSGCN']) 309 | 310 | print(df_test_bayesian.mean(axis=0)) 311 | 312 | if alg_name == 'BGCN': 313 | df_test_bayesian.to_csv('accuracy_deepset_num_neib_' + str(num_neib) + '_r_' + str(r) + '_' + alg_name + '.csv', index=False) 314 | 315 | else: 316 | df_test.to_csv('accuracy_deepset_num_neib_' + str(num_neib) + '_r_' + str(r) + '_' + alg_name + '.csv', index=False) 317 | -------------------------------------------------------------------------------- /mil_election_classification/election.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/networkslab/BagGraph/5ecbc365aeecd6e495a500cea11deb2412f12311/mil_election_classification/election.pdf -------------------------------------------------------------------------------- /mil_election_classification/evaluate.py: -------------------------------------------------------------------------------- 1 | import csv 2 | import numpy as np 3 | from scipy.stats import wilcoxon 4 | 5 | alg_name = 'deepset' 6 | 7 | k_gcn_list = [5] 8 | k_bgcn_list = [5] 9 | r_bgcn_list = [1] 10 | 11 | prnt_str_head = '' 12 | prnt_str = '' 13 | 14 | alg_all = ['vanilla'] 15 | k_all = [0] 16 | r_all = [0] 17 | 18 | for k in k_gcn_list: 19 | alg_all.append('GCN') 20 | k_all.append(k) 21 | r_all.append(0) 22 | 23 | for k_ in k_bgcn_list: 24 | for r_ in r_bgcn_list: 25 | alg_all.append('BGCN') 26 | k_all.append(k_) 27 | r_all.append(r_) 28 | 29 | acc_all = [] 30 | 31 | for idx, alg_ in enumerate(alg_all): 32 | 33 | file_name = 'accuracy_' + alg_name + '_num_neib_' + str(k_all[idx]) + '_r_' + str(r_all[idx]) + '_' + alg_ + '.csv' 34 | 35 | raw_acc = np.loadtxt(open(file_name, "rb"), delimiter=",", skiprows=1) 36 | acc_all.append(raw_acc) 37 | 38 | mu_ = np.mean(raw_acc) * 100 39 | sigma_ = np.std(raw_acc) * 100 40 | new_str = '&' + "{:.2f}".format(mu_) + '$\pm$' + "{:.2f}".format(sigma_) + ' ' 41 | prnt_str = prnt_str + new_str.rjust(15) 42 | 43 | new_str_ = alg_ + '_k_' + str(k_all[idx]) + '_r_' + str(r_all[idx]) 44 | prnt_str_head = prnt_str_head + new_str_.rjust(15) + ' ' 45 | 46 | print(prnt_str_head) 47 | print(prnt_str) 48 | 49 | 50 | print('----------------statistical test-----------') 51 | _, p = wilcoxon(acc_all[0], acc_all[1], zero_method='wilcox', correction=False) 52 | print('vanilla vs GCN') 53 | print(p) 54 | _, p = wilcoxon(acc_all[0], acc_all[2], zero_method='wilcox', correction=False) 55 | print('vanilla vs BGCN') 56 | print(p) 57 | _, p = wilcoxon(acc_all[1], acc_all[2], zero_method='wilcox', correction=False) 58 | print('GCN vs BGCN') 59 | print(p) 60 | -------------------------------------------------------------------------------- /mil_election_classification/evaluate_nd.py: -------------------------------------------------------------------------------- 1 | import csv 2 | import numpy as np 3 | from scipy.stats import wilcoxon 4 | 5 | alg_name = 'deepset' 6 | log_dir = 'log_ds' 7 | num_trials = 100 8 | 9 | file_name = 'data/results-2016-election.csv' 10 | votes = np.loadtxt(open(file_name, "rb"), delimiter=",", skiprows=1, usecols=range(1, 4)) 11 | votes = votes[:, [0, 2]] # keeping republican and democrat 12 | votes = votes[:, [0]] / votes.sum(axis=1, keepdims=True) # republican vote probability 13 | votes = np.squeeze(votes) 14 | 15 | k_gcn_list = [5] 16 | k_bgcn_list = [5] 17 | r_bgcn_list = [1] 18 | 19 | alg_all = ['Vanilla'] 20 | k_all = [0] 21 | r_all = [0] 22 | 23 | for k in k_gcn_list: 24 | alg_all.append('GCN') 25 | k_all.append(k) 26 | r_all.append(0) 27 | 28 | for k_ in k_bgcn_list: 29 | for r_ in r_bgcn_list: 30 | alg_all.append('BGCN') 31 | k_all.append(k_) 32 | r_all.append(r_) 33 | 34 | nd_all = np.zeros([num_trials, len(alg_all)]) 35 | 36 | for idx, alg_ in enumerate(alg_all): 37 | 38 | for i in range(num_trials): 39 | 40 | idx_test = np.squeeze(np.loadtxt(log_dir + '/idx_test_trial_' + str(i) + '.txt')).astype(int) 41 | 42 | if idx == 0: 43 | filename = log_dir + '/output_' + alg_name + '_trial_' + str(i) + '_num_neib_' + str(k_all[idx]) + '_r_' + str(r_all[idx]) + '.txt' 44 | else: 45 | filename = log_dir + '/output_' + alg_name + '_' + alg_.lower() + '_trial_' + str(i) + '_num_neib_' + str(k_all[idx]) + '_r_' + str(r_all[idx]) + '.txt' 46 | 47 | output = np.squeeze(np.loadtxt(filename)) 48 | 49 | err = np.abs(votes - output) 50 | nd_all[i, idx] = np.mean(err[idx_test])/np.mean(np.abs(output[idx_test])) * 100 51 | 52 | 53 | print(nd_all) 54 | 55 | print('mean ND') 56 | print(nd_all.mean(axis=0)) 57 | print('std. error of ND') 58 | print(nd_all.std(axis=0)) 59 | 60 | print('----------------statistical test-----------') 61 | _, p = wilcoxon(nd_all[0], nd_all[1], zero_method='wilcox', correction=False) 62 | print('vanilla vs GCN') 63 | print(p) 64 | _, p = wilcoxon(nd_all[0], nd_all[2], zero_method='wilcox', correction=False) 65 | print('vanilla vs BGCN') 66 | print(p) 67 | _, p = wilcoxon(nd_all[1], nd_all[2], zero_method='wilcox', correction=False) 68 | print('GCN vs BGCN') 69 | print(p) 70 | -------------------------------------------------------------------------------- /mil_election_classification/models_all_deepset.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | import math 5 | import scipy.sparse as sp 6 | from torch.nn.parameter import Parameter 7 | from torch.nn.modules.module import Module 8 | from modules import * 9 | 10 | 11 | # GCN model 12 | class GraphConvolution(Module): 13 | """ 14 | Simple GCN layer, similar to https://arxiv.org/abs/1609.02907 15 | """ 16 | 17 | def __init__(self, in_features, out_features, bias=True): 18 | super(GraphConvolution, self).__init__() 19 | self.in_features = in_features 20 | self.out_features = out_features 21 | self.weight = Parameter(torch.FloatTensor(in_features, out_features)) 22 | if bias: 23 | self.bias = Parameter(torch.FloatTensor(out_features)) 24 | else: 25 | self.register_parameter('bias', None) 26 | self.reset_parameters() 27 | 28 | def reset_parameters(self): 29 | stdv = 1. / math.sqrt(self.in_features) 30 | self.weight.data.uniform_(-stdv, stdv) 31 | if self.bias is not None: 32 | self.bias.data.uniform_(-stdv, stdv) 33 | 34 | def forward(self, input, adj): 35 | support = torch.mm(input, self.weight) 36 | output = torch.mm(adj, support) 37 | if self.bias is not None: 38 | return output + self.bias 39 | else: 40 | return output 41 | 42 | def __repr__(self): 43 | return self.__class__.__name__ + ' (' \ 44 | + str(self.in_features) + ' -> ' \ 45 | + str(self.out_features) + ')' 46 | 47 | 48 | #Deep Set Model 49 | class DeepSet(nn.Module): 50 | 51 | def __init__(self, in_features=10, set_features=128, nhid=64, dropout=0.5): 52 | super(DeepSet, self).__init__() 53 | self.in_features = in_features 54 | self.out_features = set_features 55 | self.fc = nn.Linear(nhid, 1) 56 | self.dropout = dropout 57 | self.feature_extractor = nn.Sequential( 58 | nn.Linear(in_features, set_features), 59 | nn.ReLU(), 60 | nn.Linear(set_features, set_features), 61 | nn.ReLU() 62 | ) 63 | 64 | self.regressor = nn.Sequential( 65 | nn.Linear(set_features, nhid), 66 | nn.ReLU() 67 | ) 68 | self.reset_parameters() 69 | 70 | self.add_module('0', self.feature_extractor) 71 | self.add_module('1', self.regressor) 72 | 73 | def reset_parameters(self): 74 | for module in self.children(): 75 | reset_op = getattr(module, "reset_parameters", None) 76 | if callable(reset_op): 77 | reset_op() 78 | 79 | def forward(self, input): 80 | x = input 81 | x = self.feature_extractor(x) 82 | x = x.sum(dim=1) 83 | x = self.regressor(x) 84 | embedding = x.cpu().detach().numpy() 85 | x = F.dropout(x, self.dropout) 86 | x = torch.sigmoid(self.fc(x)) 87 | return x, embedding 88 | 89 | 90 | # Graph Deep Set Model 91 | class DSGCN(nn.Module): 92 | 93 | def __init__(self, in_features=10, set_features=128, nhid=64, dropout=0.5): 94 | super(DSGCN, self).__init__() 95 | self.in_features = in_features 96 | self.out_features = set_features 97 | self.gc = GraphConvolution(nhid, 1) 98 | self.dropout = dropout 99 | self.feature_extractor = nn.Sequential( 100 | nn.Linear(in_features, set_features), 101 | nn.ReLU(), 102 | nn.Linear(set_features, set_features), 103 | nn.ReLU() 104 | ) 105 | 106 | self.regressor = nn.Sequential( 107 | nn.Linear(set_features, nhid), 108 | nn.ReLU() 109 | ) 110 | self.reset_parameters() 111 | 112 | self.add_module('0', self.feature_extractor) 113 | self.add_module('1', self.regressor) 114 | 115 | def reset_parameters(self): 116 | for module in self.children(): 117 | reset_op = getattr(module, "reset_parameters", None) 118 | if callable(reset_op): 119 | reset_op() 120 | 121 | def forward(self, input, adj): 122 | x = input 123 | x = self.feature_extractor(x) 124 | x = x.sum(dim=1) 125 | x = self.regressor(x) 126 | embedding = x.cpu().detach().numpy() 127 | x = F.dropout(x, self.dropout) 128 | x = torch.sigmoid(self.gc(x, adj)) 129 | return x, embedding 130 | 131 | -------------------------------------------------------------------------------- /mil_election_classification/models_all_set_transformer.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | import math 5 | import scipy.sparse as sp 6 | from torch.nn.parameter import Parameter 7 | from torch.nn.modules.module import Module 8 | from modules import * 9 | 10 | 11 | # GCN model 12 | class GraphConvolution(Module): 13 | """ 14 | Simple GCN layer, similar to https://arxiv.org/abs/1609.02907 15 | """ 16 | 17 | def __init__(self, in_features, out_features, bias=True): 18 | super(GraphConvolution, self).__init__() 19 | self.in_features = in_features 20 | self.out_features = out_features 21 | self.weight = Parameter(torch.FloatTensor(in_features, out_features)) 22 | if bias: 23 | self.bias = Parameter(torch.FloatTensor(out_features)) 24 | else: 25 | self.register_parameter('bias', None) 26 | self.reset_parameters() 27 | 28 | def reset_parameters(self): 29 | stdv = 1. / math.sqrt(self.in_features) 30 | self.weight.data.uniform_(-stdv, stdv) 31 | if self.bias is not None: 32 | self.bias.data.uniform_(-stdv, stdv) 33 | 34 | def forward(self, input, adj): 35 | support = torch.mm(input, self.weight) 36 | output = torch.mm(adj, support) 37 | if self.bias is not None: 38 | return output + self.bias 39 | else: 40 | return output 41 | 42 | def __repr__(self): 43 | return self.__class__.__name__ + ' (' \ 44 | + str(self.in_features) + ' -> ' \ 45 | + str(self.out_features) + ')' 46 | 47 | 48 | # Set Transformer model 49 | class SetTransformer(nn.Module): 50 | def __init__(self, in_features=200, num_heads=4, dropout=0.5, ln=True): 51 | super(SetTransformer, self).__init__() 52 | self.enc = nn.Sequential( 53 | SAB(dim_in=in_features, dim_out=32, num_heads=num_heads, ln=ln), 54 | # SAB(dim_in=32, dim_out=32, num_heads=num_heads, ln=ln) 55 | ) 56 | self.dec = nn.Sequential( 57 | PMA(dim=32, num_heads=num_heads, num_seeds=1, ln=ln), 58 | # PMA(dim=32, num_heads=num_heads, num_seeds=1, ln=ln) 59 | ) 60 | self.last_layer = nn.Linear(in_features=32, out_features=1) 61 | self.dropout = dropout 62 | 63 | self.reset_parameters() 64 | 65 | def reset_parameters(self): 66 | for module in self.children(): 67 | reset_op = getattr(module, "reset_parameters", None) 68 | if callable(reset_op): 69 | reset_op() 70 | 71 | def forward(self, x): 72 | x = self.enc(x) 73 | x = self.dec(x).squeeze() 74 | embedding = x.cpu().detach().numpy() 75 | x = F.dropout(x, self.dropout) 76 | x = torch.sigmoid(self.last_layer(x)) 77 | return x, embedding 78 | 79 | 80 | # Set Transformer GCN model 81 | class STGCN(nn.Module): 82 | def __init__(self, in_features=200, num_heads=4, dropout=0.5, ln=True): 83 | super(STGCN, self).__init__() 84 | self.enc = nn.Sequential( 85 | SAB(dim_in=in_features, dim_out=32, num_heads=num_heads, ln=ln), 86 | # SAB(dim_in=32, dim_out=32, num_heads=num_heads, ln=ln) 87 | ) 88 | self.dec = nn.Sequential( 89 | PMA(dim=32, num_heads=num_heads, num_seeds=1, ln=ln), 90 | # PMA(dim=32, num_heads=num_heads, num_seeds=1, ln=ln) 91 | ) 92 | self.last_layer = GraphConvolution(in_features=32, out_features=1) 93 | self.dropout = dropout 94 | 95 | self.reset_parameters() 96 | 97 | def reset_parameters(self): 98 | for module in self.children(): 99 | reset_op = getattr(module, "reset_parameters", None) 100 | if callable(reset_op): 101 | reset_op() 102 | 103 | def forward(self, x, adj): 104 | x = self.enc(x) 105 | x = self.dec(x).squeeze() 106 | embedding = x.cpu().detach().numpy() 107 | x = F.dropout(x, self.dropout) 108 | x = torch.sigmoid(self.last_layer(x, adj)) 109 | return x, embedding 110 | 111 | 112 | -------------------------------------------------------------------------------- /mil_election_classification/modules.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | import math 5 | import scipy.sparse as sp 6 | from torch.nn.parameter import Parameter 7 | from torch.nn.modules.module import Module 8 | 9 | 10 | # Set transformer layers 11 | class MAB(nn.Module): 12 | def __init__(self, dim_Q, dim_K, dim_V, num_heads, ln=False): 13 | super(MAB, self).__init__() 14 | self.dim_V = dim_V 15 | self.num_heads = num_heads 16 | self.fc_q = nn.Linear(dim_Q, dim_V) 17 | self.fc_k = nn.Linear(dim_K, dim_V) 18 | self.fc_v = nn.Linear(dim_K, dim_V) 19 | if ln: 20 | self.ln0 = nn.LayerNorm(dim_V) 21 | self.ln1 = nn.LayerNorm(dim_V) 22 | self.fc_o = nn.Linear(dim_V, dim_V) 23 | 24 | self.reset_parameters() 25 | 26 | def reset_parameters(self): 27 | for module in self.children(): 28 | reset_op = getattr(module, "reset_parameters", None) 29 | if callable(reset_op): 30 | reset_op() 31 | 32 | def forward(self, Q, K): 33 | Q = self.fc_q(Q) 34 | K, V = self.fc_k(K), self.fc_v(K) 35 | 36 | dim_split = self.dim_V // self.num_heads 37 | Q_ = torch.cat(Q.split(dim_split, 2), 0) 38 | K_ = torch.cat(K.split(dim_split, 2), 0) 39 | V_ = torch.cat(V.split(dim_split, 2), 0) 40 | 41 | A = torch.softmax(Q_.bmm(K_.transpose(1,2))/math.sqrt(self.dim_V), 2) 42 | O = torch.cat((Q_ + A.bmm(V_)).split(Q.size(0), 0), 2) 43 | O = O if getattr(self, 'ln0', None) is None else self.ln0(O) 44 | O = O + F.relu(self.fc_o(O)) 45 | O = O if getattr(self, 'ln1', None) is None else self.ln1(O) 46 | return O 47 | 48 | 49 | class SAB(nn.Module): 50 | def __init__(self, dim_in, dim_out, num_heads, ln=False): 51 | super(SAB, self).__init__() 52 | self.mab = MAB(dim_in, dim_in, dim_out, num_heads, ln=ln) 53 | 54 | def forward(self, X): 55 | return self.mab(X, X) 56 | 57 | 58 | class ISAB(nn.Module): 59 | def __init__(self, dim_in, dim_out, num_heads, num_inds, ln=False): 60 | super(ISAB, self).__init__() 61 | self.I = nn.Parameter(torch.Tensor(1, num_inds, dim_out)) 62 | nn.init.xavier_uniform_(self.I) 63 | self.mab0 = MAB(dim_out, dim_in, dim_out, num_heads, ln=ln) 64 | self.mab1 = MAB(dim_in, dim_out, dim_out, num_heads, ln=ln) 65 | 66 | def forward(self, X): 67 | H = self.mab0(self.I.repeat(X.size(0), 1, 1), X) 68 | return self.mab1(X, H) 69 | 70 | 71 | class PMA(nn.Module): 72 | def __init__(self, dim, num_heads, num_seeds, ln=False): 73 | super(PMA, self).__init__() 74 | self.S = nn.Parameter(torch.Tensor(1, num_seeds, dim)) 75 | nn.init.xavier_uniform_(self.S) 76 | self.mab = MAB(dim, dim, dim, num_heads, ln=ln) 77 | 78 | def forward(self, X): 79 | return self.mab(self.S.repeat(X.size(0), 1, 1), X) 80 | 81 | 82 | -------------------------------------------------------------------------------- /mil_election_classification/run_trials.py: -------------------------------------------------------------------------------- 1 | import multiprocessing as mp 2 | from deepset_main import run_deepset 3 | from set_transformer_main import run_set_transformer 4 | 5 | 6 | # def run_bgcn(blank, r): 7 | # 8 | # run_set_transformer(num_neib=5, r=r, alg_name='BGCN') 9 | 10 | 11 | if __name__ == '__main__': 12 | run_deepset(num_neib=0, r=0, alg_name='vanilla') 13 | run_deepset(num_neib=5, r=0, alg_name='GCN') 14 | run_deepset(num_neib=5, r=1, alg_name='BGCN') 15 | 16 | 17 | -------------------------------------------------------------------------------- /mil_election_classification/set_transformer_main.py: -------------------------------------------------------------------------------- 1 | import numpy as np # linear algebra 2 | import pandas as pd # data processing, CSV file I/O (e.g. pd.read_csv) 3 | import matplotlib.pyplot as plt 4 | from numpy import matlib 5 | import seaborn as sns 6 | from scipy import stats 7 | from sklearn.preprocessing import StandardScaler 8 | from torch import optim 9 | import torch.nn as nn 10 | import matplotlib as mpl 11 | import networkx as nx 12 | import scipy.sparse as sp 13 | import pickle as pk 14 | import time 15 | import torch 16 | import sys 17 | import glob 18 | import os 19 | from models_all_set_transformer import * 20 | from sklearn.neighbors import kneighbors_graph 21 | from scipy.stats import wilcoxon 22 | from utils import * 23 | from sklearn import preprocessing 24 | from subprocess import call 25 | mpl.rcParams['figure.dpi'] = 600 26 | color = sns.color_palette() 27 | pd.options.mode.chained_assignment = None # default='warn' 28 | 29 | 30 | def run_set_transformer(num_neib, r, alg_name='vanilla'): 31 | num_trials = 100 32 | lr_ = 1e-3 33 | weight_decay_ = 1e-4 34 | EPOCHS = 200 35 | MC_samples = 5 36 | 37 | log_dir = 'log_st' 38 | 39 | check = os.path.isdir(log_dir) 40 | if not check: 41 | os.makedirs(log_dir) 42 | print("created folder : ", log_dir) 43 | else: 44 | print(log_dir, " folder already exists.") 45 | 46 | file_name = 'data/centroids_cartesian_10.csv' 47 | location = np.loadtxt(open(file_name, "rb"), delimiter=",", skiprows=1, usecols=range(1, 3)) 48 | print(location.shape) 49 | 50 | file_name = 'data/results-2016-election.csv' 51 | votes = np.loadtxt(open(file_name, "rb"), delimiter=",", skiprows=1, usecols=range(1, 4)) 52 | votes = votes[:, [0, 2]] # keeping republican and democrat 53 | votes = votes[:, [0]] / votes.sum(axis=1, keepdims=True) # republican vote probability 54 | print(votes.shape) 55 | 56 | county_list = os.listdir('data/cleaned_features') 57 | # print(county_list) 58 | 59 | set_list = [] 60 | 61 | for file_name_ in county_list: 62 | with open('data/cleaned_features/' + file_name_, 'rb') as f: 63 | set_ = pk.load(f) 64 | set_list.append(set_) 65 | 66 | def sample_dataset(n=100, seed=0): 67 | np.random.seed(seed) 68 | 69 | X = [] 70 | for set_ in set_list: 71 | row_id = np.random.choice(set_.shape[0], n, replace=False) 72 | X.append(set_[row_id, :]) 73 | 74 | X = np.array(X) 75 | X = (X - np.mean(X, axis=(0, 1), keepdims=True))/np.std(X, axis=(0, 1), keepdims=True) 76 | return X 77 | 78 | def get_performance_test(trial_i): 79 | torch.manual_seed(trial_i) 80 | np.random.seed(trial_i) 81 | 82 | indices = np.arange(0, votes.shape[0]) 83 | np.random.shuffle(indices) 84 | num_train = int(0.025 * votes.shape[0]) 85 | idx_train = indices[:num_train] 86 | idx_test = indices[num_train:] 87 | np.savetxt(log_dir + '/idx_test' + '_trial_' + str(trial_i) + '.txt', idx_test, delimiter=',') 88 | 89 | X = sample_dataset(n=100, seed=trial_i) 90 | X = np.array(X, dtype='float') 91 | 92 | features = torch.FloatTensor(X) 93 | labels = torch.FloatTensor(votes) 94 | 95 | acc = [] 96 | 97 | if alg_name == 'vanilla': 98 | models = [ 99 | SetTransformer(in_features=94) 100 | ] 101 | elif alg_name == 'GCN': 102 | models = [ 103 | STGCN(in_features=94) 104 | ] 105 | 106 | A = kneighbors_graph(location, num_neib, mode='connectivity', include_self=True) 107 | G = nx.from_scipy_sparse_matrix(A) 108 | adj = nx.to_numpy_array(G) 109 | adj_norm = normalize(adj) 110 | adj_norm = torch.FloatTensor(adj_norm) 111 | 112 | def weight_reset(m): 113 | reset_parameters = getattr(m, "reset_parameters", None) 114 | if callable(reset_parameters): 115 | m.reset_parameters() 116 | 117 | for model in models: 118 | model.apply(weight_reset) 119 | 120 | bce = nn.BCELoss() 121 | 122 | def train(epoch, adj_=None): 123 | t = time.time() 124 | model.train() 125 | optimizer.zero_grad() 126 | if adj_ is not None: 127 | output, _ = model(features, adj_) 128 | else: 129 | output, _ = model(features) 130 | loss_train = bce(output[idx_train], labels[idx_train]) 131 | loss_train.backward() 132 | optimizer.step() 133 | 134 | loss_test = bce(output[idx_test], labels[idx_test]) 135 | acc_train = accuracy(labels[idx_train], output[idx_train]) 136 | acc_test = accuracy(labels[idx_test], output[idx_test]) 137 | # if (epoch+1) % 10 == 0: 138 | # print('Epoch: {:04d}'.format(epoch+1), 139 | # 'loss_train: {:.4f}'.format(loss_train.item()), 140 | # 'loss_test: {:.4f}'.format(loss_test.item()), 141 | # 'acc_train: {:.4f}'.format(acc_train.item()), 142 | # 'acc_test: {:.4f}'.format(acc_test.item()), 143 | # 'time: {:.4f}s'.format(time.time() - t)) 144 | return acc_test.item() 145 | 146 | for network in (models): 147 | 148 | t_total = time.time() 149 | # Model and optimizer 150 | model = network 151 | 152 | no_decay = list() 153 | decay = list() 154 | for m in model.modules(): 155 | if isinstance(m, torch.nn.Linear) or isinstance(m, GraphConvolution): 156 | decay.append(m.weight) 157 | no_decay.append(m.bias) 158 | 159 | optimizer = optim.Adam([{'params': no_decay, 'weight_decay': 0}, {'params': decay, 'weight_decay': weight_decay_}], lr=lr_) 160 | 161 | # Train model 162 | t_total = time.time() 163 | for epoch in range(EPOCHS): 164 | if isinstance(model, SetTransformer): 165 | value = train(epoch) 166 | else: 167 | value = train(epoch, adj_norm) 168 | 169 | model.eval() 170 | with torch.no_grad(): 171 | if isinstance(model, (SetTransformer)): 172 | output, decoding = model(features) 173 | np.savetxt(log_dir + '/decoding_set_transformer' + '_trial_' + str(trial_i) + '_num_neib_' + str(num_neib) + '_r_' + str(r) + '.txt', decoding, delimiter=',') 174 | np.savetxt(log_dir + '/output_set_transformer' + '_trial_' + str(trial_i) + '_num_neib_' + str(num_neib) + '_r_' + str(r) + '.txt', output, delimiter=',') 175 | 176 | if isinstance(model, (STGCN)): 177 | output, decoding = model(features, adj_norm) 178 | np.savetxt(log_dir + '/decoding_set_transformer_gcn' + '_trial_' + str(trial_i) + '_num_neib_' + str(num_neib) + '_r_' + str(r) + '.txt', decoding, delimiter=',') 179 | np.savetxt(log_dir + '/output_set_transformer_gcn' + '_trial_' + str(trial_i) + '_num_neib_' + str(num_neib) + '_r_' + str(r) + '.txt', output, delimiter=',') 180 | 181 | acc.append(accuracy(labels[idx_test], output[idx_test])) 182 | model.apply(weight_reset) 183 | return acc 184 | 185 | if alg_name == 'vanilla' or alg_name == 'GCN': 186 | tests = [] 187 | if alg_name == 'vanilla': 188 | print('SetTransformer') 189 | elif alg_name == 'GCN': 190 | print('STGCN') 191 | for i in range(num_trials): 192 | t = time.time() 193 | acc_all = get_performance_test(i) 194 | tests.append(acc_all) 195 | print('run : ' + str(i+1) + ', accuracy : ' + str(np.array(acc_all)) + ', run time: {:.2f}s'.format(time.time() - t)) 196 | 197 | tests = np.array(tests) 198 | 199 | if alg_name == 'vanilla': 200 | df_test = pd.DataFrame(tests, columns=['SetTransformer']) 201 | elif alg_name == 'GCN': 202 | df_test = pd.DataFrame(tests, columns=['STGCN']) 203 | 204 | print(df_test.mean(axis=0)) 205 | 206 | def get_performance_test_bayesian(trial_i): 207 | torch.manual_seed(trial_i) 208 | np.random.seed(trial_i) 209 | 210 | indices = np.arange(0, votes.shape[0]) 211 | np.random.shuffle(indices) 212 | num_train = int(0.025 * votes.shape[0]) 213 | idx_train = indices[:num_train] 214 | idx_test = indices[num_train:] 215 | 216 | X = sample_dataset(n=100, seed=trial_i) 217 | X = np.array(X, dtype='float') 218 | 219 | features = torch.FloatTensor(X) 220 | labels = torch.FloatTensor(votes) 221 | 222 | acc = [] 223 | 224 | if alg_name == 'BGCN': 225 | models = [ 226 | STGCN(in_features=94) 227 | ] 228 | 229 | def weight_reset(m): 230 | reset_parameters = getattr(m, "reset_parameters", None) 231 | if callable(reset_parameters): 232 | m.reset_parameters() 233 | 234 | for model in models: 235 | model.apply(weight_reset) 236 | 237 | bce = nn.BCELoss() 238 | 239 | def train_(epoch, adj_=None): 240 | t = time.time() 241 | model.train() 242 | optimizer.zero_grad() 243 | if adj_ is not None: 244 | output, _ = model(features, adj_) 245 | else: 246 | output, _ = model(features) 247 | loss_train = bce(output[idx_train], labels[idx_train]) 248 | loss_train.backward() 249 | optimizer.step() 250 | 251 | loss_test = bce(output[idx_test], labels[idx_test]) 252 | acc_train = accuracy(labels[idx_train], output[idx_train]) 253 | acc_test = accuracy(labels[idx_test], output[idx_test]) 254 | return acc_test.item(), output 255 | 256 | for network in (models): 257 | 258 | t_total = time.time() 259 | # Model and optimizer 260 | model = network 261 | 262 | if isinstance(model, (STGCN)): 263 | embedding = np.loadtxt(log_dir + '/decoding_set_transformer_gcn' + '_trial_' + str(trial_i) + '_num_neib_' + str(num_neib) + '_r_0.txt', delimiter=",") 264 | adj_np = MAP_inference(embedding, num_neib, r) 265 | adj_np_norm = normalize(adj_np) 266 | adj_norm = torch.FloatTensor(adj_np_norm) 267 | 268 | no_decay = list() 269 | decay = list() 270 | for m in model.modules(): 271 | if isinstance(m, torch.nn.Linear) or isinstance(m, GraphConvolution): 272 | decay.append(m.weight) 273 | no_decay.append(m.bias) 274 | 275 | optimizer = optim.Adam([{'params': no_decay, 'weight_decay': 0}, {'params': decay, 'weight_decay': weight_decay_}], lr=lr_) 276 | 277 | # Train model 278 | t_total = time.time() 279 | output_ = 0.0 280 | for epoch in range(EPOCHS): 281 | if isinstance(model, SetTransformer): 282 | value, output = train_(epoch) 283 | else: 284 | value, output = train_(epoch, adj_norm) 285 | 286 | if epoch >= EPOCHS - MC_samples: 287 | output_ += output 288 | 289 | output = output_ / np.float32(MC_samples) 290 | np.savetxt(log_dir + '/output_set_transformer_bgcn' + '_trial_' + str(trial_i) + '_num_neib_' + str(num_neib) + '_r_' + str(r) + '.txt', output.detach().numpy(), delimiter=',') 291 | 292 | acc.append(accuracy(labels[idx_test], output[idx_test])) 293 | model.apply(weight_reset) 294 | adj_norm = None 295 | return acc 296 | 297 | if alg_name == 'BGCN': 298 | tests_bayesian = [] 299 | print('B-STGCN') 300 | for i in range(num_trials): 301 | t = time.time() 302 | acc_bayes = get_performance_test_bayesian(i) 303 | tests_bayesian.append(acc_bayes) 304 | print('run : ' + str(i+1) + ', accuracy : ' + str(np.array(acc_bayes)) + ', run time: {:.2f}s'.format(time.time() - t)) 305 | 306 | tests_bayesian = np.array(tests_bayesian) 307 | 308 | df_test_bayesian = pd.DataFrame(tests_bayesian, columns=['B-STGCN']) 309 | 310 | print(df_test_bayesian.mean(axis=0)) 311 | 312 | if alg_name == 'BGCN': 313 | df_test_bayesian.to_csv('accuracy_set_transformer_num_neib_' + str(num_neib) + '_r_' + str(r) + '_' + alg_name + '.csv', index=False) 314 | 315 | else: 316 | df_test.to_csv('accuracy_set_transformer_num_neib_' + str(num_neib) + '_r_' + str(r) + '_' + alg_name + '.csv', index=False) 317 | -------------------------------------------------------------------------------- /mil_election_classification/utils.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from numpy import matlib 3 | from scipy.stats import t 4 | import scipy.sparse as sp 5 | from math import sqrt 6 | from statistics import stdev 7 | 8 | 9 | def normalize(mx): 10 | """Row-normalize sparse matrix""" 11 | rowsum = np.array(mx.sum(1)) 12 | r_inv = np.power(rowsum, -1).flatten() 13 | r_inv[np.isinf(r_inv)] = 0. 14 | r_inv_sqrt = np.sqrt(r_inv) 15 | r_mat_inv_sqrt = np.diag(r_inv_sqrt) 16 | mx = r_mat_inv_sqrt.dot(mx) 17 | mx = mx.dot(r_mat_inv_sqrt) 18 | return mx 19 | 20 | 21 | def accuracy(labels, output): 22 | labels = (labels > 0.5).type_as(labels) 23 | preds = (output > 0.5).type_as(labels) 24 | correct = preds.eq(labels).double() 25 | correct = correct.sum() 26 | return correct / len(labels) 27 | 28 | 29 | def compute_distance(embed): 30 | N = embed.shape[0] 31 | p = np.dot(embed, np.transpose(embed)) 32 | q = np.matlib.repmat(np.diag(p), N, 1) 33 | dist = q + np.transpose(q) - 2 * p 34 | dist[dist < 1e-8] = 1e-8 35 | return dist 36 | 37 | 38 | def estimate_graph(gamma, epsilon, dist, max_iter, k, r): 39 | np.random.seed(0) 40 | 41 | N = dist.shape[0] 42 | dist += 1e10 * np.eye(N) 43 | 44 | deg_exp = np.minimum(int(N-1), int(k * r)) 45 | 46 | dist_sort_col_idx = np.argsort(dist, axis=0) 47 | dist_sort_col_idx = np.transpose(dist_sort_col_idx[0:deg_exp, :]) 48 | 49 | dist_sort_row_idx = np.matlib.repmat(np.arange(N).reshape(N, 1), 1, deg_exp) 50 | 51 | dist_sort_col_idx = np.reshape(dist_sort_col_idx, int(N * deg_exp)).astype(int) 52 | dist_sort_row_idx = np.reshape(dist_sort_row_idx, int(N * deg_exp)).astype(int) 53 | 54 | dist_idx = np.zeros((int(N * deg_exp), 2)).astype(int) 55 | dist_idx[:, 0] = dist_sort_col_idx 56 | dist_idx[:, 1] = dist_sort_row_idx 57 | dist_idx = np.sort(dist_idx, axis=1) 58 | dist_idx = np.unique(dist_idx, axis=0) 59 | dist_sort_col_idx = dist_idx[:, 0] 60 | dist_sort_row_idx = dist_idx[:, 1] 61 | 62 | num_edges = len(dist_sort_col_idx) 63 | 64 | w_init = np.random.uniform(0, 1, size=(num_edges, 1)) 65 | d_init = k * np.random.uniform(0, 1, size=(N, 1)) 66 | 67 | w_current = w_init 68 | d_current = d_init 69 | 70 | dist_sorted = np.sort(dist, axis=0) 71 | 72 | B_k = np.sum(dist_sorted[0:k, :], axis=0) 73 | dist_sorted_k = dist_sorted[k-1, :] 74 | dist_sorted_k_plus_1 = dist_sorted[k, :] 75 | 76 | theta_lb = 1 / np.sqrt(k * dist_sorted_k_plus_1 ** 2 - B_k * dist_sorted_k_plus_1) 77 | theta_lb = theta_lb[~np.isnan(theta_lb)] 78 | theta_lb = theta_lb[~np.isinf(theta_lb)] 79 | theta_lb = np.mean(theta_lb) 80 | 81 | theta_ub = 1 / np.sqrt(k * dist_sorted_k ** 2 - B_k * dist_sorted_k) 82 | theta_ub = theta_ub[~np.isnan(theta_ub)] 83 | theta_ub = theta_ub[~np.isinf(theta_ub)] 84 | if len(theta_ub) > 0: 85 | theta_ub = np.mean(theta_ub) 86 | else: 87 | theta_ub = theta_lb 88 | 89 | theta = (theta_lb + theta_ub) / 2 90 | 91 | dist = theta * dist 92 | 93 | z = dist[dist_sort_row_idx, dist_sort_col_idx] 94 | z.shape = (num_edges, 1) 95 | 96 | for iter in range(max_iter): 97 | 98 | # print('Graph inference epoch : ' + str(iter)) 99 | 100 | St_times_d = d_current[dist_sort_row_idx] + d_current[dist_sort_col_idx] 101 | y_current = w_current - gamma * (2 * w_current + St_times_d) 102 | 103 | adj_current = np.zeros((N, N)) 104 | adj_current[dist_sort_row_idx, dist_sort_col_idx] = np.squeeze(w_current) 105 | adj_current = adj_current + np.transpose(adj_current) 106 | S_times_w = np.sum(adj_current, axis=1) 107 | S_times_w.shape = (N, 1) 108 | y_bar_current = d_current + gamma * S_times_w 109 | 110 | p_current = np.maximum(0, np.abs(y_current) - 2 * gamma * z) 111 | p_bar_current = (y_bar_current - np.sqrt(y_bar_current * y_bar_current + 4 * gamma)) / 2 112 | 113 | St_times_p_bar = p_bar_current[dist_sort_row_idx] + p_bar_current[dist_sort_col_idx] 114 | q_current = p_current - gamma * (2 * p_current + St_times_p_bar) 115 | 116 | p_matrix_current = np.zeros((N, N)) 117 | p_matrix_current[dist_sort_row_idx, dist_sort_col_idx] = np.squeeze(p_current) 118 | p_matrix_current = p_matrix_current + np.transpose(p_matrix_current) 119 | S_times_p = np.sum(p_matrix_current, axis=1) 120 | S_times_p.shape = (N, 1) 121 | q_bar_current = p_bar_current + gamma * S_times_p 122 | 123 | w_updated = np.abs(w_current - y_current + q_current) 124 | d_updated = np.abs(d_current - y_bar_current + q_bar_current) 125 | 126 | if (np.linalg.norm(w_updated - w_current) / np.linalg.norm(w_current) < epsilon) and \ 127 | (np.linalg.norm(d_updated - d_current) / np.linalg.norm(d_current) < epsilon): 128 | break 129 | else: 130 | w_current = w_updated 131 | d_current = d_updated 132 | 133 | upper_tri_index = np.triu_indices(N, k=1) 134 | 135 | z = dist[upper_tri_index[0], upper_tri_index[1]] 136 | z.shape = (int(N * (N - 1) / 2), 1) 137 | z = z * np.max(w_current) 138 | 139 | w_current = w_current / np.max(w_current) 140 | 141 | inferred_graph = np.zeros((N, N)) 142 | inferred_graph[dist_sort_row_idx, dist_sort_col_idx] = np.squeeze(w_current) 143 | inferred_graph = inferred_graph + np.transpose(inferred_graph) + np.eye(N) 144 | 145 | return inferred_graph 146 | 147 | 148 | def MAP_inference(x, num_neib, r): 149 | N = x.shape[0] 150 | k = int(num_neib) 151 | 152 | dist = compute_distance(x) 153 | 154 | inferred_graph = estimate_graph(0.01, 0.001, dist, 1000, k, r) 155 | 156 | return inferred_graph 157 | -------------------------------------------------------------------------------- /mil_election_classification/visualization.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import matplotlib.pyplot as plt 3 | import matplotlib.patches as mpatches 4 | import pandas as pd 5 | import seaborn as sns 6 | 7 | plt.rcParams["figure.figsize"] = (12,6) 8 | 9 | def super_impose(number: int, left: int = 150, bottom: int = 3500, width: int = 600, height: int = 950): 10 | file_name = 'log_ds/idx_test_trial_{}.txt'.format(number) 11 | test_indices = np.loadtxt(file_name, dtype=np.int) 12 | # print(sorted(test_indices)) 13 | # print(len(test_indices)) 14 | 15 | file_name = 'data/centroids_cartesian_10.csv' 16 | location = np.loadtxt(open(file_name, "rb"), delimiter=",", skiprows=1, usecols=range(1, 3)) 17 | print(location.shape) 18 | 19 | file_name = 'data/results-2016-election.csv' 20 | votes = np.loadtxt(open(file_name, "rb"), delimiter=",", skiprows=1, usecols=range(1, 4)) 21 | print(votes.shape) 22 | votes = votes/votes.sum(axis=1, keepdims=True) 23 | 24 | file_name = 'log_ds/output_deepset_bgcn_trial_{}_num_neib_5_r_1.txt'.format(number) 25 | preds_bgcn = np.loadtxt(open(file_name, "rb"), delimiter=",") 26 | 27 | file_name = 'log_ds/output_deepset_gcn_trial_{}_num_neib_5_r_0.txt'.format(number) 28 | preds_gcn = np.loadtxt(open(file_name, "rb"), delimiter=",") 29 | 30 | file_name = 'log_ds/output_deepset_trial_{}_num_neib_0_r_0.txt'.format(number) 31 | preds_deepset = np.loadtxt(open(file_name, "rb"), delimiter=",") 32 | 33 | d = { 34 | 'x': location[:, 0], 35 | 'y': -location[:, 1], 36 | 'preds_deepset': preds_deepset, 37 | 'preds_gcn': preds_gcn, 38 | 'preds_bgcn': preds_bgcn, 39 | 'vote': votes[:, 0], 40 | 'is_test': np.zeros(votes.shape[0]) 41 | } 42 | 43 | df = pd.DataFrame(data=d).convert_dtypes() 44 | 45 | for i in range(len(test_indices)): 46 | df.iloc[test_indices[i], -1] = 1 47 | 48 | df = df.loc[df['x'] > -3000] 49 | df = df.loc[df['y'] > 2000] 50 | 51 | df1 = pd.DataFrame(data=d).convert_dtypes() 52 | 53 | for i in range(len(test_indices)): 54 | df1.iloc[test_indices[i], -1] = 1 55 | 56 | df1 = df1.loc[df1['x'] >= left] 57 | df1 = df1.loc[df1['y'] >= bottom] 58 | df1 = df1.loc[df1['x'] <= left + width] 59 | df1 = df1.loc[df1['y'] <= bottom + height] 60 | 61 | df1['x'] = df1['x'] + 2500 62 | 63 | df1['x'] = (df1['x'] - 950) 64 | df1['y'] = (df1['y'] - 600) 65 | 66 | df_train = df1.loc[df1['is_test'] != 0] 67 | df_test = df1.loc[df1['is_test'] == 0] 68 | 69 | print(df_train.shape) 70 | print(df_test.shape) 71 | 72 | fig, axs = plt.subplots(2, 2) 73 | mynorm = plt.Normalize(vmin=0.3, vmax=0.7) 74 | 75 | 76 | ff = 16 77 | 78 | axs[0, 0].scatter(x=df['x'], y=df['y'], c=df['preds_deepset'], cmap='coolwarm', norm=mynorm, s=20) 79 | axs[0, 0].set_title('Deep Sets', fontsize=ff) 80 | axs[0, 1].scatter(x=df['x'], y=df['y'], c=df['preds_gcn'], cmap='coolwarm', norm=mynorm, s=20) 81 | axs[0, 1].set_title('DS-GCN', fontsize=ff) 82 | axs[1, 0].scatter(x=df['x'], y=df['y'], c=df['preds_bgcn'], cmap='coolwarm', norm=mynorm, s=20) 83 | axs[1, 0].set_title('B-DS-GCN', fontsize=ff) 84 | axs[1, 1].scatter(x=df['x'], y=df['y'], c=df['vote'], cmap='coolwarm', norm=mynorm, s=20) 85 | axs[1, 1].set_title('True Election Results', fontsize=ff) 86 | 87 | 88 | ## Plot train only 89 | axs[0, 0].scatter(x=df_train['x'], y=df_train['y'], c=df_train['preds_deepset'], cmap='coolwarm', norm=mynorm, s=20) 90 | axs[0, 1].scatter(x=df_train['x'], y=df_train['y'], c=df_train['preds_gcn'], cmap='coolwarm', norm=mynorm, s=20) 91 | axs[1, 0].scatter(x=df_train['x'], y=df_train['y'], c=df_train['preds_bgcn'], cmap='coolwarm', norm=mynorm, s=20) 92 | axs[1, 1].scatter(x=df_train['x'], y=df_train['y'], c=df_train['vote'], cmap='coolwarm', norm=mynorm, s=20) 93 | 94 | ## Plot test only 95 | axs[0, 0].scatter(x=df_test['x'], y=df_test['y'], c=df_test['preds_deepset'], cmap='coolwarm', norm=mynorm, s=20) 96 | axs[0, 1].scatter(x=df_test['x'], y=df_test['y'], c=df_test['preds_gcn'], cmap='coolwarm', norm=mynorm, s=20) 97 | axs[1, 0].scatter(x=df_test['x'], y=df_test['y'], c=df_test['preds_bgcn'], cmap='coolwarm', norm=mynorm, s=20) 98 | axs[1, 1].scatter(x=df_test['x'], y=df_test['y'], c=df_test['vote'], cmap='coolwarm', norm=mynorm, s=20) 99 | 100 | 101 | # MID-WEST 102 | # specify the location of (left,bottom),width,height 103 | rect1 = mpatches.Rectangle(((275 + 1500)*1.5-1100, 2800), 835, 1150, 104 | fill=False, 105 | color="purple", 106 | linewidth=2) 107 | rect2 = mpatches.Rectangle(((275 + 1500)*1.5-1100, 2800), 835, 1150, 108 | fill=False, 109 | color="purple", 110 | linewidth=2) 111 | rect3 = mpatches.Rectangle(((275 + 1500)*1.5-1100, 2800), 835, 1150, 112 | fill=False, 113 | color="purple", 114 | linewidth=2) 115 | rect4 = mpatches.Rectangle(((275 + 1500)*1.5-1100, 2800), 835, 1150, 116 | fill=False, 117 | color="purple", 118 | linewidth=2) 119 | 120 | axs[0, 0].add_patch(rect1) 121 | axs[0, 1].add_patch(rect2) 122 | axs[1, 0].add_patch(rect3) 123 | axs[1, 1].add_patch(rect4) 124 | 125 | 126 | # specify the location of (left,bottom),width,height 127 | rect11 = mpatches.Rectangle((235, 3650), 385, 700, 128 | fill=False, 129 | color="purple", 130 | linewidth=2) 131 | rect22 = mpatches.Rectangle((235, 3650), 385, 700, 132 | fill=False, 133 | color="purple", 134 | linewidth=2) 135 | rect33 = mpatches.Rectangle((235, 3650), 385, 700, 136 | fill=False, 137 | color="purple", 138 | linewidth=2) 139 | rect44 = mpatches.Rectangle((235, 3650), 385, 700, 140 | fill=False, 141 | color="purple", 142 | linewidth=2) 143 | # facecolor="red") 144 | axs[0, 0].add_patch(rect11) 145 | axs[0, 1].add_patch(rect22) 146 | axs[1, 0].add_patch(rect33) 147 | axs[1, 1].add_patch(rect44) 148 | 149 | 150 | for ax in axs.flat: 151 | ax.set(xlabel='', ylabel='') 152 | 153 | # Hide x labels and tick labels for top plots and y ticks for right plots. 154 | for ax in axs.flat: 155 | ax.label_outer() 156 | 157 | # remove the x and y ticks 158 | for ax in axs.flat: 159 | ax.set_xticks([]) 160 | ax.set_yticks([]) 161 | 162 | fig.tight_layout() 163 | 164 | plt.savefig('election.pdf') 165 | 166 | # plt.subplots_adjust(left=0.125, 167 | # bottom=0.1, 168 | # right=0.9, 169 | # top=0.9, 170 | # wspace=0.2, 171 | # hspace=0.35) 172 | 173 | 174 | plt.show() 175 | 176 | 177 | 178 | if __name__ == "__main__": 179 | super_impose(8) 180 | -------------------------------------------------------------------------------- /mil_rental_data/deepset_main.py: -------------------------------------------------------------------------------- 1 | import numpy as np # linear algebra 2 | import pandas as pd # data processing, CSV file I/O (e.g. pd.read_csv) 3 | import matplotlib.pyplot as plt 4 | from numpy import matlib 5 | import seaborn as sns 6 | from scipy import stats 7 | from sklearn.preprocessing import StandardScaler 8 | from torch import optim 9 | import torch.nn as nn 10 | import matplotlib as mpl 11 | import scipy.sparse as sp 12 | import time 13 | import torch 14 | import sys 15 | import glob 16 | import os 17 | from models_all_deepset import * 18 | from scipy.stats import wilcoxon 19 | from utils import * 20 | from sklearn import preprocessing 21 | from subprocess import call 22 | mpl.rcParams['figure.dpi'] = 600 23 | color = sns.color_palette() 24 | pd.options.mode.chained_assignment = None # default='warn' 25 | torch.manual_seed(0) 26 | np.random.seed(0) 27 | 28 | num_trials = 100 29 | lr_ = 5e-4 30 | weight_decay_ = 1e-3 31 | EPOCHS = 500 32 | MC_samples = 20 33 | 34 | adj_orig = np.loadtxt('data/adj_nbhd.txt', dtype='float', delimiter=',') 35 | print(adj_orig) 36 | num_neib = int(np.sum(adj_orig)/adj_orig.shape[0]) - 1 37 | print(num_neib) 38 | 39 | files = glob.glob('*.pkl') 40 | for file in files: 41 | os.remove(file) 42 | log_dir = 'log_ds' 43 | 44 | check = os.path.isdir(log_dir) 45 | if not check: 46 | os.makedirs(log_dir) 47 | print("created folder : ", log_dir) 48 | else: 49 | print(log_dir, " folder already exists.") 50 | 51 | for f in os.listdir(log_dir): 52 | os.remove(os.path.join(log_dir, f)) 53 | 54 | train_df = pd.read_csv('data/neighbourhood_data.csv') 55 | districts = pd.read_csv('data/districts_data_cleaned.csv') 56 | names = list(districts['name'].unique()) 57 | train_df = train_df[train_df['neighbourhood'].isin(names)] 58 | counts = train_df['neighbourhood'].value_counts() 59 | count_list = counts[counts > 50].index.tolist() 60 | train_df = train_df[train_df['neighbourhood'].isin(count_list)] 61 | 62 | 63 | train_df = train_df.loc[:, ~train_df.columns.str.contains('^Unnamed')] 64 | train_df = train_df.astype({'bathrooms': 'int64'}) 65 | train_df = train_df.astype({'bedrooms': 'int64'}) 66 | train_df = train_df.astype({'interest_level': 'category'}) 67 | train_df = train_df.astype({'num_photos': 'int64'}) 68 | train_df = train_df.astype({'num_features': 'int64'}) 69 | train_df = train_df.astype({'num_description_words': 'int64'}) 70 | train_df = train_df.astype({'created_month': 'category'}) 71 | train_df = train_df.astype({'created_day': 'category'}) 72 | train_df = train_df.astype({'neighbourhood': 'str'}) 73 | train_df = train_df.astype({'price': 'float64'}) 74 | non_standardized = train_df.copy() 75 | train_df['price'] = (train_df['price'] - train_df['price'].mean()) / train_df['price'].std() 76 | 77 | train_df['interest_level'] = pd.Categorical(train_df['interest_level'], categories=train_df['interest_level'].unique()).codes 78 | train_df['neighbourhood'] = pd.Categorical(train_df['neighbourhood'], categories=train_df['neighbourhood'].unique()).codes 79 | 80 | 81 | def train_Set(n=25, seed=None): 82 | train_samples = train_df.groupby('neighbourhood').apply(pd.DataFrame.sample, n, replace = False, random_state=seed) 83 | train_samples = train_samples.astype({'price': 'float64'}) 84 | features = [] 85 | for i, hood in enumerate(train_samples['neighbourhood'].unique()): 86 | get_boro = train_samples[train_samples['neighbourhood'] == hood] 87 | sample = get_boro.to_numpy() 88 | features.append(sample) 89 | return np.array(features) 90 | 91 | 92 | non_standardized['price'] = non_standardized['price'].clip(upper = np.percentile(non_standardized['price'].values, 95)) 93 | df = train_df[['price', 'neighbourhood']] 94 | mean = df.groupby('neighbourhood').mean().values 95 | 96 | 97 | def get_performance_test(trial_i): 98 | torch.manual_seed(trial_i) 99 | np.random.seed(trial_i) 100 | 101 | indices = np.arange(0, mean.shape[0]) 102 | np.random.shuffle(indices) 103 | idx_train = indices[:55] 104 | idx_test = indices[55:] 105 | 106 | X = train_Set(n=25, seed=trial_i) 107 | X = np.array(X, dtype='float') 108 | 109 | features = torch.FloatTensor(X) 110 | labels = torch.FloatTensor(mean) 111 | 112 | mserr = [] 113 | maerr = [] 114 | mperr = [] 115 | 116 | adj_norm = normalize(adj_orig) 117 | adj_norm = torch.FloatTensor(adj_norm) 118 | 119 | models = [ 120 | DeepSet(in_features=10), 121 | DSGCN(in_features=10) 122 | ] 123 | 124 | def weight_reset(m): 125 | reset_parameters = getattr(m, "reset_parameters", None) 126 | if callable(reset_parameters): 127 | m.reset_parameters() 128 | 129 | for model in models: 130 | model.apply(weight_reset) 131 | 132 | mse = nn.MSELoss() 133 | mae = nn.L1Loss() 134 | 135 | def train(epoch, adj_=None): 136 | t = time.time() 137 | model.train() 138 | optimizer.zero_grad() 139 | if adj_ is not None: 140 | output, _ = model(features, adj_) 141 | else: 142 | output, _ = model(features) 143 | loss_train = mse(output[idx_train], labels[idx_train]) 144 | loss_train.backward() 145 | optimizer.step() 146 | 147 | loss_test = mse(output[idx_test], labels[idx_test]) 148 | return loss_test.item() 149 | 150 | for network in (models): 151 | 152 | t_total = time.time() 153 | # Model and optimizer 154 | model = network 155 | 156 | no_decay = list() 157 | decay = list() 158 | for m in model.modules(): 159 | if isinstance(m, torch.nn.Linear) or isinstance(m, GraphConvolution): 160 | decay.append(m.weight) 161 | no_decay.append(m.bias) 162 | 163 | optimizer = optim.Adam([{'params': no_decay, 'weight_decay': 0}, {'params': decay, 'weight_decay': weight_decay_}], lr=lr_) 164 | 165 | # Train model 166 | t_total = time.time() 167 | for epoch in range(EPOCHS): 168 | if isinstance(model, DeepSet): 169 | value = train(epoch) 170 | else: 171 | value = train(epoch, adj_norm) 172 | 173 | model.eval() 174 | with torch.no_grad(): 175 | if isinstance(model, (DeepSet)): 176 | output, decoding = model(features) 177 | np.savetxt(log_dir + '/decoding_deepset_' + '_trial_' + str(trial_i), decoding, delimiter=',') 178 | 179 | if isinstance(model, (DSGCN)): 180 | output, decoding = model(features, adj_norm) 181 | np.savetxt(log_dir + '/decoding_deepset_gcn' + '_trial_' + str(trial_i), decoding, delimiter=',') 182 | 183 | maerr.append(non_standardized['price'].std() * (mae(output[idx_test], labels[idx_test]).detach().cpu().numpy())) 184 | mserr.append(non_standardized['price'].std() * np.sqrt(mse(output[idx_test], labels[idx_test]).detach().numpy())) 185 | 186 | target = (non_standardized['price'].std() * labels[idx_test]).detach().cpu().numpy() + non_standardized['price'].mean() 187 | pred = (non_standardized['price'].std() * output[idx_test].T).detach().cpu().numpy() + non_standardized['price'].mean() 188 | 189 | mperr.append(100 * np.mean(np.abs((target - pred) / target))) 190 | model.apply(weight_reset) 191 | return mserr, maerr, mperr 192 | 193 | 194 | mse = [] 195 | mae = [] 196 | mpe = [] 197 | print('DeepSet', 'DSGCN') 198 | for i in range(num_trials): 199 | t = time.time() 200 | mserr, maerr, mperr = get_performance_test(i) 201 | mse.append(mserr) 202 | mae.append(maerr) 203 | mpe.append(mperr) 204 | print('run : ' + str(i+1) + ', RMSE : ' + str(np.around(np.array(mserr), 2)) + ', MAE : ' + str(np.around(np.array(maerr), 2)) + 205 | ', MAPE : ' + str(np.around(np.array(mperr), 2)) + ', run time: {:.2f}s'.format(time.time() - t)) 206 | 207 | 208 | mse = np.array(mse) 209 | mae = np.array(mae) 210 | mpe = np.array(mpe) 211 | 212 | columns_ = ['DeepSet', 'DSGCN'] 213 | df_mse = pd.DataFrame(mse, columns=columns_) 214 | df_mae = pd.DataFrame(mae, columns=columns_) 215 | df_mpe = pd.DataFrame(mpe, columns=columns_) 216 | 217 | 218 | def get_performance_test_bayesian(trial_i): 219 | torch.manual_seed(trial_i) 220 | np.random.seed(trial_i) 221 | 222 | indices = np.arange(0, mean.shape[0]) 223 | np.random.shuffle(indices) 224 | idx_train = indices[:55] 225 | idx_test = indices[55:] 226 | 227 | X = train_Set(n=25, seed=trial_i) 228 | X = np.array(X, dtype='float') 229 | 230 | features = torch.FloatTensor(X) 231 | labels = torch.FloatTensor(mean) 232 | 233 | mserr = [] 234 | maerr = [] 235 | mperr = [] 236 | 237 | models = [ 238 | DSGCN(in_features=10) 239 | ] 240 | 241 | def weight_reset(m): 242 | reset_parameters = getattr(m, "reset_parameters", None) 243 | if callable(reset_parameters): 244 | m.reset_parameters() 245 | 246 | for model in models: 247 | model.apply(weight_reset) 248 | 249 | mse = nn.MSELoss() 250 | mae = nn.L1Loss() 251 | 252 | def train_(epoch, adj_=None): 253 | t = time.time() 254 | model.train() 255 | optimizer.zero_grad() 256 | if adj_ is not None: 257 | output, _ = model(features, adj_) 258 | else: 259 | output, _ = model(features) 260 | loss_train = mse(output[idx_train], labels[idx_train]) 261 | loss_train.backward() 262 | optimizer.step() 263 | 264 | loss_test = mse(output[idx_test], labels[idx_test]) 265 | return loss_test.item(), output 266 | 267 | for network in (models): 268 | 269 | t_total = time.time() 270 | # Model and optimizer 271 | model = network 272 | 273 | if isinstance(model, (DSGCN)): 274 | embedding = np.loadtxt(log_dir + '/decoding_deepset_gcn' + '_trial_' + str(trial_i), delimiter=",") 275 | 276 | adj_np = MAP_inference(embedding, num_neib, 1) 277 | adj_np_norm = normalize(adj_np) 278 | adj_norm = torch.FloatTensor(adj_np_norm) 279 | 280 | no_decay = list() 281 | decay = list() 282 | for m in model.modules(): 283 | if isinstance(m, torch.nn.Linear) or isinstance(m, GraphConvolution): 284 | decay.append(m.weight) 285 | no_decay.append(m.bias) 286 | 287 | optimizer = optim.Adam([{'params': no_decay, 'weight_decay': 0}, {'params': decay, 'weight_decay': weight_decay_}], lr=lr_) 288 | 289 | # Train model 290 | t_total = time.time() 291 | output_ = 0.0 292 | for epoch in range(EPOCHS): 293 | if isinstance(model, DeepSet): 294 | value, output = train_(epoch) 295 | else: 296 | value, output = train_(epoch, adj_norm) 297 | 298 | if epoch >= EPOCHS - MC_samples: 299 | output_ += output 300 | 301 | output = output_ / np.float32(MC_samples) 302 | 303 | maerr.append(non_standardized['price'].std() * (mae(output[idx_test], labels[idx_test]).detach().cpu().numpy())) 304 | mserr.append( 305 | non_standardized['price'].std() * np.sqrt(mse(output[idx_test], labels[idx_test]).detach().numpy())) 306 | 307 | target = (non_standardized['price'].std() * labels[idx_test]).detach().cpu().numpy() + non_standardized[ 308 | 'price'].mean() 309 | pred = (non_standardized['price'].std() * output[idx_test].T).detach().cpu().numpy() + non_standardized[ 310 | 'price'].mean() 311 | 312 | mperr.append(100 * np.mean(np.abs((target - pred) / target))) 313 | model.apply(weight_reset) 314 | adj_norm = None 315 | return mserr, maerr, mperr 316 | 317 | 318 | mse_b = [] 319 | mae_b = [] 320 | mpe_b = [] 321 | print('B-DSGCN') 322 | for i in range(num_trials): 323 | t = time.time() 324 | mserr_b, maerr_b, mperr_b = get_performance_test_bayesian(i) 325 | mse_b.append(mserr_b) 326 | mae_b.append(maerr_b) 327 | mpe_b.append(mperr_b) 328 | print('run : ' + str(i+1) + ', RMSE : ' + str(np.around(np.array(mserr_b), 2)) + ', MAE : ' + str(np.around(np.array(maerr_b), 2)) + 329 | ', MAPE : ' + str(np.around(np.array(mperr_b), 2)) + ', run time: {:.2f}s'.format(time.time() - t)) 330 | 331 | 332 | mse_b = np.array(mse_b) 333 | mae_b = np.array(mae_b) 334 | mpe_b = np.array(mpe_b) 335 | 336 | columns_b = ['B-DSGCN'] 337 | df_mse_b = pd.DataFrame(mse_b, columns = columns_b) 338 | df_mae_b = pd.DataFrame(mae_b, columns = columns_b) 339 | df_mpe_b = pd.DataFrame(mpe_b, columns = columns_b) 340 | 341 | mse_concat = pd.concat([df_mse, df_mse_b], axis=1) 342 | mae_concat = pd.concat([df_mae, df_mae_b], axis=1) 343 | mpe_concat = pd.concat([df_mpe, df_mpe_b], axis=1) 344 | 345 | mse_concat.to_csv('rmse_deepset.csv', index=False) 346 | mae_concat.to_csv('mae_deepset.csv', index=False) 347 | mpe_concat.to_csv('mape_deepset.csv', index=False) 348 | 349 | -------------------------------------------------------------------------------- /mil_rental_data/evaluate.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from scipy.stats import wilcoxon 3 | 4 | alg = 'deepset' # 'set_transformer' 5 | 6 | quant = ['rmse', 'mae', 'mape'] 7 | 8 | for quant_ in quant: 9 | file_name = quant_ + '_' + alg + '.csv' 10 | result = np.loadtxt(open(file_name, "rb"), delimiter=",", skiprows=1) 11 | print('Algorithm: ' + alg) 12 | print('---------------' + quant_ + ' mean and std. error---------------') 13 | print(np.round(result.mean(axis=0), 2)) 14 | print(np.round(result.std(axis=0), 2)) 15 | print('----------------------------------------------------') 16 | print(quant_ + ': statistical test') 17 | 18 | _, p = wilcoxon(result[:, 0], result[:, 1], zero_method='wilcox', correction=False) 19 | print('vanilla vs GCN') 20 | print(p) 21 | _, p = wilcoxon(result[:, 0], result[:, 2], zero_method='wilcox', correction=False) 22 | print('vanilla vs BGCN') 23 | print(p) 24 | _, p = wilcoxon(result[:, 1], result[:, 2], zero_method='wilcox', correction=False) 25 | print('GCN vs BGCN') 26 | print(p) 27 | -------------------------------------------------------------------------------- /mil_rental_data/mae_deepset.csv: -------------------------------------------------------------------------------- 1 | DeepSet,DSGCN,B-DSGCN 2 | 77.07391620524622,57.43637492812094,47.26009531379419 3 | 56.14650594072809,49.18978002217976,42.52855133450055 4 | 58.37818794964221,57.57949478401765,39.65257789378745 5 | 51.96165974742866,43.0999632753376,40.51071570121238 6 | 64.537223879126,65.9054435281043,40.96726392592674 7 | 68.01812307452371,68.74658873645144,49.81768109233651 8 | 60.96474752503608,80.97492516600111,36.31957007356553 9 | 56.9088841148798,64.69607045678677,58.17331357360949 10 | 65.16949777018542,52.67916763462382,70.81510279175393 11 | 80.3485057104561,63.21863317514539,86.0577892933578 12 | 73.36291917395016,63.4085267812151,49.650670200825694 13 | 41.82642291564483,64.11862145085645,46.1604080242011 14 | 86.96961505247714,70.27863997891954,51.86001481166996 15 | 64.04632586009741,71.66800350308603,62.89172622895718 16 | 86.67277768061373,62.34355454509401,43.42289352675198 17 | 50.598425114279415,49.3703312250033,32.052846667235045 18 | 73.02427763425726,61.64075991411252,50.29830040999624 19 | 43.29526889033625,38.96457368644734,32.696263534853486 20 | 35.852825459405196,54.512958952861496,30.460772027258866 21 | 54.58216784651277,46.144182286404615,34.38194740911782 22 | 38.3643217056334,53.00020471181211,34.84442466452667 23 | 76.16974031045896,62.39446446939141,59.74956631866078 24 | 53.22227199073509,62.32235408045884,41.622901541844335 25 | 75.23989338472546,63.937746144836595,54.085240479133674 26 | 79.01360181226094,59.64027666306518,38.781712605288156 27 | 60.97951737069674,39.38378573739568,40.943606964843305 28 | 40.872710676773124,65.21583423826814,44.18905804935588 29 | 53.774939963385236,88.78102267218833,43.825496575012956 30 | 75.3434983731476,47.371946927989235,51.27844506510077 31 | 75.7531082238432,73.41915879525365,64.94906129623438 32 | 73.39470700966497,65.2423078104624,51.288307062360076 33 | 55.73181847329022,61.43072560941465,43.175875446207904 34 | 45.29831409998304,48.780468552204574,38.74212314263348 35 | 73.91838060748285,66.50158248502214,55.81759264141983 36 | 52.020769997042365,56.854141541673584,40.31649300244904 37 | 126.83750807526432,66.50046612956815,51.376324230388 38 | 48.48466522387228,52.93447864139743,38.751138870435746 39 | 65.55953796121779,50.65558560022004,42.86459490412775 40 | 54.69970927237646,45.10563217752599,43.419493015438256 41 | 63.25842584535967,51.68589422782093,47.42159130647246 42 | 69.98105151075991,71.94374844471304,62.24945658376379 43 | 61.33584568481714,64.8948383170413,58.765167166042716 44 | 59.46118649687266,55.603499330025244,38.053745959392856 45 | 74.57550241025976,65.47386668299177,67.07688623646256 46 | 67.53800791739158,58.05847815221027,46.106154178036945 47 | 46.65887874013407,55.9512106147198,40.75687979555665 48 | 108.98040970136418,64.14696761929609,66.6551302326479 49 | 78.87536922678487,49.72597017676963,41.11574434580059 50 | 64.95355244052617,55.38662284673849,63.36355360435496 51 | 65.65689238799669,50.42155737000448,47.622519854702425 52 | 63.150700116298665,64.74010219102904,48.33891653093879 53 | 82.05368979369874,43.3259866710537,36.46818939480924 54 | 44.06746790578583,52.807836603561014,33.84093429029708 55 | 55.0216157692501,60.45813307323379,49.986996717687774 56 | 59.95755311427717,56.616316674015465,38.61419240875457 57 | 44.54299417321776,48.66761376144951,45.08213726804072 58 | 54.25143339591092,43.89123807891317,41.595936670360416 59 | 79.87524816539838,65.69286269829252,72.18841034646485 60 | 81.57844132900651,42.80491361554911,35.54521808688328 61 | 67.628674500437,86.39639482158445,55.36367839823744 62 | 61.34191104463389,52.673143430768505,35.773767429732565 63 | 51.71016595607833,59.11017244627279,46.974364907023556 64 | 55.35111554100884,77.84890868157832,58.64173529319837 65 | 64.96653200186432,53.57029194514041,50.06157131981138 66 | 46.87560603306062,47.7279048384567,39.2460341622189 67 | 66.24058681107506,41.83756074770868,41.27155052681021 68 | 71.88453530519574,45.960513519502236,40.87913100675727 69 | 58.71748284022293,51.160363467243215,41.72764346370116 70 | 56.87090745146393,52.66902783462482,46.35750392352225 71 | 112.81172104142574,63.7440661903147,57.219498444834194 72 | 49.842647327443146,49.022326704083504,43.99257434496109 73 | 41.120948002674766,51.0681586801391,28.93760790021354 74 | 71.73295275472857,78.10510454152282,33.26043459872524 75 | 40.81049829656612,52.22063363477015,34.94554228952946 76 | 64.06218633873614,71.82414407628232,56.87480697881007 77 | 79.7710721380113,60.91671337354408,46.651393499647746 78 | 73.39492822295769,61.56736854588023,42.798843111237176 79 | 85.13812332358475,43.85937822026585,39.1491556012441 80 | 75.96292645974353,60.46020116029599,36.158786593469635 81 | 80.56926628760345,65.7738267634292,45.96073473279496 82 | 73.56524702486898,49.86638402820186,37.97180444017205 83 | 51.650392066586456,68.56140234347107,53.65042260205799 84 | 82.69200332209901,58.87298035152177,44.657551791877275 85 | 72.4128286596849,69.19202485157784,53.1175197798879 86 | 85.62067697143202,62.21374349822695,55.5536851832011 87 | 63.1274058421254,62.619176018836704,53.30472795947385 88 | 60.25053725925105,68.31661182933978,57.308199830721 89 | 46.308667230782234,51.99076730115489,41.10734852966747 90 | 68.47018015494626,62.61145413157212,53.6929315657271 91 | 73.9543663512642,74.64755106525018,62.59088643984404 92 | 67.93746253460263,64.99945677101381,46.443946876025215 93 | 54.2510115473062,54.55303971480583,47.40901301575831 94 | 86.58039283617832,70.37244470402452,46.54228390138343 95 | 78.5263512403147,49.11771078920863,34.634668162568495 96 | 71.3314351949505,62.61330100534159,60.16577655667182 97 | 63.80321759588983,51.99677092702949,49.90972125559488 98 | 66.20767233091593,66.17645038967089,56.18768762464086 99 | 51.749058339636164,74.23713867330656,52.38756735885326 100 | 40.71228216684462,75.5820846260923,37.3273303787942 101 | 82.88276120335888,57.590159322524975,50.78111128260249 102 | -------------------------------------------------------------------------------- /mil_rental_data/mae_set_transformer.csv: -------------------------------------------------------------------------------- 1 | SetTransformer,STGCN,B-STGCN 2 | 66.05051314767407,57.6526443609765,48.32488234808882 3 | 48.36985038045378,45.521893860488824,34.01142800504448 4 | 62.22608514216283,57.12233436437693,50.68164246830476 5 | 53.60106631530483,59.52238541152928,44.765581046153876 6 | 57.02181092856738,53.109211420172834,56.94771476449548 7 | 50.20828717783863,50.1403798414678,61.5296079512619 8 | 59.59995925434259,67.39249072674117,35.06933942153156 9 | 58.43470022919021,61.435283632143786,65.13358919383175 10 | 69.10276786021163,62.2503362924395,71.24017699495943 11 | 68.10036297446493,53.31881358727561,53.46278228487696 12 | 62.6068909643478,52.449882628963856,43.67141894438368 13 | 54.54741678157451,55.49543406428239,57.01134188087688 14 | 52.91819116965879,54.48753485768387,52.922044396548316 15 | 59.15464660710051,65.90290729198074,63.7500698161893 16 | 38.60862863721783,51.81865307042591,42.927823322132724 17 | 26.81714216033978,44.28640218714443,35.27089559817343 18 | 59.136939254692294,53.41804575479508,42.1551304351459 19 | 58.912978801543225,33.09038331056329,38.065519136611385 20 | 49.27295107574849,55.7501894655766,38.58792718861508 21 | 52.7950731110204,37.1678355947408,57.244963695973254 22 | 37.081989403678676,32.921993694344344,41.81233471559547 23 | 74.63027585043706,56.438985925174364,60.50515375917541 24 | 60.209334997357566,55.637735945445534,50.52437010616895 25 | 45.776345736567436,61.41432495878206,44.006762862666456 26 | 51.68477787236696,45.22540116980245,43.91620431401981 27 | 51.25590188722373,41.10154296685728,35.63921315830991 28 | 63.09075131397069,44.37810795821613,43.90298296140821 29 | 43.8737725177784,46.19033355266088,37.001292852291336 30 | 58.625025972855006,51.584357326461,48.5069100210289 31 | 60.807614064270126,58.87790363340866,53.48524829532631 32 | 78.19176356282333,66.12906958906669,71.29295951550222 33 | 52.73338546932172,55.434317461548645,55.25581891230177 34 | 47.70480091060508,45.83942239196461,42.13046258075968 35 | 52.948245310498066,55.71622036390564,51.94123095707044 36 | 58.67316815874578,45.565246521367385,38.18213712559034 37 | 64.12615299179939,53.834857898742136,67.92290875773952 38 | 63.251593955761145,44.78249100180925,36.77831499772157 39 | 56.608893167471294,68.34575025003707,54.03070883022983 40 | 39.399599915577795,42.695919768426336,45.1751343074025 41 | 54.33364242888106,53.1829629030677,58.651720758341995 42 | 64.82875213196405,46.515779462218084,45.89669091230401 43 | 55.23376446146679,69.53352158609542,64.58540722097821 44 | 60.38537962240376,69.21871449256965,45.50930528078432 45 | 67.59188107091244,65.30063095231363,67.71148029484797 46 | 53.73801792148119,58.2787397133252,51.48813983311679 47 | 56.93795051264458,49.39052336858326,47.30954420146058 48 | 57.85602168891196,54.87788371842702,60.30180215371585 49 | 44.55894725276972,50.5071257583269,42.295544286578135 50 | 56.03639831039893,57.902836595046395,50.582014174656464 51 | 57.67232205503851,64.45443866269575,50.078563587389624 52 | 43.677947308766605,39.739046568766305,41.65629703230277 53 | 45.66283245042938,55.833638321885026,53.853640450642885 54 | 52.8760577541378,41.855587058818024,43.307394465474594 55 | 48.16432265353324,59.13671289690439,59.54626615815302 56 | 59.19701152490458,50.83954245885248,44.314614598709404 57 | 64.07431191387448,50.49233533468553,49.02943639642172 58 | 47.717837061390206,51.96257546757063,38.67008477663342 59 | 62.87878267908528,63.46824922575517,67.7994717404 60 | 47.255568158036134,56.38870877378405,46.043509660234854 61 | 69.06715766457837,78.16839212122237,52.30786369503559 62 | 43.9584148969685,44.987010407674575,68.77475999205498 63 | 69.55686730522048,50.89907970156609,57.88149722904138 64 | 61.56126717459721,45.669520294162865,48.635578989966085 65 | 54.38667702968765,44.25733064488446,45.91641490682263 66 | 76.52542556268192,44.42302454562929,53.84302735708735 67 | 42.86162653040911,46.40902089825085,32.2059699940085 68 | 52.12816647841187,64.37728666848716,40.28342161518693 69 | 63.72035521203188,55.68288403514178,49.09112403812041 70 | 59.427582654359455,49.33797749481874,50.96571120313237 71 | 68.97885754931558,65.67527881376861,52.60659938562027 72 | 47.70944124525708,48.95187284259877,38.09249944158084 73 | 53.66040292270643,64.27144896915708,39.22130457389052 74 | 71.17046394078055,50.34569664408597,43.475542290420094 75 | 48.87745257533056,39.812777473680455,42.74990867308874 76 | 57.91542003025572,46.457641522193335,60.881555893486635 77 | 48.61887995861308,39.187729026100804,45.745427320537985 78 | 51.099236575519114,47.95766313767319,60.95727771803529 79 | 35.30568524682554,47.64256795241736,37.58893567542005 80 | 45.07858242187161,51.87825719157685,38.99450950389748 81 | 55.2346698926184,62.810633551441015,51.934851783047726 82 | 56.84385255131437,41.53911314885406,38.427179719985446 83 | 69.06141640795794,66.95401511359279,51.52990798948003 84 | 48.10254241092132,43.76113894031605,39.811221263888626 85 | 49.13162664866947,52.33573656991871,53.07781971058686 86 | 58.18505331160935,44.924936928837425,41.88001054968321 87 | 64.0528387909948,67.31454133577975,67.21993406942676 88 | 58.8449942977447,62.77593393145455,59.36969165010332 89 | 44.369058791195194,43.64031018203259,43.26584752240407 90 | 61.269615453874856,69.47103654764389,54.52598481465626 91 | 69.87821305211955,73.34932741768564,67.70761163447291 92 | 53.37515095398751,61.33896324889598,50.757693540544906 93 | 51.05621316233205,53.503444374776585,53.67305838084827 94 | 62.824924959049966,53.99954347843176,51.32113923059634 95 | 53.037075308764365,48.992380597643006,29.016897432169255 96 | 60.2030690022288,64.59851539469585,55.78497654198111 97 | 55.73216829896243,60.282885844940424,57.13809195311207 98 | 58.28291704341105,58.372302647156744,56.14803900029161 99 | 59.23618171120212,56.334341748725954,48.107604594178056 100 | 64.84621769309882,44.19021041627611,43.89155189311913 101 | 65.14438748921374,46.94334360109052,53.58029798826475 102 | -------------------------------------------------------------------------------- /mil_rental_data/mape_deepset.csv: -------------------------------------------------------------------------------- 1 | DeepSet,DSGCN,B-DSGCN 2 | 2.5133760645985603,1.9506670534610748,2.0974649116396904 3 | 2.0236145704984665,1.7046360298991203,1.6246119514107704 4 | 2.0412279292941093,2.099205367267132,2.095978334546089 5 | 2.0369166508316994,1.4419964514672756,1.8111327663064003 6 | 2.3417560383677483,2.176060527563095,1.7865262925624847 7 | 2.0590338855981827,2.5286786258220673,1.933790184557438 8 | 2.00178362429142,2.067665569484234,1.4131894335150719 9 | 2.452876977622509,1.796112023293972,1.936088316142559 10 | 2.233998104929924,1.7951460555195808,2.1777356043457985 11 | 2.417987957596779,1.8521016463637352,2.287006564438343 12 | 2.2325078025460243,1.779620349407196,2.0484210923314095 13 | 1.7556840553879738,1.8421443179249763,1.737549714744091 14 | 2.914668433368206,2.1546367555856705,1.8164528533816338 15 | 2.464604936540127,2.395288273692131,2.1946677938103676 16 | 2.7437573298811913,2.0871151238679886,1.7237750813364983 17 | 2.3467320948839188,2.1316908299922943,1.5239385887980461 18 | 2.309975028038025,1.8118688836693764,1.8912425264716148 19 | 1.7607897520065308,1.7114633694291115,1.6000671312212944 20 | 1.9496569409966469,1.6758350655436516,1.645718701183796 21 | 1.6078617423772812,1.6797887161374092,1.4986857771873474 22 | 1.5202601440250874,1.893448829650879,1.6644679009914398 23 | 2.6839526370167732,1.8487010151147842,2.066147141158581 24 | 2.186994254589081,1.8100781366229057,1.7936434596776962 25 | 3.315049782395363,2.0186463370919228,2.056611329317093 26 | 2.4378789588809013,1.8995368853211403,1.5859264880418777 27 | 2.0626265555620193,1.7295261844992638,1.7737511545419693 28 | 1.7436983063817024,1.8849031999707222,1.6251340508460999 29 | 2.1870240569114685,2.1329907700419426,2.0050738006830215 30 | 2.4100836366415024,1.9054921343922615,1.79057065397501 31 | 2.371329255402088,2.0367972552776337,2.0296702161431313 32 | 2.4579524993896484,2.185629680752754,2.1491149440407753 33 | 2.1259797737002373,2.0528972148895264,1.9345438107848167 34 | 2.0362095907330513,1.9306320697069168,1.747714914381504 35 | 2.337770164012909,1.9168021157383919,1.9752491265535355 36 | 1.8858881667256355,1.8242394551634789,1.6929319128394127 37 | 3.7206921726465225,1.7988963052630424,2.0701590925455093 38 | 2.144017443060875,1.8808389082551003,1.7833059653639793 39 | 2.2341936826705933,2.375739812850952,1.890590786933899 40 | 2.1327564492821693,1.6291053965687752,1.8502747640013695 41 | 2.6038996875286102,1.6469011083245277,1.6499778255820274 42 | 2.307119034230709,2.2032825276255608,1.9522331655025482 43 | 2.432316727936268,2.1627023816108704,2.2458206862211227 44 | 2.4745898321270943,1.8141305074095726,1.7863987013697624 45 | 2.255641110241413,2.256842330098152,2.208918146789074 46 | 2.200762555003166,1.9760208204388618,1.8009185791015625 47 | 1.5706731006503105,1.8349198624491692,1.5731770545244217 48 | 2.6366621255874634,2.0920781418681145,2.001768723130226 49 | 2.4422019720077515,1.6626004129648209,1.674099639058113 50 | 1.994810439646244,1.7055569216609001,1.8825327977538109 51 | 2.09121722728014,1.8593909218907356,1.9630054011940956 52 | 2.1830925717949867,1.9430940970778465,1.9837576895952225 53 | 3.1005900353193283,1.5763355419039726,1.6385257244110107 54 | 1.9866682589054108,1.6579331830143929,1.9120456650853157 55 | 2.2568603977560997,1.8983542919158936,1.8837563693523407 56 | 2.0121220499277115,1.973532885313034,1.4932379126548767 57 | 1.8136149272322655,1.7382113263010979,1.790604367852211 58 | 2.0184116438031197,1.7770051956176758,1.6909334808588028 59 | 2.28441022336483,2.2453289479017258,2.194228768348694 60 | 2.6625696569681168,1.7765020951628685,1.7862090840935707 61 | 2.4489501491189003,2.2434990853071213,1.8958788365125656 62 | 2.071201615035534,1.864437572658062,1.7827440053224564 63 | 1.8617155030369759,1.8554355949163437,1.9976269453763962 64 | 2.173440158367157,2.41999514400959,1.8181707710027695 65 | 1.8823182210326195,1.694541610777378,1.7109919339418411 66 | 1.7735503613948822,1.6721634194254875,1.5683859586715698 67 | 2.2412972524762154,1.5276053920388222,1.6098938882350922 68 | 2.1842719987034798,1.4153828844428062,1.6847465187311172 69 | 2.086610347032547,1.5417967922985554,1.8168851733207703 70 | 2.0458001643419266,1.7103290185332298,1.666388101875782 71 | 3.3619292080402374,2.210371196269989,2.1721454337239265 72 | 1.9426696002483368,1.555402297526598,1.655426062643528 73 | 1.9763931632041931,1.5178769826889038,1.62579994648695 74 | 2.1534616127610207,2.1556925028562546,1.4479530975222588 75 | 1.6862565651535988,1.791476458311081,1.6048813238739967 76 | 2.240169420838356,2.0379720255732536,1.7540331929922104 77 | 2.119012735784054,1.8400592729449272,1.66791882365942 78 | 2.496592327952385,1.8485337495803833,1.7292940989136696 79 | 2.750883810222149,1.7433695495128632,1.6060255467891693 80 | 2.266561798751354,1.7515560612082481,1.690090261399746 81 | 2.5868048891425133,2.069994807243347,1.8836671486496925 82 | 2.1088751032948494,2.15835589915514,1.7319394275546074 83 | 2.2390635684132576,2.0080041140317917,1.9717298448085785 84 | 2.2627681493759155,1.7686184495687485,1.94675512611866 85 | 2.199380472302437,2.0322972908616066,1.858949288725853 86 | 2.278119884431362,2.2651752457022667,1.8340559676289558 87 | 2.0944198593497276,2.2050155326724052,1.8178319558501244 88 | 1.9687052816152573,1.7969945445656776,1.8931947648525238 89 | 1.7955105751752853,1.616942323744297,1.6407137736678123 90 | 2.315966971218586,2.1038757637143135,2.136179991066456 91 | 2.221093699336052,2.4392927065491676,2.290900982916355 92 | 2.220132201910019,1.9960248842835426,1.9227582961320877 93 | 1.8625231459736824,1.745789684355259,1.6364490613341331 94 | 2.559508942067623,2.08247322589159,1.6996197402477264 95 | 2.387545630335808,1.8663100898265839,1.6536431387066841 96 | 2.206268720328808,2.0540911704301834,2.018336020410061 97 | 1.9128601998090744,1.5296323224902153,1.9582685083150864 98 | 2.391078695654869,2.3309187963604927,2.1952489390969276 99 | 2.1064825356006622,1.9208358600735664,1.8268372863531113 100 | 1.7056964337825775,2.299053408205509,1.6599729657173157 101 | 2.7931297197937965,1.7864042893052101,1.5815604478120804 102 | -------------------------------------------------------------------------------- /mil_rental_data/mape_set_transformer.csv: -------------------------------------------------------------------------------- 1 | SetTransformer,STGCN,B-STGCN 2 | 2.1577855572104454,1.8218668177723885,1.8121963366866112 3 | 1.7241068184375763,1.5124596655368805,1.5932871028780937 4 | 2.4066435173153877,2.3458797484636307,2.1506499499082565 5 | 1.799686998128891,1.9012914970517159,1.6106834635138512 6 | 2.2876590490341187,1.8452463671565056,1.8525226041674614 7 | 2.099371515214443,1.759677566587925,2.0183824002742767 8 | 2.088843658566475,2.396921068429947,1.3878230936825275 9 | 2.0208096131682396,1.9515356048941612,2.092047408223152 10 | 1.9699884578585625,1.902320608496666,2.0979054272174835 11 | 2.170694060623646,1.9457900896668434,1.8132831901311874 12 | 2.1781476214528084,1.7918048426508904,2.0558370277285576 13 | 1.926717348396778,1.5782784670591354,2.1885860711336136 14 | 1.9659969955682755,1.987317018210888,2.168484590947628 15 | 1.9619768485426903,2.0466364920139313,2.1164610981941223 16 | 1.555943675339222,1.7482785508036613,1.6207743436098099 17 | 1.8297994509339333,1.6597816720604897,1.6381708905100822 18 | 1.9701413810253143,1.8144182860851288,1.7820922657847404 19 | 2.0103221759200096,1.597590558230877,1.6314255073666573 20 | 1.9825967028737068,1.800602488219738,1.7469227313995361 21 | 1.6352290287613869,1.5379748307168484,1.6669852659106255 22 | 1.459116954356432,1.5431178733706474,1.5188268385827541 23 | 2.3208491504192352,1.939658261835575,2.0802320912480354 24 | 2.1230073645710945,1.7013616859912872,1.8342986702919006 25 | 2.272232621908188,1.9513465464115143,1.944229006767273 26 | 1.8859507516026497,1.566316932439804,1.6347430646419525 27 | 2.131418325006962,1.538029033690691,1.5980103984475136 28 | 2.636796049773693,1.4701246283948421,1.6378886997699738 29 | 2.1059921011328697,1.9129248335957527,2.155402861535549 30 | 1.8015176057815552,1.7275109887123108,1.7268877476453781 31 | 2.2217541933059692,1.945529319345951,2.0511964336037636 32 | 2.4807799607515335,2.126690372824669,2.0760444924235344 33 | 2.067367546260357,1.999138854444027,1.8034454435110092 34 | 1.6150100156664848,1.9084496423602104,1.6564078629016876 35 | 1.9163254648447037,1.7236918210983276,1.8378552049398422 36 | 1.8568271771073341,1.6823966056108475,1.4624020084738731 37 | 2.3097455501556396,1.8306806683540344,2.088961936533451 38 | 2.0689113065600395,1.7950965091586113,1.7396142706274986 39 | 2.070232294499874,2.3571282625198364,1.8657289445400238 40 | 1.6783777624368668,1.524160709232092,1.5152718871831894 41 | 1.83615330606699,1.722738891839981,1.7651388421654701 42 | 1.9182488322257996,1.7111703753471375,1.6882073134183884 43 | 2.2299841046333313,2.188808098435402,2.155996486544609 44 | 2.4904653429985046,2.175062708556652,2.0089948549866676 45 | 2.1800320595502853,2.106383442878723,2.1776093170046806 46 | 1.8763994798064232,2.0594915375113487,1.6728421673178673 47 | 2.070634439587593,1.7015736550092697,1.7859738320112228 48 | 2.0188139751553535,2.0265135914087296,2.0004838705062866 49 | 2.042713761329651,1.6064886003732681,1.706375926733017 50 | 2.013128809630871,1.7100952565670013,1.7466498538851738 51 | 2.1336110308766365,2.073190174996853,1.82106364518404 52 | 2.174678258597851,1.58794354647398,1.7735067754983902 53 | 1.874341256916523,1.8511731177568436,1.9521523267030716 54 | 1.9686529412865639,1.4850590378046036,1.839444600045681 55 | 1.9784020259976387,1.8523884937167168,1.917433924973011 56 | 1.8222395330667496,1.7364073544740677,1.6342008486390114 57 | 2.0843664184212685,1.7163338139653206,1.6497081145644188 58 | 1.7806613817811012,1.775108277797699,1.7969606444239616 59 | 2.076714299619198,2.0671840757131577,2.217813767492771 60 | 1.856851577758789,1.56975407153368,1.85406394302845 61 | 2.3598648607730865,2.200615219771862,2.1042028442025185 62 | 1.7843995243310928,2.0095733925700188,2.3819925263524055 63 | 2.2924786433577538,1.754225231707096,1.8324565142393112 64 | 1.9608678296208382,1.6817258670926094,1.7245186492800713 65 | 1.8141502514481544,1.5030168928205967,1.6234049573540688 66 | 2.43787057697773,1.5182465314865112,1.801021583378315 67 | 1.8000205978751183,1.6427021473646164,1.4733819290995598 68 | 1.9415831193327904,1.8815664574503899,1.4958187006413937 69 | 2.103802375495434,1.7432505264878273,1.6066955402493477 70 | 1.8387814983725548,1.6358675435185432,1.6169887036085129 71 | 2.299574390053749,1.979646272957325,1.9792348146438599 72 | 1.9758066162467003,1.5382983721792698,1.5992643311619759 73 | 2.082151360809803,1.8038753420114517,1.7655828967690468 74 | 1.8502892926335335,1.7168805003166199,1.4183958992362022 75 | 1.8665570765733719,1.5585138462483883,1.5822723507881165 76 | 1.8473964184522629,1.6526451334357262,1.7089242115616798 77 | 1.4831321313977242,1.531778834760189,1.5500785782933235 78 | 1.9072655588388443,1.6898272559046745,2.07919143140316 79 | 1.9340049475431442,1.863180473446846,1.5752771869301796 80 | 1.9648993387818336,1.611294411122799,1.8175110220909119 81 | 2.015724405646324,1.8947133794426918,1.9319947808980942 82 | 2.152271755039692,1.7232807353138924,1.736137643456459 83 | 2.3276226595044136,1.9663874059915543,2.0757514983415604 84 | 1.7827734351158142,1.6621392220258713,2.163245528936386 85 | 2.395191974937916,1.8693368881940842,1.835053600370884 86 | 2.179812453687191,1.5329786576330662,1.6595056280493736 87 | 2.1177753806114197,2.0746415480971336,1.9936172291636467 88 | 1.834098994731903,1.8662093207240105,2.228272147476673 89 | 1.7910901457071304,1.586262695491314,1.7983265221118927 90 | 2.0522700622677803,2.23811361938715,2.0715706050395966 91 | 2.4290017783641815,2.3958520963788033,2.379382774233818 92 | 1.907171867787838,2.008471265435219,1.8993007019162178 93 | 1.7735611647367477,1.6890592873096466,1.6586868092417717 94 | 1.8653925508260727,1.5636662021279335,1.7078746110200882 95 | 2.0586857572197914,1.661829650402069,1.6068913042545319 96 | 2.0140159875154495,2.043093554675579,1.9115518778562546 97 | 1.9604559987783432,1.933271810412407,2.0745983347296715 98 | 2.0628269761800766,2.042781561613083,2.089470811188221 99 | 1.8678417429327965,2.033270336687565,1.7537947744131088 100 | 2.242342196404934,1.572493463754654,1.548153068870306 101 | 2.164452150464058,1.5146086923778057,2.074548602104187 102 | -------------------------------------------------------------------------------- /mil_rental_data/models_all_deepset.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | import math 5 | import scipy.sparse as sp 6 | from torch.nn.parameter import Parameter 7 | from torch.nn.modules.module import Module 8 | from modules import * 9 | 10 | 11 | # GCN model 12 | class GraphConvolution(Module): 13 | """ 14 | Simple GCN layer, similar to https://arxiv.org/abs/1609.02907 15 | """ 16 | 17 | def __init__(self, in_features, out_features, bias=True): 18 | super(GraphConvolution, self).__init__() 19 | self.in_features = in_features 20 | self.out_features = out_features 21 | self.weight = Parameter(torch.FloatTensor(in_features, out_features)) 22 | if bias: 23 | self.bias = Parameter(torch.FloatTensor(out_features)) 24 | else: 25 | self.register_parameter('bias', None) 26 | self.reset_parameters() 27 | 28 | def reset_parameters(self): 29 | stdv = 1. / math.sqrt(self.in_features) 30 | self.weight.data.uniform_(-stdv, stdv) 31 | if self.bias is not None: 32 | self.bias.data.uniform_(-stdv, stdv) 33 | 34 | def forward(self, input, adj): 35 | support = torch.mm(input, self.weight) 36 | output = torch.mm(adj, support) 37 | if self.bias is not None: 38 | return output + self.bias 39 | else: 40 | return output 41 | 42 | def __repr__(self): 43 | return self.__class__.__name__ + ' (' \ 44 | + str(self.in_features) + ' -> ' \ 45 | + str(self.out_features) + ')' 46 | 47 | 48 | #Deep Set Model 49 | class DeepSet(nn.Module): 50 | 51 | def __init__(self, in_features=10, set_features=25, nhid=64, dropout=0.3): 52 | super(DeepSet, self).__init__() 53 | self.in_features = in_features 54 | self.out_features = set_features 55 | self.fc = nn.Linear(nhid, 1) 56 | self.dropout = dropout 57 | self.feature_extractor = nn.Sequential( 58 | nn.Linear(in_features, 25), 59 | nn.ELU(inplace=True), 60 | nn.Linear(25, 25), 61 | nn.ELU(inplace=True), 62 | nn.Linear(25, set_features) 63 | ) 64 | 65 | self.regressor = nn.Sequential( 66 | nn.Linear(set_features, 25), 67 | nn.ELU(inplace=True), 68 | nn.Linear(25, 25), 69 | nn.ELU(inplace=True), 70 | nn.Linear(25, nhid), 71 | nn.ELU(inplace=True), 72 | nn.Linear(nhid, nhid), 73 | ) 74 | self.reset_parameters() 75 | 76 | self.add_module('0', self.feature_extractor) 77 | self.add_module('1', self.regressor) 78 | 79 | def reset_parameters(self): 80 | for module in self.children(): 81 | reset_op = getattr(module, "reset_parameters", None) 82 | if callable(reset_op): 83 | reset_op() 84 | 85 | def forward(self, input): 86 | x = input 87 | x = self.feature_extractor(x) 88 | x = x.sum(dim=1) 89 | x = self.regressor(x) 90 | embedding = x.cpu().detach().numpy() 91 | x = F.dropout(x, self.dropout) 92 | x = self.fc(x) 93 | return x, embedding 94 | 95 | # def __repr__(self): 96 | # return self.__class__.__name__ + '(' \ 97 | # + 'Feature Exctractor=' + str(self.feature_extractor) \ 98 | # + '\n Set Feature' + str(self.regressor) + ')' 99 | 100 | 101 | # Graph Deep Set Model 102 | class DSGCN(nn.Module): 103 | 104 | def __init__(self, in_features=10, set_features=25, nhid=64, dropout=0.3): 105 | super(DSGCN, self).__init__() 106 | self.in_features = in_features 107 | self.out_features = set_features 108 | self.gc = GraphConvolution(nhid, 1) 109 | self.dropout = dropout 110 | self.feature_extractor = nn.Sequential( 111 | nn.Linear(in_features, 25), 112 | nn.ELU(inplace=True), 113 | nn.Linear(25, 25), 114 | nn.ELU(inplace=True), 115 | nn.Linear(25, set_features) 116 | ) 117 | 118 | self.regressor = nn.Sequential( 119 | nn.Linear(set_features, 25), 120 | nn.ELU(inplace=True), 121 | nn.Linear(25, 25), 122 | nn.ELU(inplace=True), 123 | nn.Linear(25, nhid), 124 | nn.ELU(inplace=True), 125 | nn.Linear(nhid, nhid), 126 | ) 127 | self.reset_parameters() 128 | 129 | self.add_module('0', self.feature_extractor) 130 | self.add_module('1', self.regressor) 131 | 132 | def reset_parameters(self): 133 | for module in self.children(): 134 | reset_op = getattr(module, "reset_parameters", None) 135 | if callable(reset_op): 136 | reset_op() 137 | 138 | def forward(self, input, adj): 139 | x = input 140 | x = self.feature_extractor(x) 141 | x = x.sum(dim=1) 142 | x = self.regressor(x) 143 | embedding = x.cpu().detach().numpy() 144 | x = F.dropout(x, self.dropout) 145 | x = self.gc(x, adj) 146 | return x, embedding 147 | 148 | -------------------------------------------------------------------------------- /mil_rental_data/models_all_set_transformer.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | import math 5 | import scipy.sparse as sp 6 | from torch.nn.parameter import Parameter 7 | from torch.nn.modules.module import Module 8 | from modules import * 9 | 10 | 11 | # GCN model 12 | class GraphConvolution(Module): 13 | """ 14 | Simple GCN layer, similar to https://arxiv.org/abs/1609.02907 15 | """ 16 | 17 | def __init__(self, in_features, out_features, bias=True): 18 | super(GraphConvolution, self).__init__() 19 | self.in_features = in_features 20 | self.out_features = out_features 21 | self.weight = Parameter(torch.FloatTensor(in_features, out_features)) 22 | if bias: 23 | self.bias = Parameter(torch.FloatTensor(out_features)) 24 | else: 25 | self.register_parameter('bias', None) 26 | self.reset_parameters() 27 | 28 | def reset_parameters(self): 29 | stdv = 1. / math.sqrt(self.in_features) 30 | self.weight.data.uniform_(-stdv, stdv) 31 | if self.bias is not None: 32 | self.bias.data.uniform_(-stdv, stdv) 33 | 34 | def forward(self, input, adj): 35 | support = torch.mm(input, self.weight) 36 | output = torch.mm(adj, support) 37 | if self.bias is not None: 38 | return output + self.bias 39 | else: 40 | return output 41 | 42 | def __repr__(self): 43 | return self.__class__.__name__ + ' (' \ 44 | + str(self.in_features) + ' -> ' \ 45 | + str(self.out_features) + ')' 46 | 47 | 48 | # Set Transformer model 49 | class SetTransformer(nn.Module): 50 | def __init__(self, in_features=200, num_heads=4, ln=True): 51 | super(SetTransformer, self).__init__() 52 | self.enc = nn.Sequential( 53 | SAB(dim_in=in_features, dim_out=64, num_heads=num_heads, ln=ln), 54 | # SAB(dim_in=64, dim_out=64, num_heads=num_heads, ln=ln) 55 | ) 56 | self.dec = nn.Sequential( 57 | PMA(dim=64, num_heads=num_heads, num_seeds=1, ln=ln), 58 | # PMA(dim=64, num_heads=num_heads, num_seeds=1, ln=ln) 59 | ) 60 | self.last_layer = nn.Linear(in_features=64, out_features=1) 61 | 62 | self.reset_parameters() 63 | 64 | def reset_parameters(self): 65 | for module in self.children(): 66 | reset_op = getattr(module, "reset_parameters", None) 67 | if callable(reset_op): 68 | reset_op() 69 | 70 | def forward(self, x): 71 | x = self.enc(x) 72 | x = self.dec(x).squeeze() 73 | embedding = x.cpu().detach().numpy() 74 | x = self.last_layer(x) 75 | return x, embedding 76 | 77 | 78 | # Set Transformer GCN model 79 | class STGCN(nn.Module): 80 | def __init__(self, in_features=200, num_heads=4, ln=True): 81 | super(STGCN, self).__init__() 82 | self.enc = nn.Sequential( 83 | SAB(dim_in=in_features, dim_out=64, num_heads=num_heads, ln=ln), 84 | # SAB(dim_in=64, dim_out=64, num_heads=num_heads, ln=ln) 85 | ) 86 | self.dec = nn.Sequential( 87 | PMA(dim=64, num_heads=num_heads, num_seeds=1, ln=ln), 88 | # PMA(dim=64, num_heads=num_heads, num_seeds=1, ln=ln) 89 | ) 90 | self.last_layer = GraphConvolution(in_features=64, out_features=1) 91 | 92 | self.reset_parameters() 93 | 94 | def reset_parameters(self): 95 | for module in self.children(): 96 | reset_op = getattr(module, "reset_parameters", None) 97 | if callable(reset_op): 98 | reset_op() 99 | 100 | def forward(self, x, adj): 101 | x = self.enc(x) 102 | x = self.dec(x).squeeze() 103 | embedding = x.cpu().detach().numpy() 104 | x = self.last_layer(x, adj) 105 | return x, embedding 106 | 107 | 108 | -------------------------------------------------------------------------------- /mil_rental_data/modules.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | import math 5 | import scipy.sparse as sp 6 | from torch.nn.parameter import Parameter 7 | from torch.nn.modules.module import Module 8 | 9 | 10 | # Set transformer layers 11 | class MAB(nn.Module): 12 | def __init__(self, dim_Q, dim_K, dim_V, num_heads, ln=False): 13 | super(MAB, self).__init__() 14 | self.dim_V = dim_V 15 | self.num_heads = num_heads 16 | self.fc_q = nn.Linear(dim_Q, dim_V) 17 | self.fc_k = nn.Linear(dim_K, dim_V) 18 | self.fc_v = nn.Linear(dim_K, dim_V) 19 | if ln: 20 | self.ln0 = nn.LayerNorm(dim_V) 21 | self.ln1 = nn.LayerNorm(dim_V) 22 | self.fc_o = nn.Linear(dim_V, dim_V) 23 | 24 | self.reset_parameters() 25 | 26 | def reset_parameters(self): 27 | for module in self.children(): 28 | reset_op = getattr(module, "reset_parameters", None) 29 | if callable(reset_op): 30 | reset_op() 31 | 32 | def forward(self, Q, K): 33 | Q = self.fc_q(Q) 34 | K, V = self.fc_k(K), self.fc_v(K) 35 | 36 | dim_split = self.dim_V // self.num_heads 37 | Q_ = torch.cat(Q.split(dim_split, 2), 0) 38 | K_ = torch.cat(K.split(dim_split, 2), 0) 39 | V_ = torch.cat(V.split(dim_split, 2), 0) 40 | 41 | A = torch.softmax(Q_.bmm(K_.transpose(1,2))/math.sqrt(self.dim_V), 2) 42 | O = torch.cat((Q_ + A.bmm(V_)).split(Q.size(0), 0), 2) 43 | O = O if getattr(self, 'ln0', None) is None else self.ln0(O) 44 | O = O + F.relu(self.fc_o(O)) 45 | O = O if getattr(self, 'ln1', None) is None else self.ln1(O) 46 | return O 47 | 48 | 49 | class SAB(nn.Module): 50 | def __init__(self, dim_in, dim_out, num_heads, ln=False): 51 | super(SAB, self).__init__() 52 | self.mab = MAB(dim_in, dim_in, dim_out, num_heads, ln=ln) 53 | 54 | def forward(self, X): 55 | return self.mab(X, X) 56 | 57 | 58 | class ISAB(nn.Module): 59 | def __init__(self, dim_in, dim_out, num_heads, num_inds, ln=False): 60 | super(ISAB, self).__init__() 61 | self.I = nn.Parameter(torch.Tensor(1, num_inds, dim_out)) 62 | nn.init.xavier_uniform_(self.I) 63 | self.mab0 = MAB(dim_out, dim_in, dim_out, num_heads, ln=ln) 64 | self.mab1 = MAB(dim_in, dim_out, dim_out, num_heads, ln=ln) 65 | 66 | def forward(self, X): 67 | H = self.mab0(self.I.repeat(X.size(0), 1, 1), X) 68 | return self.mab1(X, H) 69 | 70 | 71 | class PMA(nn.Module): 72 | def __init__(self, dim, num_heads, num_seeds, ln=False): 73 | super(PMA, self).__init__() 74 | self.S = nn.Parameter(torch.Tensor(1, num_seeds, dim)) 75 | nn.init.xavier_uniform_(self.S) 76 | self.mab = MAB(dim, dim, dim, num_heads, ln=ln) 77 | 78 | def forward(self, X): 79 | return self.mab(self.S.repeat(X.size(0), 1, 1), X) 80 | 81 | 82 | -------------------------------------------------------------------------------- /mil_rental_data/rmse_deepset.csv: -------------------------------------------------------------------------------- 1 | DeepSet,DSGCN,B-DSGCN 2 | 119.34078507569322,92.68370873837178,78.80042936550352 3 | 72.17945378035715,62.709600810608684,57.26124087872154 4 | 77.31159702703931,72.92764343680349,61.3970343104834 5 | 67.58703495645325,51.48044366832809,57.183610446461245 6 | 86.37513776750231,93.57411796405839,70.72355135553761 7 | 81.78553298246165,90.77526565058281,66.84224066682545 8 | 80.47482878959138,88.95973214573756,46.45877845562803 9 | 84.50471764358582,95.95455819558607,92.15725196866173 10 | 80.00020280881589,71.18857770830796,88.59079981842713 11 | 93.64849707536517,78.60443953264598,100.37527434837537 12 | 95.7572668054481,92.47194073879837,73.31220474046958 13 | 59.94188812645526,75.18363318544127,67.59187592641726 14 | 122.4226257800087,88.92927673427428,73.89171154447175 15 | 90.56755151321093,98.84436518383714,90.49161876435991 16 | 105.2568549573347,76.69883620619615,57.73403027471791 17 | 71.4140094870784,58.98398312401218,39.7163181890628 18 | 99.93076938687648,97.13296655041754,78.29171595966794 19 | 54.16910603955165,52.68129745562818,43.73703183590441 20 | 49.86165109263662,69.96976448833585,47.67993242090685 21 | 69.75376256922964,57.43920440046973,49.16042553268492 22 | 52.83150642588239,60.386125574204804,49.29709933612157 23 | 95.14532910383367,76.58144397069267,75.93804253655976 24 | 63.12156684009654,80.12039879419225,56.557221857887306 25 | 101.2132817461625,84.9372925207632,73.44836923888346 26 | 90.93222419851266,80.78062272756249,55.43933334434876 27 | 86.98338172157777,51.599600465678186,59.64624427747353 28 | 53.486308061333304,77.51886873782573,55.61766726974989 29 | 73.01005310508565,103.39749035355919,58.194025311202594 30 | 97.79982688256945,65.75337225059506,72.17452020947991 31 | 106.52688677131498,101.06575820239205,105.6152513585173 32 | 107.22256656546307,95.91597448173901,89.07065775080028 33 | 89.87467032444056,102.6103620130984,70.177257412415 34 | 58.5083951226381,58.64825851308611,56.64580492038498 35 | 98.63267949618661,91.85656651440398,84.3077657901297 36 | 62.656453030908146,70.95837183801083,52.263595314015056 37 | 150.82456054739794,83.84455544415775,70.52759238965115 38 | 70.53311243297887,66.7624804135608,51.90057915616117 39 | 94.77901018007448,69.59008016610004,58.249956262795294 40 | 73.2855356774585,53.47869420846749,54.75890697840821 41 | 88.65641270994786,68.81167688946385,62.82220866559513 42 | 100.63699539614277,85.87339000813992,75.47250687876664 43 | 93.92423172800183,98.45195338052697,95.59149059278039 44 | 67.54125409384991,65.27772251527884,48.67933806596383 45 | 111.00218601795991,94.47547467548722,102.19784552262631 46 | 81.27896483111597,70.79379943721517,60.842375418198735 47 | 60.82413303829185,66.69436729738278,52.422776283862504 48 | 136.5889810001349,99.42699984090851,99.68073663215715 49 | 107.96699559593311,59.60551016464139,55.41171769422462 50 | 103.09433554161522,90.79302444794281,103.09633160574492 51 | 81.22542606978179,70.87613193806962,74.70853905560443 52 | 71.99668015561602,78.10887545648949,69.51120476600627 53 | 102.70814857747239,51.91456189405935,50.573211943404154 54 | 53.511258862954406,62.04176302437264,43.22362665046503 55 | 82.00437466290701,100.20800623210557,86.29265093179247 56 | 71.35958587257332,70.01930597691549,61.81529205757592 57 | 58.6773763558027,62.97642519759425,63.731853158758305 58 | 73.39343117486042,57.81376480550666,59.318277565278315 59 | 104.6338874580552,98.80005050236,91.87704160521882 60 | 110.71551416857142,61.49432185882497,57.006398021009275 61 | 88.7675338058274,103.7085779770701,73.51704824953123 62 | 89.3750883975488,64.9294504806097,48.90174488156867 63 | 67.85100929310931,76.31549929238281,65.03208315944289 64 | 80.45663271014111,114.87138142052412,90.94270867968869 65 | 76.04878808429119,69.64330511322827,66.39693830857372 66 | 70.91329062675193,58.016216124309786,55.95667406860054 67 | 83.62515815823299,58.82845989023744,51.91653738020832 68 | 83.21993141743042,54.706896132142376,59.215037836013934 69 | 76.32963636513638,60.573256586363065,53.313561057697235 70 | 73.49608957616951,65.05292865391067,55.94232092704943 71 | 128.08895797265586,84.55217046712254,72.10608299010558 72 | 60.136669002945595,67.43027189934023,58.75570643940742 73 | 51.043351924383025,62.658562273931786,44.02261819681001 74 | 85.81973292341661,91.23928882679313,46.88962992692023 75 | 53.706621067400036,63.620341081245115,41.800939658772634 76 | 76.58894978915971,90.54451446379665,78.93508652682976 77 | 101.91969295725525,76.05334610702033,67.44648734814635 78 | 93.92188583819993,97.96291766875339,82.61091578907802 79 | 106.47061628304043,58.83928390809533,48.355656718253236 80 | 91.55380268409367,71.7438950959756,44.00840910112393 81 | 122.51660541794979,100.0631886927996,78.85956533759311 82 | 92.3262383463215,64.49767300018517,52.14075505811637 83 | 75.37177251865472,96.4559406957907,78.72544320376555 84 | 96.89139134577348,74.08019041436707,72.87278254020815 85 | 96.95807429229157,87.47531892863265,72.2270455052637 86 | 100.64694484982013,98.71363327233293,70.32145246780425 87 | 96.78024967191323,94.4798474963899,87.53468640300532 88 | 89.09803675414615,92.8191530074605,84.86820709499624 89 | 59.77256735660881,63.21892641137063,60.40832407090482 90 | 113.5178750277595,93.19041064659213,89.09870553851951 91 | 106.2114983498339,98.03826394515394,95.99921241374507 92 | 99.08545680593431,89.62615005130404,84.06051106280735 93 | 73.81287215584426,74.35789026416232,63.90214623819371 94 | 126.16259146366114,107.25575884836189,85.79472038785335 95 | 101.99648998129643,65.19292065673817,48.70020928290751 96 | 101.2183027734578,94.8858716339453,92.17085401391661 97 | 74.82477892418768,67.70526574467101,70.17106858471394 98 | 89.82041647827641,106.87478840233118,82.30956155044692 99 | 74.8392555336231,84.42826530072166,64.95627902297136 100 | 47.511043788655485,87.96251291113191,45.18759427472751 101 | 107.30739929097481,66.5885553205286,59.38243970915839 102 | -------------------------------------------------------------------------------- /mil_rental_data/rmse_set_transformer.csv: -------------------------------------------------------------------------------- 1 | SetTransformer,STGCN,B-STGCN 2 | 107.74489745003906,88.96274681991281,79.8866483667164 3 | 65.93918112749216,61.30842038101464,52.951897902075586 4 | 78.9336152012084,78.60399196156536,67.3932932679892 5 | 65.31228837839062,73.71022404352553,55.849128396870825 6 | 83.58091549968836,74.35893459668378,94.72759609524947 7 | 65.2958722942725,66.04151542560494,73.14094964043561 8 | 80.00407146919096,92.49203513696992,47.19827391522084 9 | 96.50278646889622,87.63458221040295,95.49555604667106 10 | 87.62335692192104,78.88095614854993,86.14049734236036 11 | 93.89751122003894,70.92438215835917,68.11697969389508 12 | 79.284793875654,73.61104846545304,63.77912078046855 13 | 74.28821836594489,61.96305739261981,86.14155710836737 14 | 72.83342200758896,69.27800994400981,69.74944633777395 15 | 89.91576455193527,91.56750761925214,92.72481325485684 16 | 51.01989302636401,63.86343905646233,58.12246023875906 17 | 32.69242831369709,53.675774674303106,44.76761312174982 18 | 104.83838114144463,80.96485738893462,64.33384655119055 19 | 70.24515960429065,44.22030571307682,46.55277867154983 20 | 62.99894779749058,69.62327244899888,59.74039368375554 21 | 63.46095947608567,48.62288752035799,71.94532266023799 22 | 46.50524868058543,43.68339018466663,53.962456812681964 23 | 97.40880409295778,75.1316789286224,82.03119806077349 24 | 74.5965845515058,67.87472027139214,69.02396962754557 25 | 64.73642387697562,83.7053351156072,70.4381964969151 26 | 67.68113806227865,62.63935272893113,60.171764749060884 27 | 64.81405945373109,52.132184329147144,54.10827238405278 28 | 80.11333540231065,56.12916899197281,55.52022538655292 29 | 60.23553076681213,67.16576253518548,52.489376918457715 30 | 81.73355814766207,67.17974012858848,67.02744763778652 31 | 102.74548732348599,86.5495464430814,95.58494679491193 32 | 109.09161310916659,91.75374863374432,99.88064142584638 33 | 73.8223997609169,88.15801401694739,83.5848664719863 34 | 60.99660223918822,63.77064779690774,54.72835382153652 35 | 71.9812209476013,82.84590603989217,79.24094219874303 36 | 77.01076238242135,55.35527743760914,47.50035352767226 37 | 81.24392052995248,65.70261151665787,81.64442462418017 38 | 83.26604152772525,61.19049826250766,52.3030433030523 39 | 73.35987363280384,88.46005761793256,65.90382815661789 40 | 49.45133130160623,54.17735266532456,55.12811196396317 41 | 71.70842894620739,65.32976937301093,75.566064668103 42 | 84.0210939407412,63.011572388661335,60.54608336282437 43 | 90.69440447534971,96.29062748768918,103.44001475071384 44 | 84.4422326051343,77.78283792998661,60.48229676709241 45 | 101.96366810205053,99.9232481349239,106.17157706724015 46 | 62.52737250235659,77.77773459076845,58.22425436487797 47 | 71.5873275296794,63.43524728917799,59.510172379972886 48 | 81.5415089981121,88.97814943848056,90.94290417050553 49 | 63.86669037741584,60.62064767595761,51.26482758636035 50 | 85.89169926648414,93.24575512573435,82.997421711922 51 | 77.84645990187283,87.5041898355806,77.7533239611412 52 | 58.087014666971555,57.40264761772343,57.06429931425577 53 | 58.75980145757038,64.77637602654046,67.69115439439334 54 | 63.75510627697014,52.918175736173254,56.3717679511576 55 | 84.80580436846759,92.5635641979472,93.28101549569539 56 | 87.48394110255366,65.9083193009097,59.029414160943304 57 | 82.68265062986248,65.4161094356103,59.66722867331115 58 | 63.58687099560658,68.08392116787091,55.72452872362071 59 | 91.75179372557606,97.42530763349396,99.24310471621823 60 | 63.07016818875707,73.30753868334168,59.58293097529808 61 | 94.0476070113992,97.1439243251501,70.3739366076266 62 | 52.196907223001794,64.32241033840627,87.11817778427387 63 | 85.96822363228083,68.21509036146539,73.55795727519948 64 | 88.1516554209054,79.19376717794601,81.4557811304391 65 | 68.54096840861767,59.11508029467414,65.05227015852768 66 | 96.77790378211132,54.89851828859243,77.10150613289446 67 | 54.34329350183801,58.84060089886131,39.88894172481456 68 | 63.227790376565096,73.08250817519524,48.63854221918954 69 | 76.95897789394341,70.62588825904793,60.85794266061223 70 | 70.42731074511505,59.30490187781133,63.69636128651419 71 | 79.70742958814438,79.3383326369882,68.29207773182821 72 | 62.790744933076645,56.996921860888435,54.27137860372227 73 | 73.7963120258611,73.04127504633068,53.469994867118764 74 | 86.40935894943706,59.06083673750035,56.84016909276577 75 | 65.89750557204216,49.96008071890807,53.5476664553405 76 | 75.4130313699952,55.969550740035096,75.90202592580732 77 | 66.32868629102586,52.32027221740881,63.53494246126362 78 | 79.26028035612318,81.7462393282798,104.01297775684755 79 | 46.31131150130455,58.379432917475675,46.61339111375597 80 | 63.984417005106,61.8268774607204,52.268539173882665 81 | 85.83573230342519,93.10741965035469,90.41242440556502 82 | 72.27689566355411,61.42287510977057,56.40859739214842 83 | 87.91613014259256,84.56547927615217,69.2186733366082 84 | 63.18582158488985,60.747804163312,62.7225134935095 85 | 61.797522971225554,66.58877138932614,65.94979936554287 86 | 80.61611206070897,57.41205175491176,54.85345765531424 87 | 91.61798540595446,100.38883523766881,95.31633212360386 88 | 91.58543104045789,95.24696375060202,95.75965385121143 89 | 53.408291791934545,55.990231610657126,56.79126552158841 90 | 98.99199161751119,97.96835025566307,96.41139965652566 91 | 108.31138868123682,101.58080448179372,100.23077576777052 92 | 92.40297363642054,93.35615599228875,87.39002319855473 93 | 71.20203570769782,66.32260549772356,68.4480691146643 94 | 104.42499008679202,88.05585463167073,78.90044864078547 95 | 73.13609838148123,63.81600166641116,47.55924770848842 96 | 90.46568021966434,92.8670379685923,89.23062068391502 97 | 68.83339694811215,71.81984842280734,69.41582582537634 98 | 80.76694865937509,83.1136512915149,80.55300453834072 99 | 72.01500999194097,68.0879853190628,67.80370565993282 100 | 73.08539938148618,55.1529804536614,53.45030173957122 101 | 84.72873983067706,58.94783275638507,67.57699290186265 102 | -------------------------------------------------------------------------------- /mil_rental_data/set_transformer_main.py: -------------------------------------------------------------------------------- 1 | import numpy as np # linear algebra 2 | import pandas as pd # data processing, CSV file I/O (e.g. pd.read_csv) 3 | import matplotlib.pyplot as plt 4 | from numpy import matlib 5 | import seaborn as sns 6 | from scipy import stats 7 | from sklearn.preprocessing import StandardScaler 8 | from torch import optim 9 | import torch.nn as nn 10 | import matplotlib as mpl 11 | import scipy.sparse as sp 12 | import time 13 | import torch 14 | import sys 15 | import glob 16 | import os 17 | from models_all_set_transformer import * 18 | from scipy.stats import wilcoxon 19 | from utils import * 20 | from sklearn import preprocessing 21 | from subprocess import call 22 | mpl.rcParams['figure.dpi'] = 600 23 | color = sns.color_palette() 24 | pd.options.mode.chained_assignment = None # default='warn' 25 | torch.manual_seed(0) 26 | np.random.seed(0) 27 | 28 | num_trials = 100 29 | lr_ = 5e-4 30 | weight_decay_ = 1e-3 31 | EPOCHS = 500 32 | MC_samples = 20 33 | 34 | adj_orig = np.loadtxt('data/adj_nbhd.txt', dtype='float', delimiter=',') 35 | print(adj_orig) 36 | num_neib = int(np.sum(adj_orig)/adj_orig.shape[0]) - 1 37 | print(num_neib) 38 | 39 | files = glob.glob('*.pkl') 40 | for file in files: 41 | os.remove(file) 42 | log_dir = 'log_st' 43 | 44 | check = os.path.isdir(log_dir) 45 | if not check: 46 | os.makedirs(log_dir) 47 | print("created folder : ", log_dir) 48 | else: 49 | print(log_dir, " folder already exists.") 50 | 51 | for f in os.listdir(log_dir): 52 | os.remove(os.path.join(log_dir, f)) 53 | 54 | train_df = pd.read_csv('data/neighbourhood_data.csv') 55 | districts = pd.read_csv('data/districts_data_cleaned.csv') 56 | names = list(districts['name'].unique()) 57 | train_df = train_df[train_df['neighbourhood'].isin(names)] 58 | counts = train_df['neighbourhood'].value_counts() 59 | count_list = counts[counts > 50].index.tolist() 60 | train_df = train_df[train_df['neighbourhood'].isin(count_list)] 61 | 62 | 63 | train_df = train_df.loc[:, ~train_df.columns.str.contains('^Unnamed')] 64 | train_df = train_df.astype({'bathrooms': 'int64'}) 65 | train_df = train_df.astype({'bedrooms': 'int64'}) 66 | train_df = train_df.astype({'interest_level': 'category'}) 67 | train_df = train_df.astype({'num_photos': 'int64'}) 68 | train_df = train_df.astype({'num_features': 'int64'}) 69 | train_df = train_df.astype({'num_description_words': 'int64'}) 70 | train_df = train_df.astype({'created_month': 'category'}) 71 | train_df = train_df.astype({'created_day': 'category'}) 72 | train_df = train_df.astype({'neighbourhood': 'str'}) 73 | train_df = train_df.astype({'price': 'float64'}) 74 | non_standardized = train_df.copy() 75 | train_df['price'] = (train_df['price'] - train_df['price'].mean()) / train_df['price'].std() 76 | 77 | train_df['interest_level'] = pd.Categorical(train_df['interest_level'], categories=train_df['interest_level'].unique()).codes 78 | train_df['neighbourhood'] = pd.Categorical(train_df['neighbourhood'], categories=train_df['neighbourhood'].unique()).codes 79 | 80 | 81 | def train_Set(n=25, seed=None): 82 | train_samples = train_df.groupby('neighbourhood').apply(pd.DataFrame.sample, n, replace = False, random_state=seed) 83 | train_samples = train_samples.astype({'price': 'float64'}) 84 | features = [] 85 | for i, hood in enumerate(train_samples['neighbourhood'].unique()): 86 | get_boro = train_samples[train_samples['neighbourhood'] == hood] 87 | sample = get_boro.to_numpy() 88 | features.append(sample) 89 | return np.array(features) 90 | 91 | 92 | non_standardized['price'] = non_standardized['price'].clip(upper = np.percentile(non_standardized['price'].values, 95)) 93 | df = train_df[['price', 'neighbourhood']] 94 | mean = df.groupby('neighbourhood').mean().values 95 | 96 | 97 | def get_performance_test(trial_i): 98 | torch.manual_seed(trial_i) 99 | np.random.seed(trial_i) 100 | 101 | indices = np.arange(0, mean.shape[0]) 102 | np.random.shuffle(indices) 103 | idx_train = indices[:55] 104 | idx_test = indices[55:] 105 | 106 | X = train_Set(n=25, seed=trial_i) 107 | X = np.array(X, dtype='float') 108 | 109 | features = torch.FloatTensor(X) 110 | labels = torch.FloatTensor(mean) 111 | 112 | mserr = [] 113 | maerr = [] 114 | mperr = [] 115 | 116 | adj_norm = normalize(adj_orig) 117 | adj_norm = torch.FloatTensor(adj_norm) 118 | 119 | models = [ 120 | SetTransformer(in_features=10), 121 | STGCN(in_features=10) 122 | ] 123 | 124 | def weight_reset(m): 125 | reset_parameters = getattr(m, "reset_parameters", None) 126 | if callable(reset_parameters): 127 | m.reset_parameters() 128 | 129 | for model in models: 130 | model.apply(weight_reset) 131 | 132 | mse = nn.MSELoss() 133 | mae = nn.L1Loss() 134 | 135 | def train(epoch, adj_=None): 136 | t = time.time() 137 | model.train() 138 | optimizer.zero_grad() 139 | if adj_ is not None: 140 | output, _ = model(features, adj_) 141 | else: 142 | output, _ = model(features) 143 | loss_train = mse(output[idx_train], labels[idx_train]) 144 | loss_train.backward() 145 | optimizer.step() 146 | 147 | loss_test = mse(output[idx_test], labels[idx_test]) 148 | return loss_test.item() 149 | 150 | for network in (models): 151 | 152 | t_total = time.time() 153 | # Model and optimizer 154 | model = network 155 | 156 | no_decay = list() 157 | decay = list() 158 | for m in model.modules(): 159 | if isinstance(m, torch.nn.Linear) or isinstance(m, GraphConvolution): 160 | decay.append(m.weight) 161 | no_decay.append(m.bias) 162 | 163 | optimizer = optim.Adam([{'params': no_decay, 'weight_decay': 0}, {'params': decay, 'weight_decay': weight_decay_}], lr=lr_) 164 | 165 | # Train model 166 | t_total = time.time() 167 | for epoch in range(EPOCHS): 168 | if isinstance(model, SetTransformer): 169 | value = train(epoch) 170 | else: 171 | value = train(epoch, adj_norm) 172 | 173 | model.eval() 174 | with torch.no_grad(): 175 | if isinstance(model, (SetTransformer)): 176 | output, decoding = model(features) 177 | np.savetxt(log_dir + '/decoding_set_transformer' + '_trial_' + str(trial_i), decoding, delimiter=',') 178 | 179 | if isinstance(model, (STGCN)): 180 | output, decoding = model(features, adj_norm) 181 | np.savetxt(log_dir + '/decoding_set_transformer_gcn' + '_trial_' + str(trial_i), decoding, delimiter=',') 182 | 183 | maerr.append(non_standardized['price'].std() * (mae(output[idx_test], labels[idx_test]).detach().cpu().numpy())) 184 | mserr.append(non_standardized['price'].std() * np.sqrt(mse(output[idx_test], labels[idx_test]).detach().numpy())) 185 | 186 | target = (non_standardized['price'].std() * labels[idx_test]).detach().cpu().numpy() + non_standardized['price'].mean() 187 | pred = (non_standardized['price'].std() * output[idx_test].T).detach().cpu().numpy() + non_standardized['price'].mean() 188 | 189 | mperr.append(100 * np.mean(np.abs((target - pred) / target))) 190 | model.apply(weight_reset) 191 | return mserr, maerr, mperr 192 | 193 | 194 | mse = [] 195 | mae = [] 196 | mpe = [] 197 | print('SetTransformer', 'STGCN') 198 | for i in range(num_trials): 199 | t = time.time() 200 | mserr, maerr, mperr = get_performance_test(i) 201 | mse.append(mserr) 202 | mae.append(maerr) 203 | mpe.append(mperr) 204 | print('run : ' + str(i) + ', RMSE : ' + str(np.around(np.array(mserr), 2)) + ', MAE : ' + str(np.around(np.array(maerr), 2)) + 205 | ', MAPE : ' + str(np.around(np.array(mperr), 2)) + ', run time: {:.2f}s'.format(time.time() - t)) 206 | 207 | 208 | mse = np.array(mse) 209 | mae = np.array(mae) 210 | mpe = np.array(mpe) 211 | 212 | columns_ = ['SetTransformer', 'STGCN'] 213 | df_mse = pd.DataFrame(mse, columns=columns_) 214 | df_mae = pd.DataFrame(mae, columns=columns_) 215 | df_mpe = pd.DataFrame(mpe, columns=columns_) 216 | 217 | 218 | def get_performance_test_bayesian(trial_i): 219 | torch.manual_seed(trial_i) 220 | np.random.seed(trial_i) 221 | 222 | indices = np.arange(0, mean.shape[0]) 223 | np.random.shuffle(indices) 224 | idx_train = indices[:55] 225 | idx_test = indices[55:] 226 | 227 | X = train_Set(n=25, seed=trial_i) 228 | X = np.array(X, dtype='float') 229 | 230 | features = torch.FloatTensor(X) 231 | labels = torch.FloatTensor(mean) 232 | 233 | mserr = [] 234 | maerr = [] 235 | mperr = [] 236 | 237 | models = [ 238 | STGCN(in_features=10) 239 | ] 240 | 241 | def weight_reset(m): 242 | reset_parameters = getattr(m, "reset_parameters", None) 243 | if callable(reset_parameters): 244 | m.reset_parameters() 245 | 246 | for model in models: 247 | model.apply(weight_reset) 248 | 249 | mse = nn.MSELoss() 250 | mae = nn.L1Loss() 251 | 252 | def train_(epoch, adj_=None): 253 | t = time.time() 254 | model.train() 255 | optimizer.zero_grad() 256 | if adj_ is not None: 257 | output, _ = model(features, adj_) 258 | else: 259 | output, _ = model(features) 260 | loss_train = mse(output[idx_train], labels[idx_train]) 261 | loss_train.backward() 262 | optimizer.step() 263 | 264 | loss_test = mse(output[idx_test], labels[idx_test]) 265 | return loss_test.item(), output 266 | 267 | for network in (models): 268 | 269 | t_total = time.time() 270 | # Model and optimizer 271 | model = network 272 | 273 | if isinstance(model, (STGCN)): 274 | embedding = np.loadtxt(log_dir + '/decoding_set_transformer_gcn' + '_trial_' + str(trial_i), delimiter=",") 275 | 276 | adj_np = MAP_inference(embedding, num_neib, 1) 277 | adj_np_norm = normalize(adj_np) 278 | adj_norm = torch.FloatTensor(adj_np_norm) 279 | 280 | no_decay = list() 281 | decay = list() 282 | for m in model.modules(): 283 | if isinstance(m, torch.nn.Linear) or isinstance(m, GraphConvolution): 284 | decay.append(m.weight) 285 | no_decay.append(m.bias) 286 | 287 | optimizer = optim.Adam([{'params': no_decay, 'weight_decay': 0}, {'params': decay, 'weight_decay': weight_decay_}], lr=lr_) 288 | 289 | # Train model 290 | t_total = time.time() 291 | output_ = 0.0 292 | for epoch in range(EPOCHS): 293 | if isinstance(model, SetTransformer): 294 | value, output = train_(epoch) 295 | else: 296 | value, output = train_(epoch, adj_norm) 297 | 298 | if epoch >= EPOCHS - MC_samples: 299 | output_ += output 300 | 301 | output = output_ / np.float32(MC_samples) 302 | 303 | maerr.append(non_standardized['price'].std() * (mae(output[idx_test], labels[idx_test]).detach().cpu().numpy())) 304 | mserr.append( 305 | non_standardized['price'].std() * np.sqrt(mse(output[idx_test], labels[idx_test]).detach().numpy())) 306 | 307 | target = (non_standardized['price'].std() * labels[idx_test]).detach().cpu().numpy() + non_standardized[ 308 | 'price'].mean() 309 | pred = (non_standardized['price'].std() * output[idx_test].T).detach().cpu().numpy() + non_standardized[ 310 | 'price'].mean() 311 | 312 | mperr.append(100 * np.mean(np.abs((target - pred) / target))) 313 | model.apply(weight_reset) 314 | adj_norm = None 315 | return mserr, maerr, mperr 316 | 317 | 318 | mse_b = [] 319 | mae_b = [] 320 | mpe_b = [] 321 | print('B-STGCN') 322 | for i in range(num_trials): 323 | t = time.time() 324 | mserr_b, maerr_b, mperr_b = get_performance_test_bayesian(i) 325 | mse_b.append(mserr_b) 326 | mae_b.append(maerr_b) 327 | mpe_b.append(mperr_b) 328 | print('run : ' + str(i+1) + ', RMSE : ' + str(np.around(np.array(mserr_b), 2)) + ', MAE : ' + str(np.around(np.array(maerr_b), 2)) + 329 | ', MAPE : ' + str(np.around(np.array(mperr_b), 2)) + ', run time: {:.2f}s'.format(time.time() - t)) 330 | 331 | 332 | mse_b = np.array(mse_b) 333 | mae_b = np.array(mae_b) 334 | mpe_b = np.array(mpe_b) 335 | 336 | columns_b = ['B-STGCN'] 337 | df_mse_b = pd.DataFrame(mse_b, columns = columns_b) 338 | df_mae_b = pd.DataFrame(mae_b, columns = columns_b) 339 | df_mpe_b = pd.DataFrame(mpe_b, columns = columns_b) 340 | 341 | mse_concat = pd.concat([df_mse, df_mse_b], axis=1) 342 | mae_concat = pd.concat([df_mae, df_mae_b], axis=1) 343 | mpe_concat = pd.concat([df_mpe, df_mpe_b], axis=1) 344 | 345 | mse_concat.to_csv('rmse_set_transformer.csv', index=False) 346 | mae_concat.to_csv('mae_set_transformer.csv', index=False) 347 | mpe_concat.to_csv('mape_set_transformer.csv', index=False) 348 | 349 | -------------------------------------------------------------------------------- /mil_rental_data/utils.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from numpy import matlib 3 | from scipy.stats import t 4 | import scipy.sparse as sp 5 | from math import sqrt 6 | from statistics import stdev 7 | 8 | 9 | def normalize(mx): 10 | """Row-normalize sparse matrix""" 11 | rowsum = np.array(mx.sum(1)) 12 | r_inv = np.power(rowsum, -1).flatten() 13 | r_inv[np.isinf(r_inv)] = 0. 14 | r_inv_sqrt = np.sqrt(r_inv) 15 | r_mat_inv_sqrt = np.diag(r_inv_sqrt) 16 | mx = r_mat_inv_sqrt.dot(mx) 17 | mx = mx.dot(r_mat_inv_sqrt) 18 | return mx 19 | 20 | 21 | def compute_distance(embed): 22 | N = embed.shape[0] 23 | p = np.dot(embed, np.transpose(embed)) 24 | q = np.matlib.repmat(np.diag(p), N, 1) 25 | dist = q + np.transpose(q) - 2 * p 26 | dist[dist < 1e-8] = 1e-8 27 | return dist 28 | 29 | 30 | def estimate_graph(gamma, epsilon, dist, max_iter, k, r): 31 | np.random.seed(0) 32 | 33 | N = dist.shape[0] 34 | dist += 1e10 * np.eye(N) 35 | 36 | deg_exp = np.minimum(int(N-1), int(k * r)) 37 | 38 | dist_sort_col_idx = np.argsort(dist, axis=0) 39 | dist_sort_col_idx = np.transpose(dist_sort_col_idx[0:deg_exp, :]) 40 | 41 | dist_sort_row_idx = np.matlib.repmat(np.arange(N).reshape(N, 1), 1, deg_exp) 42 | 43 | dist_sort_col_idx = np.reshape(dist_sort_col_idx, int(N * deg_exp)).astype(int) 44 | dist_sort_row_idx = np.reshape(dist_sort_row_idx, int(N * deg_exp)).astype(int) 45 | 46 | dist_idx = np.zeros((int(N * deg_exp), 2)).astype(int) 47 | dist_idx[:, 0] = dist_sort_col_idx 48 | dist_idx[:, 1] = dist_sort_row_idx 49 | dist_idx = np.sort(dist_idx, axis=1) 50 | dist_idx = np.unique(dist_idx, axis=0) 51 | dist_sort_col_idx = dist_idx[:, 0] 52 | dist_sort_row_idx = dist_idx[:, 1] 53 | 54 | num_edges = len(dist_sort_col_idx) 55 | 56 | w_init = np.random.uniform(0, 1, size=(num_edges, 1)) 57 | d_init = k * np.random.uniform(0, 1, size=(N, 1)) 58 | 59 | w_current = w_init 60 | d_current = d_init 61 | 62 | dist_sorted = np.sort(dist, axis=0) 63 | 64 | B_k = np.sum(dist_sorted[0:k, :], axis=0) 65 | dist_sorted_k = dist_sorted[k-1, :] 66 | dist_sorted_k_plus_1 = dist_sorted[k, :] 67 | 68 | theta_lb = 1 / np.sqrt(k * dist_sorted_k_plus_1 ** 2 - B_k * dist_sorted_k_plus_1) 69 | theta_lb = theta_lb[~np.isnan(theta_lb)] 70 | theta_lb = theta_lb[~np.isinf(theta_lb)] 71 | theta_lb = np.mean(theta_lb) 72 | 73 | theta_ub = 1 / np.sqrt(k * dist_sorted_k ** 2 - B_k * dist_sorted_k) 74 | theta_ub = theta_ub[~np.isnan(theta_ub)] 75 | theta_ub = theta_ub[~np.isinf(theta_ub)] 76 | if len(theta_ub) > 0: 77 | theta_ub = np.mean(theta_ub) 78 | else: 79 | theta_ub = theta_lb 80 | 81 | theta = (theta_lb + theta_ub) / 2 82 | 83 | dist = theta * dist 84 | 85 | z = dist[dist_sort_row_idx, dist_sort_col_idx] 86 | z.shape = (num_edges, 1) 87 | 88 | for iter in range(max_iter): 89 | 90 | # print('Graph inference epoch : ' + str(iter)) 91 | 92 | St_times_d = d_current[dist_sort_row_idx] + d_current[dist_sort_col_idx] 93 | y_current = w_current - gamma * (2 * w_current + St_times_d) 94 | 95 | adj_current = np.zeros((N, N)) 96 | adj_current[dist_sort_row_idx, dist_sort_col_idx] = np.squeeze(w_current) 97 | adj_current = adj_current + np.transpose(adj_current) 98 | S_times_w = np.sum(adj_current, axis=1) 99 | S_times_w.shape = (N, 1) 100 | y_bar_current = d_current + gamma * S_times_w 101 | 102 | p_current = np.maximum(0, np.abs(y_current) - 2 * gamma * z) 103 | p_bar_current = (y_bar_current - np.sqrt(y_bar_current * y_bar_current + 4 * gamma)) / 2 104 | 105 | St_times_p_bar = p_bar_current[dist_sort_row_idx] + p_bar_current[dist_sort_col_idx] 106 | q_current = p_current - gamma * (2 * p_current + St_times_p_bar) 107 | 108 | p_matrix_current = np.zeros((N, N)) 109 | p_matrix_current[dist_sort_row_idx, dist_sort_col_idx] = np.squeeze(p_current) 110 | p_matrix_current = p_matrix_current + np.transpose(p_matrix_current) 111 | S_times_p = np.sum(p_matrix_current, axis=1) 112 | S_times_p.shape = (N, 1) 113 | q_bar_current = p_bar_current + gamma * S_times_p 114 | 115 | w_updated = np.abs(w_current - y_current + q_current) 116 | d_updated = np.abs(d_current - y_bar_current + q_bar_current) 117 | 118 | if (np.linalg.norm(w_updated - w_current) / np.linalg.norm(w_current) < epsilon) and \ 119 | (np.linalg.norm(d_updated - d_current) / np.linalg.norm(d_current) < epsilon): 120 | break 121 | else: 122 | w_current = w_updated 123 | d_current = d_updated 124 | 125 | upper_tri_index = np.triu_indices(N, k=1) 126 | 127 | z = dist[upper_tri_index[0], upper_tri_index[1]] 128 | z.shape = (int(N * (N - 1) / 2), 1) 129 | z = z * np.max(w_current) 130 | 131 | w_current = w_current / np.max(w_current) 132 | 133 | inferred_graph = np.zeros((N, N)) 134 | inferred_graph[dist_sort_row_idx, dist_sort_col_idx] = np.squeeze(w_current) 135 | inferred_graph = inferred_graph + np.transpose(inferred_graph) + np.eye(N) 136 | 137 | return inferred_graph 138 | 139 | 140 | def MAP_inference(x, num_neib, r): 141 | N = x.shape[0] 142 | k = int(num_neib) 143 | 144 | dist = compute_distance(x) 145 | 146 | inferred_graph = estimate_graph(0.01, 0.001, dist, 1000, k, r) 147 | 148 | return inferred_graph 149 | -------------------------------------------------------------------------------- /mil_text/evaluate.py: -------------------------------------------------------------------------------- 1 | import csv 2 | import numpy as np 3 | 4 | data_all = ['alt.atheism', 'comp.graphics', 'comp.os.ms-windows.misc', 'comp.sys.ibm.pc.hardware', 'comp.sys.mac.hardware', 'comp.windows.x', 5 | 'misc.forsale', 'rec.autos', 'rec.motorcycles', 'rec.sport.baseball', 'rec.sport.hockey', 'sci.crypt', 'sci.electronics', 6 | 'sci.med', 'sci.space', 'soc.religion.christian', 'talk.politics.guns', 'talk.politics.mideast', 'talk.politics.misc', 'talk.religion.misc'] 7 | for data_index in range(len(data_all)): 8 | 9 | data_name = data_all[data_index] 10 | 11 | pool_ = 'mean' 12 | 13 | k_gcn_list = [2, 2, 2, 3, 3, 4, 3, 4, 2, 2, 3, 3, 3, 3, 4, 3, 2, 2, 4, 2] 14 | k_bgcn_list = [3, 3, 3, 4, 3, 3, 3, 2, 3, 2, 4, 3, 2, 3, 4, 4, 3, 2, 3, 4] 15 | r_list = [10, 5, 10, 5, 5, 10, 1, 1, 10, 10, 1, 10, 1, 10, 5, 10, 10, 5, 10, 5] 16 | 17 | prnt_str_head = '' 18 | prnt_str = data_name.ljust(30) 19 | 20 | alg_all = ['vanilla'] 21 | k_all = [0] 22 | r_all = [0] 23 | 24 | alg_all.append('GCN') 25 | k_all.append(k_gcn_list[data_index]) 26 | r_all.append(0) 27 | 28 | alg_all.append('BGCN') 29 | k_all.append(k_bgcn_list[data_index]) 30 | r_all.append(r_list[data_index]) 31 | 32 | for idx, alg_ in enumerate(alg_all): 33 | 34 | file_name = 'accuracy_' + data_name + '_res_pool_num_neib_' + str(k_all[idx]) + '_' + pool_ + '_r_' + str(r_all[idx]) + '_' + alg_ + '.csv' 35 | 36 | raw_acc = np.loadtxt(open(file_name, "rb"), delimiter=",", skiprows=1) 37 | # print(raw_acc.shape) 38 | if raw_acc.ndim == 1: 39 | raw_acc = np.expand_dims(raw_acc, axis=1) 40 | 41 | num_alg = raw_acc.shape[1] 42 | 43 | for i in range(num_alg): 44 | # print('----------------------------------------------------------------------') 45 | # print('Algorithm : ' + algorithms[i]) 46 | acc = np.squeeze(raw_acc[:, i]) 47 | acc = np.reshape(acc, [10, 10]) 48 | mu_ = np.mean(acc) * 100 49 | sigma_ = np.std(np.mean(acc, axis=0)) * 100 50 | new_str = '&' + "{:.1f}".format(mu_) + '$\pm$' + "{:.1f}".format(sigma_) + ' ' 51 | # new_str = "{:.1f}".format(mu_) + ' ' 52 | prnt_str = prnt_str + new_str 53 | 54 | new_str_ = alg_ + '_k_' + str(k_all[idx]) + '_r_' + str(r_all[idx]) 55 | prnt_str_head = prnt_str_head + new_str_.rjust(16) + ' ' 56 | 57 | # print(prnt_str_head) 58 | print(prnt_str) 59 | 60 | 61 | 62 | 63 | 64 | 65 | 66 | -------------------------------------------------------------------------------- /mil_text/evaluate_all.py: -------------------------------------------------------------------------------- 1 | import csv 2 | import numpy as np 3 | 4 | datasets = ['alt.atheism', 'comp.graphics', 'comp.os.ms-windows.misc', 'comp.sys.ibm.pc.hardware', 'comp.sys.mac.hardware', 'comp.windows.x', 5 | 'misc.forsale', 'rec.autos', 'rec.motorcycles', 'rec.sport.baseball', 'rec.sport.hockey', 'sci.crypt', 'sci.electronics', 6 | 'sci.med', 'sci.space', 'soc.religion.christian', 'talk.politics.guns', 'talk.politics.mideast', 'talk.politics.misc', 'talk.religion.misc'] 7 | for data_index in range(len(datasets)): 8 | data_name = datasets[data_index] 9 | print(data_name.ljust(25) + str(data_index+1) + ' -----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------') 10 | 11 | pool_ = 'mean' 12 | 13 | alg_list = ['vanilla', 'GCN', 'BGCN'] 14 | 15 | k_list = [2, 3, 4] 16 | r_list = [1, 5, 10] 17 | 18 | prnt_str_head = '' 19 | prnt_str = '' 20 | 21 | alg_all = ['vanilla'] 22 | k_all = [0] 23 | r_all = [0] 24 | 25 | for k_ in k_list: 26 | alg_all.append('GCN') 27 | k_all.append(k_) 28 | r_all.append(0) 29 | 30 | for k_ in k_list: 31 | for r_ in r_list: 32 | alg_all.append('BGCN') 33 | k_all.append(k_) 34 | r_all.append(r_) 35 | 36 | for idx, alg_ in enumerate(alg_all): 37 | 38 | file_name = 'accuracy_' + data_name + '_res_pool_num_neib_' + str(k_all[idx]) + '_' + pool_ + '_r_' + str(r_all[idx]) + '_' + alg_ + '.csv' 39 | 40 | raw_acc = np.loadtxt(open(file_name, "rb"), delimiter=",", skiprows=1) 41 | # print(raw_acc.shape) 42 | if raw_acc.ndim == 1: 43 | raw_acc = np.expand_dims(raw_acc, axis=1) 44 | 45 | num_alg = raw_acc.shape[1] 46 | 47 | for i in range(num_alg): 48 | # print('----------------------------------------------------------------------') 49 | # print('Algorithm : ' + algorithms[i]) 50 | acc = np.squeeze(raw_acc[:, i]) 51 | acc = np.reshape(acc, [10, 10]) 52 | mu_ = np.mean(acc) * 100 53 | sigma_ = np.std(np.mean(acc, axis=0)) * 100 54 | new_str = '&' + "{:.1f}".format(mu_) + '$\pm$' + "{:.1f}".format(sigma_) 55 | prnt_str = prnt_str + new_str.rjust(16) + ' ' 56 | 57 | new_str_ = alg_ + '_k_' + str(k_all[idx]) + '_r_' + str(r_all[idx]) 58 | prnt_str_head = prnt_str_head + new_str_.rjust(16) + ' ' 59 | 60 | print(prnt_str_head) 61 | print(prnt_str) 62 | 63 | 64 | 65 | -------------------------------------------------------------------------------- /mil_text/models_all.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | import math 5 | import scipy.sparse as sp 6 | from torch.nn.parameter import Parameter 7 | from torch.nn.modules.module import Module 8 | from modules import * 9 | 10 | 11 | # GCN model 12 | class GraphConvolution(Module): 13 | """ 14 | Simple GCN layer, similar to https://arxiv.org/abs/1609.02907 15 | """ 16 | 17 | def __init__(self, in_features, out_features, bias=True): 18 | super(GraphConvolution, self).__init__() 19 | self.in_features = in_features 20 | self.out_features = out_features 21 | self.weight = Parameter(torch.FloatTensor(in_features, out_features)) 22 | if bias: 23 | self.bias = Parameter(torch.FloatTensor(out_features)) 24 | else: 25 | self.register_parameter('bias', None) 26 | self.reset_parameters() 27 | 28 | def reset_parameters(self): 29 | stdv = 1. / math.sqrt(self.in_features) 30 | self.weight.data.uniform_(-stdv, stdv) 31 | if self.bias is not None: 32 | self.bias.data.uniform_(-stdv, stdv) 33 | 34 | def forward(self, input, adj): 35 | support = torch.mm(input, self.weight) 36 | output = torch.mm(adj, support) 37 | if self.bias is not None: 38 | return output + self.bias 39 | else: 40 | return output 41 | 42 | def __repr__(self): 43 | return self.__class__.__name__ + ' (' \ 44 | + str(self.in_features) + ' -> ' \ 45 | + str(self.out_features) + ')' 46 | 47 | 48 | # Deep Set model 49 | class rFF_pool(nn.Module): 50 | def __init__(self, in_features=200, pooling_method='max'): 51 | super(rFF_pool, self).__init__() 52 | self.in_features = in_features 53 | self.pooling_method = pooling_method 54 | 55 | self.ll1 = nn.Linear(in_features, 256) 56 | self.ll2 = nn.Linear(256, 128) 57 | self.ll3 = nn.Linear(128, 64) 58 | self.d3 = nn.Dropout(p=0.5) 59 | 60 | self.fc = nn.Linear(64, 1) 61 | self.reset_parameters() 62 | 63 | def reset_parameters(self): 64 | for module in self.children(): 65 | reset_op = getattr(module, "reset_parameters", None) 66 | if callable(reset_op): 67 | reset_op() 68 | 69 | def forward(self, input): 70 | 71 | x = input 72 | 73 | x = [(F.relu(self.ll1(x_))) for x_ in x] 74 | x = [(F.relu(self.ll2(x_))) for x_ in x] 75 | x = [self.d3(F.relu(self.ll3(x_))) for x_ in x] 76 | 77 | if self.pooling_method == 'max': 78 | x = [torch.unsqueeze(torch.max(x_, axis=0)[0], 0) for x_ in x] 79 | elif self.pooling_method == 'mean': 80 | x = [torch.unsqueeze(x_.mean(dim=0), 0) for x_ in x] 81 | elif self.pooling_method == 'sum': 82 | x = [torch.unsqueeze(x_.sum(dim=0), 0) for x_ in x] 83 | else: 84 | print('Invalid Pooling method!!!!!!') 85 | exit(0) 86 | 87 | x = torch.cat(x, axis=0) 88 | embedding = x.cpu().detach().numpy() 89 | 90 | x = torch.sigmoid(self.fc(x)) 91 | 92 | return x, embedding 93 | 94 | 95 | # Deep Set GCN model 96 | class rFF_pool_GCN(nn.Module): 97 | def __init__(self, in_features=200, pooling_method='max'): 98 | super(rFF_pool_GCN, self).__init__() 99 | self.in_features = in_features 100 | self.pooling_method = pooling_method 101 | 102 | self.ll1 = nn.Linear(in_features, 256) 103 | self.ll2 = nn.Linear(256, 128) 104 | self.ll3 = nn.Linear(128, 64) 105 | self.d3 = nn.Dropout(p=0.5) 106 | 107 | self.gc = GraphConvolution(64, 1) 108 | self.reset_parameters() 109 | 110 | def reset_parameters(self): 111 | for module in self.children(): 112 | reset_op = getattr(module, "reset_parameters", None) 113 | if callable(reset_op): 114 | reset_op() 115 | 116 | def forward(self, input, adj): 117 | 118 | x = input 119 | 120 | x = [(F.relu(self.ll1(x_))) for x_ in x] 121 | x = [(F.relu(self.ll2(x_))) for x_ in x] 122 | x = [self.d3(F.relu(self.ll3(x_))) for x_ in x] 123 | 124 | if self.pooling_method == 'max': 125 | x = [torch.unsqueeze(torch.max(x_, axis=0)[0], 0) for x_ in x] 126 | elif self.pooling_method == 'mean': 127 | x = [torch.unsqueeze(x_.mean(dim=0), 0) for x_ in x] 128 | elif self.pooling_method == 'sum': 129 | x = [torch.unsqueeze(x_.sum(dim=0), 0) for x_ in x] 130 | else: 131 | print('Invalid Pooling method!!!!!!') 132 | exit(0) 133 | 134 | x = torch.cat(x, axis=0) 135 | embedding = x.cpu().detach().numpy() 136 | 137 | x = torch.sigmoid(self.gc(x, adj)) 138 | return x, embedding 139 | 140 | 141 | # Set Transformer model 142 | class SetTransformer(nn.Module): 143 | def __init__(self, in_features=200, num_heads=4, ln=False): 144 | super(SetTransformer, self).__init__() 145 | self.enc = nn.Sequential( 146 | SAB(dim_in=in_features, dim_out=64, num_heads=num_heads, ln=ln), 147 | SAB(dim_in=64, dim_out=64, num_heads=num_heads, ln=ln) 148 | ) 149 | self.dec = nn.Sequential( 150 | PMA(dim=64, num_heads=num_heads, num_seeds=1, ln=ln) 151 | ) 152 | self.fc = nn.Linear(in_features=64, out_features=1) 153 | 154 | self.reset_parameters() 155 | 156 | def reset_parameters(self): 157 | for module in self.children(): 158 | reset_op = getattr(module, "reset_parameters", None) 159 | if callable(reset_op): 160 | reset_op() 161 | 162 | def forward(self, x): 163 | x = [self.enc(torch.unsqueeze(x_, 0)) for x_ in x] 164 | x = [self.dec(x_).squeeze() for x_ in x] 165 | x = [torch.unsqueeze(x_, 0) for x_ in x] 166 | x = torch.cat(x, axis=0) 167 | embedding = x.cpu().detach().numpy() 168 | 169 | x = torch.sigmoid(self.fc(x)) 170 | return x, embedding 171 | 172 | 173 | # Set Transformer GCN model 174 | class STGCN(nn.Module): 175 | def __init__(self, in_features=200, num_heads=4, ln=False): 176 | super(STGCN, self).__init__() 177 | self.enc = nn.Sequential( 178 | SAB(dim_in=in_features, dim_out=64, num_heads=num_heads, ln=ln), 179 | SAB(dim_in=64, dim_out=64, num_heads=num_heads, ln=ln) 180 | ) 181 | self.dec = nn.Sequential( 182 | PMA(dim=64, num_heads=num_heads, num_seeds=1, ln=ln) 183 | ) 184 | self.gc = GraphConvolution(64, 1) 185 | 186 | self.reset_parameters() 187 | 188 | def reset_parameters(self): 189 | for module in self.children(): 190 | reset_op = getattr(module, "reset_parameters", None) 191 | if callable(reset_op): 192 | reset_op() 193 | 194 | def forward(self, x, adj): 195 | x = [self.enc(torch.unsqueeze(x_, 0)) for x_ in x] 196 | x = [self.dec(x_).squeeze() for x_ in x] 197 | x = [torch.unsqueeze(x_, 0) for x_ in x] 198 | x = torch.cat(x, axis=0) 199 | embedding = x.cpu().detach().numpy() 200 | 201 | x = torch.sigmoid(self.gc(x, adj)) 202 | return x, embedding 203 | 204 | 205 | # Deep Set model 206 | class res_pool(nn.Module): 207 | def __init__(self, in_features=200, pooling_method='max'): 208 | super(res_pool, self).__init__() 209 | self.in_features = in_features 210 | self.pooling_method = pooling_method 211 | 212 | self.ll1 = nn.Linear(in_features, 128) 213 | self.ll2 = nn.Linear(128, 128) 214 | self.ll3 = nn.Linear(128, 128) 215 | self.d1 = nn.Dropout(p=0.5) 216 | self.d2 = nn.Dropout(p=0.5) 217 | self.d3 = nn.Dropout(p=0.5) 218 | 219 | self.fc = nn.Linear(128, 1) 220 | self.reset_parameters() 221 | 222 | def reset_parameters(self): 223 | for module in self.children(): 224 | reset_op = getattr(module, "reset_parameters", None) 225 | if callable(reset_op): 226 | reset_op() 227 | 228 | def forward(self, input): 229 | 230 | x = input 231 | 232 | x1 = [(F.relu(self.ll1(x_))) for x_ in x] 233 | x2 = [(F.relu(self.ll2(x_))) for x_ in x1] 234 | x3 = [(F.relu(self.ll3(x_))) for x_ in x2] 235 | 236 | if self.pooling_method == 'max': 237 | x1 = [torch.unsqueeze(torch.max(self.d1(x_), axis=0)[0], 0) for x_ in x1] 238 | x2 = [torch.unsqueeze(torch.max(self.d2(x_), axis=0)[0], 0) for x_ in x2] 239 | x3 = [torch.unsqueeze(torch.max(self.d3(x_), axis=0)[0], 0) for x_ in x3] 240 | elif self.pooling_method == 'mean': 241 | x1 = [torch.unsqueeze(self.d1(x_).mean(dim=0), 0) for x_ in x1] 242 | x2 = [torch.unsqueeze(self.d2(x_).mean(dim=0), 0) for x_ in x2] 243 | x3 = [torch.unsqueeze(self.d3(x_).mean(dim=0), 0) for x_ in x3] 244 | elif self.pooling_method == 'sum': 245 | x1 = [torch.unsqueeze(self.d1(x_).sum(dim=0), 0) for x_ in x1] 246 | x2 = [torch.unsqueeze(self.d2(x_).sum(dim=0), 0) for x_ in x2] 247 | x3 = [torch.unsqueeze(self.d3(x_).sum(dim=0), 0) for x_ in x3] 248 | else: 249 | print('Invalid Pooling method!!!!!!') 250 | exit(0) 251 | 252 | x1 = torch.cat(x1, axis=0) 253 | x2 = torch.cat(x2, axis=0) 254 | x3 = torch.cat(x3, axis=0) 255 | 256 | x = x1 + x2 + x3 257 | 258 | embedding = x.cpu().detach().numpy() 259 | 260 | x = torch.sigmoid(self.fc(x)) 261 | 262 | return x, embedding 263 | 264 | 265 | # Deep Set model 266 | class res_pool_GCN(nn.Module): 267 | def __init__(self, in_features=200, pooling_method='max'): 268 | super(res_pool_GCN, self).__init__() 269 | self.in_features = in_features 270 | self.pooling_method = pooling_method 271 | 272 | self.ll1 = nn.Linear(in_features, 128) 273 | self.ll2 = nn.Linear(128, 128) 274 | self.ll3 = nn.Linear(128, 128) 275 | self.d1 = nn.Dropout(p=0.5) 276 | self.d2 = nn.Dropout(p=0.5) 277 | self.d3 = nn.Dropout(p=0.5) 278 | 279 | self.gc = GraphConvolution(128, 1) 280 | self.reset_parameters() 281 | 282 | def reset_parameters(self): 283 | for module in self.children(): 284 | reset_op = getattr(module, "reset_parameters", None) 285 | if callable(reset_op): 286 | reset_op() 287 | 288 | def forward(self, input, adj): 289 | 290 | x = input 291 | 292 | x1 = [(F.relu(self.ll1(x_))) for x_ in x] 293 | x2 = [(F.relu(self.ll2(x_))) for x_ in x1] 294 | x3 = [(F.relu(self.ll3(x_))) for x_ in x2] 295 | 296 | if self.pooling_method == 'max': 297 | x1 = [torch.unsqueeze(torch.max(self.d1(x_), axis=0)[0], 0) for x_ in x1] 298 | x2 = [torch.unsqueeze(torch.max(self.d2(x_), axis=0)[0], 0) for x_ in x2] 299 | x3 = [torch.unsqueeze(torch.max(self.d3(x_), axis=0)[0], 0) for x_ in x3] 300 | elif self.pooling_method == 'mean': 301 | x1 = [torch.unsqueeze(self.d1(x_).mean(dim=0), 0) for x_ in x1] 302 | x2 = [torch.unsqueeze(self.d2(x_).mean(dim=0), 0) for x_ in x2] 303 | x3 = [torch.unsqueeze(self.d3(x_).mean(dim=0), 0) for x_ in x3] 304 | elif self.pooling_method == 'sum': 305 | x1 = [torch.unsqueeze(self.d1(x_).sum(dim=0), 0) for x_ in x1] 306 | x2 = [torch.unsqueeze(self.d2(x_).sum(dim=0), 0) for x_ in x2] 307 | x3 = [torch.unsqueeze(self.d3(x_).sum(dim=0), 0) for x_ in x3] 308 | else: 309 | print('Invalid Pooling method!!!!!!') 310 | exit(0) 311 | 312 | x1 = torch.cat(x1, axis=0) 313 | x2 = torch.cat(x2, axis=0) 314 | x3 = torch.cat(x3, axis=0) 315 | 316 | x = x1 + x2 + x3 317 | 318 | embedding = x.cpu().detach().numpy() 319 | 320 | x = torch.sigmoid(self.gc(x, adj)) 321 | 322 | return x, embedding -------------------------------------------------------------------------------- /mil_text/modules.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | import math 5 | import scipy.sparse as sp 6 | from torch.nn.parameter import Parameter 7 | from torch.nn.modules.module import Module 8 | 9 | 10 | # Set transformer layers 11 | class MAB(nn.Module): 12 | def __init__(self, dim_Q, dim_K, dim_V, num_heads, ln=False): 13 | super(MAB, self).__init__() 14 | self.dim_V = dim_V 15 | self.num_heads = num_heads 16 | self.fc_q = nn.Linear(dim_Q, dim_V) 17 | self.fc_k = nn.Linear(dim_K, dim_V) 18 | self.fc_v = nn.Linear(dim_K, dim_V) 19 | if ln: 20 | self.ln0 = nn.LayerNorm(dim_V) 21 | self.ln1 = nn.LayerNorm(dim_V) 22 | self.fc_o = nn.Linear(dim_V, dim_V) 23 | self.reset_parameters() 24 | 25 | def reset_parameters(self): 26 | for module in self.children(): 27 | reset_op = getattr(module, "reset_parameters", None) 28 | if callable(reset_op): 29 | reset_op() 30 | 31 | def forward(self, Q, K): 32 | Q = self.fc_q(Q) 33 | K, V = self.fc_k(K), self.fc_v(K) 34 | 35 | dim_split = self.dim_V // self.num_heads 36 | Q_ = torch.cat(Q.split(dim_split, 2), 0) 37 | K_ = torch.cat(K.split(dim_split, 2), 0) 38 | V_ = torch.cat(V.split(dim_split, 2), 0) 39 | 40 | A = torch.softmax(Q_.bmm(K_.transpose(1,2))/math.sqrt(self.dim_V), 2) 41 | O = torch.cat((Q_ + A.bmm(V_)).split(Q.size(0), 0), 2) 42 | O = O if getattr(self, 'ln0', None) is None else self.ln0(O) 43 | O = O + F.relu(self.fc_o(O)) 44 | O = O if getattr(self, 'ln1', None) is None else self.ln1(O) 45 | return O 46 | 47 | 48 | class SAB(nn.Module): 49 | def __init__(self, dim_in, dim_out, num_heads, ln=False): 50 | super(SAB, self).__init__() 51 | self.mab = MAB(dim_in, dim_in, dim_out, num_heads, ln=ln) 52 | 53 | def forward(self, X): 54 | return self.mab(X, X) 55 | 56 | 57 | class ISAB(nn.Module): 58 | def __init__(self, dim_in, dim_out, num_heads, num_inds, ln=False): 59 | super(ISAB, self).__init__() 60 | self.I = nn.Parameter(torch.Tensor(1, num_inds, dim_out)) 61 | nn.init.xavier_uniform_(self.I) 62 | self.mab0 = MAB(dim_out, dim_in, dim_out, num_heads, ln=ln) 63 | self.mab1 = MAB(dim_in, dim_out, dim_out, num_heads, ln=ln) 64 | 65 | def forward(self, X): 66 | H = self.mab0(self.I.repeat(X.size(0), 1, 1), X) 67 | return self.mab1(X, H) 68 | 69 | 70 | class PMA(nn.Module): 71 | def __init__(self, dim, num_heads, num_seeds, ln=False): 72 | super(PMA, self).__init__() 73 | self.S = nn.Parameter(torch.Tensor(1, num_seeds, dim)) 74 | nn.init.xavier_uniform_(self.S) 75 | self.mab = MAB(dim, dim, dim, num_heads, ln=ln) 76 | 77 | def forward(self, X): 78 | return self.mab(self.S.repeat(X.size(0), 1, 1), X) 79 | 80 | 81 | -------------------------------------------------------------------------------- /mil_text/rank_box.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/networkslab/BagGraph/5ecbc365aeecd6e495a500cea11deb2412f12311/mil_text/rank_box.pdf -------------------------------------------------------------------------------- /mil_text/rank_box_results.txt: -------------------------------------------------------------------------------- 1 | 60.2 65.5 84.8 83.1 84.7 84.4 83.6 88.3 87.6 88.8 2 | 47.0 77.8 59.4 81.7 82.0 81.9 81.5 80.0 78.7 79.8 3 | 51.0 63.1 61.5 70.4 70.7 70.9 70.7 71.7 71.1 70.3 4 | 46.9 59.5 66.5 79.0 78.6 78.3 78.5 73.1 73.0 75.8 5 | 44.5 61.7 66.0 79.4 79.1 79.7 79.2 79.3 78.2 78.7 6 | 50.8 69.8 76.8 79.9 80.9 80.1 81.2 84.9 85.7 86.1 7 | 51.8 55.2 56.5 67.1 66.7 66.0 67.2 75.8 74.0 74.4 8 | 52.9 72.0 66.7 76.5 76.9 76.4 76.1 78.3 78.8 78.5 9 | 50.6 64.0 80.2 83.4 84.2 83.5 83.3 85.0 84.8 85.8 10 | 51.7 64.7 77.9 86.0 86.7 85.7 87.1 80.0 81.4 83.4 11 | 51.3 85.0 82.3 89.0 90.2 91.1 89.8 89.9 89.4 90.0 12 | 56.3 69.6 76.0 79.5 77.9 77.8 78.6 80.1 81.8 81.9 13 | 50.6 87.1 55.5 92.1 93.2 92.7 93.1 90.4 90.7 91.4 14 | 50.6 62.1 78.3 85.5 84.2 84.7 83.8 78.4 78.5 80.2 15 | 54.7 75.7 81.8 79.8 79.5 80.1 80.3 88.1 88.3 88.9 16 | 49.2 59.0 81.4 79.9 80.7 80.1 80.5 78.7 78.1 79.4 17 | 47.7 58.5 74.7 76.1 78.2 77.0 77.3 76.0 73.6 77.7 18 | 55.9 73.6 79.3 83.9 84.0 83.8 83.3 82.1 81.4 81.6 19 | 51.5 70.4 69.7 76.5 75.8 76.8 75.6 76.5 76.8 77.9 20 | 55.4 63.3 73.9 74.4 76.2 76.2 74.3 79.0 78.8 80.0 21 | 22 | 23 | -------------------------------------------------------------------------------- /mil_text/rank_plot_all.py: -------------------------------------------------------------------------------- 1 | import csv 2 | import numpy as np 3 | import seaborn as sns 4 | import pandas as pd 5 | import matplotlib.pyplot as plt 6 | 7 | datasets = ['alt.atheism', 'comp.graphics', 'comp.os.ms-windows.misc', 'comp.sys.ibm.pc.hardware', 'comp.sys.mac.hardware', 'comp.windows.x', 8 | 'misc.forsale', 'rec.autos', 'rec.motorcycles', 'rec.sport.baseball', 'rec.sport.hockey', 'sci.crypt', 'sci.electronics', 9 | 'sci.med', 'sci.space', 'soc.religion.christian', 'talk.politics.guns', 'talk.politics.mideast', 'talk.politics.misc', 'talk.religion.misc'] 10 | 11 | algorithms = ['MI-Kernel', 'mi-Graph', 'miFV', 'mi-Net', 'MI-Net', 'MI-Net \nwith DS', 'MI-Net \nwith RC', 12 | 'Res+pool', 'Res+pool\n-GCN', 'B-Res+pool\n-GCN (ours)'] 13 | my_pal = {'MI-Kernel': 'k', 'mi-Graph': 'gray', 'miFV': 'c', 'mi-Net': 'b', 'MI-Net': 'gold', 'MI-Net \nwith DS': 'teal', 'MI-Net \nwith RC': 'brown', 14 | 'Res+pool': 'darkgreen', 'Res+pool\n-GCN': 'm', 'B-Res+pool\n-GCN (ours)': 'r'} 15 | 16 | num_data_set = len(datasets) 17 | num_alg = len(algorithms) 18 | 19 | acc_matrix = np.loadtxt('rank_box_results.txt', delimiter=' ', usecols=range(num_alg)) 20 | print(acc_matrix) 21 | 22 | 23 | rank = num_alg - np.argsort(np.argsort(acc_matrix, axis=1), axis=1) 24 | print(rank) 25 | for data_id_, data in enumerate(datasets): 26 | print('----------------------------------------------------------------') 27 | print(data + ', first: ' + algorithms[int(np.where(rank[data_id_]==1)[0])].strip() + ', second: ' + algorithms[int(np.where(rank[data_id_]==2)[0])].strip()) 28 | 29 | rank = rank.transpose() 30 | # print(rank.shape) 31 | rank_mean = np.mean(rank, axis=1) 32 | print('Average rank') 33 | print(rank_mean) 34 | # rank_std = np.std(rank, axis=1) 35 | 36 | rank_median = np.median(rank, axis=1) 37 | print('Median rank') 38 | print(rank_median) 39 | order = np.argsort(rank_mean) 40 | 41 | rank = rank[order][0: num_alg] 42 | algorithms = [algorithms[idx] for idx in order] 43 | algorithms = [algorithms[idx_new] for idx_new in np.arange(num_alg)] 44 | 45 | print(algorithms) 46 | 47 | rank_df = pd.concat([pd.DataFrame({algorithms[i]: rank[i, :]}) for i in range(num_alg)], axis=1) 48 | 49 | # print(rank_df.head) 50 | 51 | data_df = rank_df.melt(var_name='algorithm', value_name='Rank') 52 | fig, ax = plt.subplots(1, 1, figsize=(12, 9), dpi=75) 53 | # plt.figure(figsize=(6, 9)) 54 | b = sns.boxplot(y="algorithm", x="Rank", data=data_df, showmeans=True, order=algorithms, whis=[0, 100], 55 | meanprops={"markerfacecolor":"black", "markeredgecolor":"black", "markersize":"50"}, palette=my_pal, linewidth=6) 56 | # plt.ylabel("algorithm", size=18) 57 | plt.xticks(ticks=np.arange(1, num_alg + 1, 1)) 58 | plt.xlabel("Rank", size=40) 59 | # plt.plot(rank.mean(axis=1), np.arange(num_alg), '--r*', lw=2) 60 | b.tick_params(labelsize=30) 61 | ax.set_ylabel('') 62 | plt.tight_layout() 63 | plt.show() -------------------------------------------------------------------------------- /mil_text/run_trials.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import os 3 | import multiprocessing as mp 4 | import sys 5 | import time 6 | from res_pool_main import run_res_pool_one_dataset 7 | indices = np.arange(20) 8 | 9 | 10 | def run_all_methods(blank, index): 11 | pool_ = 'mean' 12 | k_gcn_list = [2, 2, 2, 3, 3, 4, 3, 4, 2, 2, 3, 3, 3, 3, 4, 3, 2, 2, 4, 2] 13 | k_bgcn_list = [3, 3, 3, 4, 3, 3, 3, 2, 3, 2, 4, 3, 2, 3, 4, 4, 3, 2, 3, 4] 14 | r_list = [10, 5, 10, 5, 5, 10, 1, 1, 10, 10, 1, 10, 1, 10, 5, 10, 10, 5, 10, 5] 15 | 16 | run_res_pool_one_dataset(blank, index, num_neib=0, pooling=pool_, r=0, alg_name='vanilla') 17 | run_res_pool_one_dataset(blank, index, num_neib=k_gcn_list[index], pooling=pool_, r=0, alg_name='GCN') 18 | run_res_pool_one_dataset(blank, index, num_neib=k_bgcn_list[index], pooling=pool_, r=r_list[index], alg_name='BGCN') 19 | 20 | 21 | if __name__ == '__main__': 22 | 23 | pool = mp.Pool(processes=7) 24 | pool_results = [pool.apply_async(run_all_methods, (1, index)) for index in indices] 25 | pool.close() 26 | pool.join() 27 | for pr in pool_results: 28 | dict_results = pr.get() 29 | -------------------------------------------------------------------------------- /mil_text/set_transformer_main.py: -------------------------------------------------------------------------------- 1 | import pandas as pd 2 | import numpy as np 3 | import torch.nn.functional as F 4 | from torch import optim 5 | import seaborn as sns 6 | from scipy import stats 7 | import matplotlib as mpl 8 | import scipy.sparse as sp 9 | import networkx as nx 10 | import time 11 | import os 12 | import glob 13 | import csv 14 | import torch 15 | import torchvision 16 | from models_all import * 17 | from utils import * 18 | from sklearn import preprocessing 19 | from scipy.io import loadmat 20 | from sklearn.neighbors import kneighbors_graph 21 | mpl.rcParams['figure.dpi'] = 600 22 | color = sns.color_palette() 23 | #%matplotlib inline 24 | pd.options.mode.chained_assignment = None # default='warn' 25 | 26 | 27 | def run_set_transformer_one_dataset(blank, data_index): 28 | datasets = ['alt.atheism', 'comp.graphics', 'comp.os.ms-windows.misc', 'comp.sys.ibm.pc.hardware', 'comp.sys.mac.hardware', 'comp.windows.x', 29 | 'misc.forsale', 'rec.autos', 'rec.motorcycles', 'rec.sport.baseball', 'rec.sport.hockey', 'sci.crypt', 'sci.electronics', 30 | 'sci.med', 'sci.space', 'soc.religion.christian', 'talk.politics.guns', 'talk.politics.mideast', 'talk.politics.misc', 'talk.religion.misc'] 31 | 32 | data_name = datasets[data_index] 33 | print(data_name) 34 | 35 | torch.manual_seed(0) 36 | np.random.seed(0) 37 | 38 | num_neib = 3 39 | EPOCHS = 200 40 | MC_samples = 20 41 | 42 | num_trial = 10 43 | num_fold = 10 44 | 45 | lr_ = 1e-3 46 | 47 | log_dir = 'log_st/' + data_name 48 | 49 | check = os.path.isdir(log_dir) 50 | if not check: 51 | os.makedirs(log_dir) 52 | print("created folder : ", log_dir) 53 | else: 54 | print(log_dir, " folder already exists.") 55 | 56 | for f in os.listdir(log_dir): 57 | os.remove(os.path.join(log_dir, f)) 58 | 59 | mat = loadmat('data/' + data_name + '_py.mat') # load mat-file 60 | 61 | bag_ids = np.array(mat['bag_ids']).flatten() # bags 62 | bag_features = np.array(mat['bag_features']) # features 63 | labels = np.array(mat['labels']).flatten() # labels 64 | 65 | df_features = pd.DataFrame(bag_features) 66 | df_bag_ids = pd.DataFrame(bag_ids, columns=['bag_ids']) 67 | 68 | df = pd.concat([df_bag_ids, df_features], axis=1) 69 | 70 | print(df_bag_ids.shape) 71 | print(df_features.shape) 72 | print(labels.shape) 73 | print(df.shape) 74 | 75 | # print(df.isna().sum().max()) 76 | # print(df['bag_ids'].value_counts()) 77 | 78 | x = df.iloc[:, :] 79 | print(x.shape) 80 | 81 | groups = x.groupby('bag_ids').mean() 82 | print(groups.shape) 83 | 84 | grouped_data = groups.values[:, :] 85 | y = labels 86 | print(y.shape) 87 | print(grouped_data.shape) 88 | 89 | scaled_features = x.copy() 90 | col_names = list(x) 91 | features = scaled_features[col_names[1:]] 92 | scaled_features[col_names[1:]] = features 93 | scaled_features.head() 94 | print(scaled_features.shape) 95 | 96 | groups = scaled_features.groupby('bag_ids') 97 | 98 | # Iterate over each group 99 | set_list = [] 100 | for group_name, df_group in groups: 101 | single_set = [] 102 | for row_index, row in df_group.iterrows(): 103 | single_set.append(row[1:].values) 104 | set_list.append(single_set) 105 | 106 | print(len(set_list)) 107 | 108 | target = set_list # target = Set of sets, row = set, 109 | max_cols = max([len(row) for batch in target for row in batch]) 110 | max_rows = max([len(batch) for batch in target]) 111 | print(max_cols) 112 | print(max_rows) 113 | 114 | for i in range(len(set_list)): 115 | set_list[i] = np.array(set_list[i], dtype=float) 116 | 117 | y_ = y.copy().reshape(-1, 1) 118 | labels_copy = y.copy() 119 | print(labels_copy.shape) 120 | print(y_.shape) 121 | print(y.shape) 122 | 123 | def get_performance_test(trial_i, fold_j): 124 | torch.manual_seed(0) 125 | np.random.seed(0) 126 | 127 | y = y_ 128 | 129 | models = [ 130 | SetTransformer(in_features=200, num_heads=4, ln=False), 131 | STGCN(in_features=200, num_heads=4, ln=False) 132 | ] 133 | 134 | def weights_init(m): 135 | if isinstance(m, torch.nn.Linear) or isinstance(m, GraphConvolution): 136 | torch.nn.init.xavier_uniform_(m.weight) 137 | if m.bias is not None: 138 | torch.nn.init.zeros_(m.bias) 139 | 140 | for model in models: 141 | model.apply(weights_init) 142 | 143 | # produce a split for training, validation and testing 144 | mat_fold = loadmat('data/fold/' + str(trial_i) + '/index' + str(fold_j) + '.mat') # load mat-file 145 | 146 | idx_train = np.array(mat_fold['trainIndex']).flatten() - 1 # matlab index from 1, python from 0 147 | idx_test = np.array(mat_fold['testIndex']).flatten() - 1 148 | 149 | bce = nn.BCELoss() 150 | 151 | features = [torch.FloatTensor(set_) for set_ in set_list] 152 | labels = torch.FloatTensor(y) 153 | idx_train = torch.LongTensor(idx_train) 154 | idx_test = torch.LongTensor(idx_test) 155 | 156 | acc = [] 157 | 158 | def train(epoch, adj_=None): 159 | t = time.time() 160 | model.train() 161 | optimizer.zero_grad() 162 | if adj_ is not None: 163 | output, _ = model(features, adj_) 164 | else: 165 | output, _ = model(features) 166 | loss_train = bce(output[idx_train], labels[idx_train]) 167 | loss_train.backward() 168 | optimizer.step() 169 | 170 | acc_train = accuracy(labels[idx_train], output[idx_train]) 171 | acc_test = accuracy(labels[idx_test], output[idx_test]) 172 | # if (epoch+1)%10 == 0: 173 | # print('Epoch: {:04d}'.format(epoch+1), 174 | # 'loss_train: {:.4f}'.format(loss_train.item()), 175 | # 'acc_train: {:.4f}'.format(acc_train.item()), 176 | # 'acc_test: {:.4f}'.format(acc_test.item()), 177 | # 'time: {:.4f}s'.format(time.time() - t)) 178 | 179 | return acc_test.item() 180 | 181 | for network in (models): 182 | 183 | # Model and optimizer 184 | model = network 185 | 186 | adj = None 187 | adj_norm = None 188 | 189 | if isinstance(model, (STGCN)): 190 | embedding = np.loadtxt(log_dir + '/decoding_set_transformer_' + 191 | data_name + '_trial_' + str(trial_i) + '_fold_' + str(fold_j), delimiter=",") 192 | 193 | A = kneighbors_graph(embedding, num_neib, mode='connectivity', include_self=True) 194 | G = nx.from_scipy_sparse_matrix(A) 195 | 196 | adj = nx.to_numpy_array(G) 197 | adj_norm = normalize(adj) 198 | adj_norm = torch.FloatTensor(adj_norm) 199 | 200 | optimizer = optim.Adam(model.parameters(), lr=lr_) 201 | 202 | # Train model 203 | t_total = time.time() 204 | for epoch in range(EPOCHS): 205 | if isinstance(model, SetTransformer): 206 | value = train(epoch) 207 | else: 208 | value = train(epoch, adj_norm) 209 | 210 | model.eval() 211 | with torch.no_grad(): 212 | if isinstance(model, SetTransformer): 213 | output, decoding = model(features) 214 | np.savetxt(log_dir + '/decoding_set_transformer_' + data_name + 215 | '_trial_' + str(trial_i) + '_fold_' + str(fold_j), decoding, delimiter=',') 216 | else: 217 | output, _ = model(features, adj_norm) 218 | 219 | acc.append(accuracy(labels[idx_test], output[idx_test])) 220 | model.apply(weights_init) 221 | 222 | adj = None 223 | adj_norm = None 224 | return acc 225 | 226 | tests = [] 227 | print('SetTransformer', 'STGCN') 228 | for i in range(num_trial): 229 | for j in range(num_fold): 230 | t = time.time() 231 | acc_all = get_performance_test(i+1, j+1) 232 | tests.append(acc_all) 233 | print('run : ' + str(num_fold*i+j+1) + ', accuracy : ' + str(np.array(acc_all)) + ', run time: {:.4f}s'.format(time.time() - t)) 234 | 235 | tests = np.array(tests) 236 | df_test = pd.DataFrame(tests, columns=['SetTransformer', 'STGCN']) 237 | 238 | print(df_test.mean(axis=0)) 239 | 240 | def get_performance_test_bayesian(trial_i, fold_j): 241 | torch.manual_seed(0) 242 | np.random.seed(0) 243 | 244 | y = y_ 245 | 246 | models = [ 247 | STGCN(in_features=200, num_heads=4, ln=False) 248 | ] 249 | 250 | def weights_init(m): 251 | if isinstance(m, torch.nn.Linear) or isinstance(m, GraphConvolution): 252 | torch.nn.init.xavier_uniform_(m.weight) 253 | if m.bias is not None: 254 | torch.nn.init.zeros_(m.bias) 255 | 256 | for model in models: 257 | model.apply(weights_init) 258 | 259 | # produce a split for training, validation and testing 260 | mat_fold = loadmat('data/fold/' + str(trial_i) + '/index' + str(fold_j) + '.mat') # load mat-file 261 | 262 | idx_train = np.array(mat_fold['trainIndex']).flatten() - 1 # matlab index from 1, python from 0 263 | idx_test = np.array(mat_fold['testIndex']).flatten() - 1 264 | 265 | bce = nn.BCELoss() 266 | 267 | features = [torch.FloatTensor(set_) for set_ in set_list] 268 | labels = torch.FloatTensor(y) 269 | idx_train = torch.LongTensor(idx_train) 270 | idx_test = torch.LongTensor(idx_test) 271 | 272 | acc = [] 273 | 274 | def train_(epoch, adj_=None): 275 | t = time.time() 276 | model.train() 277 | optimizer.zero_grad() 278 | 279 | output, _ = model(features, adj_) 280 | 281 | loss_train = bce(output[idx_train], labels[idx_train]) 282 | loss_train.backward() 283 | optimizer.step() 284 | 285 | acc_train = accuracy(labels[idx_train], output[idx_train]) 286 | acc_test = accuracy(labels[idx_test], output[idx_test]) 287 | # if (epoch+1)%10 == 0: 288 | # print('Epoch: {:04d}'.format(epoch+1), 289 | # 'loss_train: {:.4f}'.format(loss_train.item()), 290 | # 'acc_train: {:.4f}'.format(acc_train.item()), 291 | # 'acc_test: {:.4f}'.format(acc_test.item()), 292 | # 'time: {:.4f}s'.format(time.time() - t)) 293 | 294 | return acc_test.item(), output 295 | 296 | for network in (models): 297 | 298 | # Model and optimizer 299 | model = network 300 | 301 | adj = None 302 | adj_norm = None 303 | 304 | if isinstance(model, (STGCN)): 305 | embedding = np.loadtxt(log_dir + '/decoding_set_transformer_' + 306 | data_name + '_trial_' + str(trial_i) + '_fold_' + str(fold_j), delimiter=",") 307 | 308 | adj_np = MAP_inference(embedding, num_neib) 309 | 310 | adj_np_norm = normalize(adj_np) 311 | adj_norm = adj_np_norm 312 | adj_norm = torch.FloatTensor(adj_norm) 313 | 314 | optimizer = optim.Adam(model.parameters(), lr=lr_) 315 | 316 | # Train model 317 | output_ = 0.0 318 | t_total = time.time() 319 | for epoch in range(EPOCHS): 320 | value, output = train_(epoch, adj_norm) 321 | 322 | if epoch >= EPOCHS - MC_samples: 323 | output_ += output 324 | 325 | output = output_ / np.float32(MC_samples) 326 | 327 | acc.append(accuracy(labels[idx_test], output[idx_test])) 328 | model.apply(weights_init) 329 | 330 | adj = None 331 | adj_norm = None 332 | return acc 333 | 334 | tests_bayesian = [] 335 | print('B-STGCN') 336 | for i in range(num_trial): 337 | for j in range(num_fold): 338 | t = time.time() 339 | acc_bayes = get_performance_test_bayesian(i+1, j+1) 340 | tests_bayesian.append(acc_bayes) 341 | print('run : ' + str(num_fold*i+j+1) + ', accuracy : ' + str(np.array(acc_bayes)) + ', run time: {:.4f}s'.format(time.time() - t)) 342 | 343 | tests_bayesian = np.array(tests_bayesian) 344 | df_test_bayesian = pd.DataFrame(tests_bayesian, columns=['B-STGCN']) 345 | 346 | df_concat = pd.concat([df_test, df_test_bayesian], axis=1) 347 | print(df_concat.mean(axis=0)) 348 | df_concat.to_csv('accuracy_' + data_name + '_set_transformer.csv', index=False) 349 | 350 | -------------------------------------------------------------------------------- /mil_text/utils.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from numpy import matlib 3 | from scipy.stats import t 4 | import scipy.sparse as sp 5 | from math import sqrt 6 | from statistics import stdev 7 | 8 | 9 | def normalize(mx): 10 | """Row-normalize sparse matrix""" 11 | rowsum = np.array(mx.sum(1)) 12 | r_inv = np.power(rowsum, -1).flatten() 13 | r_inv[np.isinf(r_inv)] = 0. 14 | r_inv_sqrt = np.sqrt(r_inv) 15 | r_mat_inv_sqrt = np.diag(r_inv_sqrt) 16 | mx = r_mat_inv_sqrt.dot(mx) 17 | mx = mx.dot(r_mat_inv_sqrt) 18 | return mx 19 | 20 | 21 | def accuracy(labels, output): 22 | preds = (output > 0.5).type_as(labels) 23 | correct = preds.eq(labels).double() 24 | correct = correct.sum() 25 | return correct / len(labels) 26 | 27 | 28 | def compute_distance(embed): 29 | N = embed.shape[0] 30 | p = np.dot(embed, np.transpose(embed)) 31 | q = np.matlib.repmat(np.diag(p), N, 1) 32 | dist = q + np.transpose(q) - 2 * p 33 | dist[dist < 1e-8] = 1e-8 34 | return dist 35 | 36 | 37 | def estimate_graph(gamma, epsilon, dist, max_iter, k, r): 38 | np.random.seed(0) 39 | 40 | N = dist.shape[0] 41 | dist += 1e10 * np.eye(N) 42 | 43 | deg_exp = np.minimum(int(N-1), int(k * r)) 44 | 45 | dist_sort_col_idx = np.argsort(dist, axis=0) 46 | dist_sort_col_idx = np.transpose(dist_sort_col_idx[0:deg_exp, :]) 47 | 48 | dist_sort_row_idx = np.matlib.repmat(np.arange(N).reshape(N, 1), 1, deg_exp) 49 | 50 | dist_sort_col_idx = np.reshape(dist_sort_col_idx, int(N * deg_exp)).astype(int) 51 | dist_sort_row_idx = np.reshape(dist_sort_row_idx, int(N * deg_exp)).astype(int) 52 | 53 | dist_idx = np.zeros((int(N * deg_exp), 2)).astype(int) 54 | dist_idx[:, 0] = dist_sort_col_idx 55 | dist_idx[:, 1] = dist_sort_row_idx 56 | dist_idx = np.sort(dist_idx, axis=1) 57 | dist_idx = np.unique(dist_idx, axis=0) 58 | dist_sort_col_idx = dist_idx[:, 0] 59 | dist_sort_row_idx = dist_idx[:, 1] 60 | 61 | num_edges = len(dist_sort_col_idx) 62 | 63 | w_init = np.random.uniform(0, 1, size=(num_edges, 1)) 64 | d_init = k * np.random.uniform(0, 1, size=(N, 1)) 65 | 66 | w_current = w_init 67 | d_current = d_init 68 | 69 | dist_sorted = np.sort(dist, axis=0) 70 | 71 | B_k = np.sum(dist_sorted[0:k, :], axis=0) 72 | dist_sorted_k = dist_sorted[k-1, :] 73 | dist_sorted_k_plus_1 = dist_sorted[k, :] 74 | 75 | theta_lb = 1 / np.sqrt(k * dist_sorted_k_plus_1 ** 2 - B_k * dist_sorted_k_plus_1) 76 | theta_lb = theta_lb[~np.isnan(theta_lb)] 77 | theta_lb = theta_lb[~np.isinf(theta_lb)] 78 | theta_lb = np.mean(theta_lb) 79 | 80 | theta_ub = 1 / np.sqrt(k * dist_sorted_k ** 2 - B_k * dist_sorted_k) 81 | theta_ub = theta_ub[~np.isnan(theta_ub)] 82 | theta_ub = theta_ub[~np.isinf(theta_ub)] 83 | theta_ub = np.mean(theta_ub) 84 | 85 | theta = (theta_lb + theta_ub) / 2 86 | 87 | dist = theta * dist 88 | 89 | z = dist[dist_sort_row_idx, dist_sort_col_idx] 90 | z.shape = (num_edges, 1) 91 | 92 | for iter in range(max_iter): 93 | 94 | # print('Graph inference epoch : ' + str(iter)) 95 | 96 | St_times_d = d_current[dist_sort_row_idx] + d_current[dist_sort_col_idx] 97 | y_current = w_current - gamma * (2 * w_current + St_times_d) 98 | 99 | adj_current = np.zeros((N, N)) 100 | adj_current[dist_sort_row_idx, dist_sort_col_idx] = np.squeeze(w_current) 101 | adj_current = adj_current + np.transpose(adj_current) 102 | S_times_w = np.sum(adj_current, axis=1) 103 | S_times_w.shape = (N, 1) 104 | y_bar_current = d_current + gamma * S_times_w 105 | 106 | p_current = np.maximum(0, np.abs(y_current) - 2 * gamma * z) 107 | p_bar_current = (y_bar_current - np.sqrt(y_bar_current * y_bar_current + 4 * gamma)) / 2 108 | 109 | St_times_p_bar = p_bar_current[dist_sort_row_idx] + p_bar_current[dist_sort_col_idx] 110 | q_current = p_current - gamma * (2 * p_current + St_times_p_bar) 111 | 112 | p_matrix_current = np.zeros((N, N)) 113 | p_matrix_current[dist_sort_row_idx, dist_sort_col_idx] = np.squeeze(p_current) 114 | p_matrix_current = p_matrix_current + np.transpose(p_matrix_current) 115 | S_times_p = np.sum(p_matrix_current, axis=1) 116 | S_times_p.shape = (N, 1) 117 | q_bar_current = p_bar_current + gamma * S_times_p 118 | 119 | w_updated = np.abs(w_current - y_current + q_current) 120 | d_updated = np.abs(d_current - y_bar_current + q_bar_current) 121 | 122 | if (np.linalg.norm(w_updated - w_current) / np.linalg.norm(w_current) < epsilon) and \ 123 | (np.linalg.norm(d_updated - d_current) / np.linalg.norm(d_current) < epsilon): 124 | break 125 | else: 126 | w_current = w_updated 127 | d_current = d_updated 128 | 129 | upper_tri_index = np.triu_indices(N, k=1) 130 | 131 | z = dist[upper_tri_index[0], upper_tri_index[1]] 132 | z.shape = (int(N * (N - 1) / 2), 1) 133 | z = z * np.max(w_current) 134 | 135 | w_current = w_current / np.max(w_current) 136 | 137 | inferred_graph = np.zeros((N, N)) 138 | inferred_graph[dist_sort_row_idx, dist_sort_col_idx] = np.squeeze(w_current) 139 | inferred_graph = inferred_graph + np.transpose(inferred_graph) + np.eye(N) 140 | 141 | return inferred_graph 142 | 143 | 144 | def MAP_inference(x, num_neib, r): 145 | N = x.shape[0] 146 | k = int(1 * num_neib) 147 | 148 | dist = compute_distance(x) 149 | 150 | inferred_graph = estimate_graph(0.01, 0.001, dist, 1000, k, r) 151 | 152 | return inferred_graph 153 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | absl-py==0.8.1 2 | apipkg==1.5 3 | appdirs==1.4.3 4 | astor==0.7.1 5 | atomicwrites==1.4.0 6 | attrs==20.2.0 7 | certifi==2020.12.5 8 | chardet==4.0.0 9 | colorama==0.4.3 10 | cycler==0.10.0 11 | Cython==0.29.21 12 | debtcollector==1.22.0 13 | decorator==4.3.0 14 | distlib==0.3.1 15 | ecos==2.0.7.post1 16 | execnet==1.7.1 17 | filelock==3.0.12 18 | future==0.18.2 19 | gast==0.2.0 20 | google-pasta==0.1.8 21 | grpcio==1.16.1 22 | h5py==2.8.0 23 | hdf5able==0.3.0 24 | idna==2.10 25 | importlib-metadata==2.0.0 26 | importlib-resources==3.2.1 27 | iniconfig==1.0.1 28 | joblib==0.13.2 29 | Keras==2.3.1 30 | Keras-Applications==1.0.6 31 | Keras-Preprocessing==1.0.5 32 | kiwisolver==1.0.1 33 | Markdown==3.0.1 34 | matplotlib==3.0.2 35 | mock==3.0.5 36 | mxboard==0.1.0 37 | networkx==2.2 38 | nibabel==2.5.0 39 | nilearn==0.5.2 40 | nose==1.3.7 41 | numexpr==2.7.1 42 | numpy==1.18.1 43 | osqp==0.6.1 44 | packaging==20.4 45 | pandas==0.24.2 46 | pathlib2==2.3.5 47 | patsy==0.5.1 48 | pbr==5.5.1 49 | Pillow==7.2.0 50 | pipenv==2020.11.15 51 | pluggy==0.13.1 52 | ply==3.11 53 | protobuf==3.6.1 54 | Psikit==0.2.0 55 | pummeler==0.3.1 56 | py==1.9.0 57 | pyerf==1.0.1 58 | Pyomo==5.6.6 59 | pyparsing==2.3.0 60 | pytest==6.1.1 61 | pytest-forked==1.3.0 62 | pytest-xdist==2.1.0 63 | python-dateutil==2.7.5 64 | python-utils==2.5.6 65 | pytz==2019.1 66 | PyUtilib==5.7.1 67 | PyYAML==5.3 68 | rbfopt==4.1.1 69 | requests==2.25.1 70 | scikit-learn==0.20.1 71 | scikits.bootstrap==1.0.1 72 | scipy==1.1.0 73 | six==1.11.0 74 | sklearn==0.0 75 | statsmodels==0.11.1 76 | tables==3.5.1 77 | tensorboard==1.14.0 78 | tensorflow==1.14.0 79 | tensorflow-estimator==1.14.0 80 | termcolor==1.1.0 81 | toml==0.10.1 82 | torch==1.3.1+cu92 83 | torchdiffeq==0.0.1 84 | torchvision==0.4.2+cu92 85 | tqdm==4.60.0 86 | urllib3==1.26.3 87 | virtualenv==20.4.2 88 | virtualenv-clone==0.5.4 89 | Werkzeug==0.14.1 90 | wrapt==1.11.2 91 | zipp==1.2.0 92 | --------------------------------------------------------------------------------