├── __init__.py ├── .gitignore ├── models ├── decoder.py ├── RNNs.py ├── GNNs.py └── models.py ├── README.md ├── reqs └── DYNAMIC_GRAPHS_3.8.10.txt ├── utils └── utils.py ├── make_data.py └── train.py /__init__.py: -------------------------------------------------------------------------------- 1 | import os, sys; sys.path.append(os.path.dirname(os.path.realpath(__file__))) -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | 2 | !*/reqs/*.txt 3 | #ignore text files to not have data leaks 4 | *.csv 5 | *.edg 6 | *.tsv 7 | logs/ 8 | notebooks/ 9 | models_final/ 10 | 11 | execution_scripts/ 12 | **/tgn 13 | */methods/GNN_LSTM/sbatch_maker.ipynb 14 | */home/emiliano/projects/def-cbravo/emiliano/RappiAnomalyDetection/methods/GNN_LSTM/gru_safe.py 15 | *.ipynbcheckpoints 16 | */data 17 | *.out 18 | *.sh 19 | *.pt 20 | methods/GNN_LSTM/model_files/* 21 | *.json 22 | .vscode/ 23 | .env 24 | *.cpython-38.pyc 25 | *.npy 26 | methods/GNN_LSTM/wandb/* -------------------------------------------------------------------------------- /models/decoder.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | from torch.nn import Linear 3 | import torch.nn.functional as F 4 | class Decoder(nn.Module): 5 | def __init__(self,hidden_dim=None, target_size=1) -> None: 6 | super().__init__() 7 | 8 | def forward(): 9 | pass 10 | 11 | class LinFFN(Decoder): 12 | def __init__(self,hidden_dim, target_size=1): 13 | super().__init__() 14 | 15 | self.fc = Linear(hidden_dim, int(hidden_dim/2)) 16 | self.fc2 = Linear(int(hidden_dim/2), target_size) 17 | 18 | 19 | def forward(self, input_): 20 | h = self.fc(input_) 21 | h = F.relu(h) 22 | h = self.fc2(h) 23 | return h 24 | def get_decoder(decoder): 25 | if decoder == 'LIN': 26 | return LinFFN 27 | assert True , 'INCORRECT DECODER NAME' -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Influencer Detection with Dynamic Graph Neural Networks 2 | 3 | This repository contain the code used in the paper E. Tiukhova, E. Penaloza, M. Óskarsdóttir, H. Garcia, A. Correa Bahnsen, B. Baesens, M. Snoeck, C. Bravo. Influencer Detection with Dynamic Graph Neural Networks. Accepted at Temporal Graph Learning workshop, NeurIPS, 2022 4 | 5 | Link to the paper: https://arxiv.org/abs/2211.09664 6 | 7 | Link to the poster: https://neurips.cc/media/PosterPDFs/NeurIPS%202022/56519.png?t=1668072924.9758837 8 | 9 | 10 | ## Project structure: 11 | 12 | The project repo holds the following structure 13 | ``` 14 | |-models 15 | | |-GNNs.py 16 | | |-RNNs.py 17 | | |-decoder.py 18 | | |-models.py 19 | |-reqs 20 | | |-DYNAMIC_GRAPHS_3.8.10.txt 21 | |-utils 22 | | |-utils.py 23 | |-make_data.py 24 | |-train.py 25 | 26 | 27 | ``` 28 | ### models 29 | 30 | This folder contains the .py files used to make combinations of encoder and decoder in dynamic GNN models as well as create baseline models. 31 | 32 | ### reqs 33 | 34 | This folder contains the files that lists all of a project's dependencies. 35 | 36 | ### utils 37 | 38 | This folder contains a .py file that provides functions for several files. 39 | 40 | ### make_data.py 41 | 42 | The script to generate the network data and preprocess it. 43 | 44 | ### train.py 45 | 46 | The script to run the experiments. 47 | -------------------------------------------------------------------------------- /reqs/DYNAMIC_GRAPHS_3.8.10.txt: -------------------------------------------------------------------------------- 1 | anyio==3.6.1 2 | argon2-cffi==21.3.0+computecanada 3 | argon2-cffi-bindings==21.2.0+computecanada 4 | asttokens==2.0.5+computecanada 5 | attrs==21.4.0+computecanada 6 | Babel==2.10.3 7 | backcall==0.2.0+computecanada 8 | beautifulsoup4==4.11.1+computecanada 9 | bleach==5.0.1 10 | certifi==2022.6.15+computecanada 11 | cffi==1.15.0+computecanada 12 | charset-normalizer==2.1.0+computecanada 13 | debugpy==1.6.2 14 | decorator==5.1.1+computecanada 15 | defusedxml==0.7.1+computecanada 16 | entrypoints==0.4+computecanada 17 | executing==0.8.3+computecanada 18 | fastjsonschema==2.16.1 19 | future==0.18.2+computecanada 20 | idna==3.3+computecanada 21 | importlib-metadata==4.12.0 22 | importlib-resources==5.8.0+computecanada 23 | ipykernel==6.15.1 24 | ipython==8.4.0 25 | ipython-genutils==0.2.0+computecanada 26 | jedi==0.18.1+computecanada 27 | jinja2==3.1.2+computecanada 28 | joblib==1.1.0+computecanada 29 | json5==0.9.6+computecanada 30 | jsonschema==4.7.2 31 | jupyter-client==7.3.4 32 | jupyter-core==4.11.1 33 | jupyter-server==1.18.1 34 | jupyterlab==3.4.3 35 | jupyterlab-pygments==0.2.2 36 | jupyterlab-server==2.15.0 37 | markupsafe==2.0.1+computecanada 38 | matplotlib-inline==0.1.3+computecanada 39 | mistune==0.8.4+computecanada 40 | nbclassic==0.4.3 41 | nbclient==0.6.6 42 | nbconvert==6.5.0+computecanada 43 | nbformat==5.4.0+computecanada 44 | nest-asyncio==1.5.5+computecanada 45 | notebook==6.4.12+computecanada 46 | notebook-shim==0.1.0+computecanada 47 | numpy==1.23.0+computecanada 48 | packaging==21.3+computecanada 49 | pandas==1.1.0+computecanada 50 | pandocfilters==1.5.0+computecanada 51 | parso==0.8.3+computecanada 52 | pexpect==4.8.0+computecanada 53 | pickleshare==0.7.5+computecanada 54 | prometheus-client==0.14.1+computecanada 55 | prompt-toolkit==3.0.30 56 | psutil==5.9.1+computecanada 57 | ptyprocess==0.7.0+computecanada 58 | pure-eval==0.2.2+computecanada 59 | pycparser==2.21+computecanada 60 | Pygments==2.12.0 61 | pyparsing==3.0.9+computecanada 62 | pyrsistent==0.18.1+computecanada 63 | python-dateutil==2.8.2+computecanada 64 | pytz==2022.1+computecanada 65 | pyzmq==23.1.0+computecanada 66 | requests==2.28.1+computecanada 67 | scikit-learn==0.23.1 68 | scipy==1.8.0+computecanada 69 | send2trash==1.8.0+computecanada 70 | six==1.16.0+computecanada 71 | sniffio==1.2.0+computecanada 72 | soupsieve==2.3.2.post1+computecanada 73 | stack-data==0.3.0 74 | terminado==0.15.0 75 | threadpoolctl==3.1.0+computecanada 76 | tinycss2==1.1.1+computecanada 77 | torch==1.6.0+computecanada 78 | tornado==6.1+computecanada 79 | traitlets==5.3.0 80 | urllib3==1.26.10+computecanada 81 | wcwidth==0.2.5+computecanada 82 | webencodings==0.5.1+computecanada 83 | websocket-client==1.3.3+computecanada 84 | zipp==3.8.1 85 | -------------------------------------------------------------------------------- /models/RNNs.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import torch 3 | from torch.nn import BCEWithLogitsLoss, GRUCell ,LSTMCell 4 | class RNN(nn.Module): 5 | def __init__(self) -> None: 6 | super().__init__() 7 | def forward(self): 8 | pass 9 | def __init__hidd(): 10 | pass 11 | 12 | class LSTMs(RNN): 13 | def __init__(self,input_dim, hidden_dim, n_layers,n_nodes ) -> None: 14 | super().__init__() 15 | c_0 = LSTMCell( input_dim,hidden_dim) 16 | self.cells = nn.ModuleList([LSTMCell( hidden_dim,hidden_dim) for _ in range(n_layers-1)]) 17 | self.cells.insert(0,c_0) 18 | self.n_layers = n_layers 19 | self.input_dim = input_dim 20 | self.hidden_dim = hidden_dim 21 | self.n_nodes = n_nodes 22 | def forward(self,inps, h0_list): 23 | prev_h, _ = self.cells[0](inps, h0_list[0]) 24 | h_list= [] 25 | h_list.append((prev_h,_)) 26 | 27 | for i,(l,h_c) in enumerate(zip(self.cells,h0_list)): 28 | if i == 0 : continue 29 | (prev_h,c) = l(prev_h, h_c) 30 | h_list.append((prev_h,c)) 31 | 32 | return h_list 33 | 34 | def init__hidd(self): 35 | 36 | return [( torch.ones(self.n_nodes ,self.hidden_dim), torch.ones(self.n_nodes ,self.hidden_dim)) for _ in range(self.n_layers)] 37 | 38 | class GRUs(RNN): 39 | def __init__(self,input_dim, hidden_dim, n_layers ,n_nodes) -> None: 40 | super().__init__() 41 | c_0 = GRUCell( input_dim,hidden_dim) 42 | self.cells = nn.ModuleList([GRUCell( hidden_dim,hidden_dim) for _ in range(n_layers-1)]) 43 | self.cells.insert(0,c_0) 44 | self.n_layers = n_layers 45 | self.input_dim = input_dim 46 | self.hidden_dim = hidden_dim 47 | self.n_nodes= n_nodes 48 | 49 | def forward(self,inps, h0_list): 50 | 51 | prev_h = self.cells[0](inps, h0_list[0]) 52 | 53 | h_list= [] 54 | h_list.append(prev_h) 55 | for i,(l,h_c) in enumerate(zip(self.cells,h0_list)): 56 | if i == 0 : continue 57 | prev_h = l(prev_h, h_c) 58 | h_list.append((prev_h)) 59 | 60 | return h_list 61 | 62 | def init__hidd(self): 63 | h0s = [ torch.ones(self.n_nodes ,self.hidden_dim) for _ in range(self.n_layers)] 64 | 65 | return h0s 66 | 67 | 68 | class LSTMClassification(torch.nn.Module): 69 | 70 | def __init__(self, input_dim,hidden_dim,n_nodes, **kwargs): 71 | super(LSTMClassification, self).__init__() 72 | print( input_dim,hidden_dim,n_nodes) 73 | self.hidden_dim = hidden_dim 74 | self.input_dim = input_dim 75 | self.n_nodes = n_nodes 76 | self.lstm = LSTMCell( input_dim,hidden_dim) 77 | 78 | 79 | def forward(self, input_, h0): 80 | (h, c) = self.lstm(input_, h0[0]) 81 | 82 | return [(h,c)] 83 | def init__hidd( self): 84 | h0 = torch.randn(self.n_nodes ,self.hidden_dim) 85 | c0 = torch.randn(self.n_nodes, self.hidden_dim) 86 | 87 | return [(h0,c0)] 88 | 89 | class GRUClassification(torch.nn.Module): 90 | 91 | def __init__(self, input_dim,hidden_dim,n_nodes, **kwargs): 92 | super(GRUClassification, self).__init__() 93 | self.hidden_dim = hidden_dim 94 | self.n_nodes = n_nodes 95 | self.input_dim = input_dim 96 | self.gru = GRUCell(input_size = input_dim,hidden_size = hidden_dim) 97 | 98 | def forward(self, input_, h0): 99 | # print(input_,h0[0]) 100 | h = self.gru(input_, h0[0]) 101 | 102 | 103 | return [h] 104 | def init__hidd( self): 105 | h0 = torch.randn(self.n_nodes,self.hidden_dim) 106 | 107 | return [h0] 108 | 109 | def get_RNN(rnn ): 110 | if rnn == 'LSTM': 111 | return LSTMs 112 | if rnn == 'single_gru': 113 | return GRUClassification 114 | if rnn == 'single_lstm': 115 | return LSTMClassification 116 | else: 117 | return GRUs 118 | 119 | -------------------------------------------------------------------------------- /models/GNNs.py: -------------------------------------------------------------------------------- 1 | from re import search 2 | import torch.nn as nn 3 | import torch 4 | from torch_geometric.nn import GCNConv, GATv2Conv,GINEConv,GraphSAGE 5 | from torch.nn import Linear 6 | import torch.nn.functional as F 7 | 8 | class GNN(nn.Module): 9 | def __init__(self,input_dim=None, embedding_dim=None,output_dim=None,edge_dim = None,heads = 8, n_layers = 5) -> None: 10 | super().__init__() 11 | def forward(self): 12 | pass 13 | 14 | class GATs(nn.Module): 15 | def __init__(self,input_dim, embedding_dim,output_dim,edge_dim,heads, n_layers, dropout_rate, **kwargs) -> None: 16 | super().__init__() 17 | self.dropout_rate = dropout_rate 18 | self.gat1 = GATv2Conv(input_dim, embedding_dim, heads=heads,edge_dim = edge_dim) # dim_h * num heads 19 | #dim_h * heads > dim_h 20 | self.GAT_list = torch.nn.ModuleList([GATv2Conv(embedding_dim*heads, embedding_dim, heads=heads,edge_dim = edge_dim) for _ in range(n_layers-1)]) 21 | self.gat2 = GATv2Conv(embedding_dim*heads, output_dim, heads=1, edge_dim = edge_dim) 22 | # for m in self.gat1.parameters(): 23 | # print(m) 24 | # # print(torch.nn.init.xavier_uniform_(m.unsqueeze(0))) 25 | 26 | 27 | def forward(self, inp, edge_index,edge_feats): 28 | 29 | h = self.gat1(inp, edge_index,edge_attr = edge_feats) 30 | 31 | # print(self.gat1) 32 | # print([x for x in self.gat1.parameters()]) 33 | h = F.elu(h) 34 | h = F.dropout(h, self.dropout_rate) 35 | for l in self.GAT_list: 36 | h = l(h, edge_index,edge_attr = edge_feats) 37 | h = F.elu(h) 38 | h = F.dropout(h, self.dropout_rate) 39 | #h = self.lin(h) 40 | h = self.gat2(h, edge_index, edge_feats) 41 | h = F.relu(h) 42 | h = F.dropout(h, self.dropout_rate) 43 | 44 | 45 | 46 | return h 47 | 48 | class GCNs(GNN): 49 | """Graph Convolutional Network""" 50 | def __init__(self, input_dim,embedding_dim, output_dim,n_layers,dropout_rate, **kwargs): 51 | super().__init__() 52 | self.gcn1 = GCNConv(input_dim, embedding_dim) 53 | self.dropout_rate = dropout_rate 54 | self.GCN_list = torch.nn.ModuleList([GCNConv(embedding_dim, embedding_dim) for _ in range(n_layers-1)]) 55 | self.lin = Linear(embedding_dim, output_dim) 56 | 57 | def forward(self, x, edge_index,edge_feats): 58 | #h = F.dropout(x, p=0.5, training=self.training) 59 | h = self.gcn1(x, edge_index,edge_weight = edge_feats) 60 | h = F.elu(h) 61 | h = F.dropout(h, self.dropout_rate) 62 | for l in self.GCN_list: 63 | h = l(h, edge_index,edge_weight = edge_feats) 64 | h = F.elu(h) 65 | h = F.dropout(h, self.dropout_rate) 66 | h = self.lin(h) 67 | h = F.relu(h) 68 | return h 69 | 70 | 71 | class GINs(GNN): 72 | def __init__(self,input_dim, embedding_dim,output_dim,edge_dim,dropout_rate, n_layers,train_eps = False,eps=0,**kwargs) -> None: 73 | super().__init__() 74 | self.dropout_rate = dropout_rate 75 | h_theta_0 = [nn.Sequential(nn.Linear(input_dim,embedding_dim))] 76 | lin_list = [nn.Sequential(nn.Linear(embedding_dim, embedding_dim)) for _ in range(n_layers-2)] 77 | h_theta_n = [nn.Sequential(nn.Linear(embedding_dim,output_dim ))] 78 | self.lin_list = nn.ModuleList(h_theta_0 + lin_list+h_theta_n) 79 | self.GIN_list = nn.ModuleList() 80 | for h_theta_i in self.lin_list: 81 | self.GIN_list.append(GINEConv(h_theta_i,eps,train_eps=train_eps,edge_dim=edge_dim)) 82 | 83 | def forward(self,x ,edge_index,edge_feats): 84 | h = x 85 | 86 | for l in self.GIN_list: 87 | h = l(h,edge_index,edge_feats) 88 | h = F.elu(h) 89 | h = F.dropout(h,self.dropout_rate) 90 | return h 91 | class SAGEs(GNN): 92 | def __init__(self, input_dim ,embedding_dim, output_dim, search_depth,n_layers,dropout_rate,**kwargs) -> None: 93 | super().__init__() 94 | self.dropout_rate = dropout_rate 95 | l1 = GraphSAGE(input_dim,embedding_dim,search_depth,output_dim,dropout=dropout_rate) 96 | self.SAGE_list = nn.ModuleList([GraphSAGE(output_dim,embedding_dim,search_depth,output_dim,dropout_rate) for _ in range(n_layers - 1) ]) 97 | self.SAGE_list.insert(0,l1) 98 | def forward(self,x,edge_index,edge_feats): 99 | h= x 100 | for l in self.SAGE_list: 101 | h = l(h,edge_index) 102 | h = F.elu(h) 103 | h = F.dropout(h,self.dropout_rate) 104 | return h 105 | def get_GNN(gnn ): 106 | 107 | if gnn == 'GAT': 108 | return GATs 109 | elif gnn == 'GIN': 110 | return GINs 111 | elif gnn == 'SAGE': 112 | return SAGEs 113 | 114 | else: 115 | return GCNs -------------------------------------------------------------------------------- /models/models.py: -------------------------------------------------------------------------------- 1 | from pprint import pprint 2 | from re import search 3 | from turtle import forward, hideturtle 4 | from typing import Tuple 5 | from utils.utils import upsample_embeddings 6 | import torch 7 | import torch.nn as nn 8 | from models.GNNs import get_GNN 9 | from models.RNNs import get_RNN 10 | from models.decoder import get_decoder 11 | 12 | def get_hidden(h): 13 | if isinstance(h,Tuple): 14 | return (h[0].clone().detach(),h[1].clone().detach()) 15 | else: 16 | return h.clone().detach() 17 | 18 | class Model(nn.Module): 19 | def __init__(self) -> None: 20 | super().__init__() 21 | def forward(): 22 | pass 23 | 24 | class RNN_GNN(Model): 25 | def __init__(self,upsample,n_nodes, RNN, GNN,DECODER , gnn_input_dim,rnn_input_dim, gnn_embedding_dim,rnn_hidden_dim , gnn_output_dim,rnn_layers,gnn_layers,heads, dropout_rate, edge_dim,train_eps,eps,search_depth,**kwargs ) -> None: 26 | super().__init__() 27 | rnn_kw = { 28 | 'hidden_dim' : rnn_hidden_dim, 29 | 'input_dim' : rnn_input_dim, 30 | 'n_layers' : rnn_layers, 31 | 'n_nodes' : n_nodes 32 | } 33 | gnn_kw = { 34 | 'embedding_dim': gnn_embedding_dim, 35 | 'input_dim' : gnn_input_dim, 36 | 'n_layers' : gnn_layers, 37 | 'heads' : heads, 38 | 'dropout_rate': dropout_rate, 39 | 'edge_dim' : edge_dim, 40 | 'output_dim' : gnn_output_dim, 41 | 'train_eps' : train_eps, 42 | 'eps': eps , 43 | 'search_depth':search_depth 44 | } 45 | self.upsample = upsample 46 | self.RNN = get_RNN(RNN)(**rnn_kw) 47 | self.GNN = get_GNN(GNN)(**gnn_kw) 48 | self.decoder = get_decoder(DECODER)(rnn_kw['hidden_dim']) 49 | 50 | def forward(self, month_list,data_dict,h0 = None,train = True): 51 | h0s = self.RNN.init__hidd() if h0 == None else h0 52 | hidden_states = h0s 53 | for i,m in enumerate(month_list): 54 | data = data_dict[m] 55 | labs = data.y 56 | emb = self.GNN(torch.Tensor(data.x).float(), torch.Tensor(data.edge_index).type(torch.int64),torch.tensor(data.edge_attr).float()) 57 | hidden_states = self.RNN(torch.Tensor(emb), hidden_states) 58 | if i == 0: 59 | h0 = [get_hidden(hidden_states[0])] 60 | last_h = hidden_states[-1] 61 | last_h = last_h[0] if type(last_h) == tuple else last_h 62 | if self.upsample > 0 and train: 63 | 64 | last_h,labs,synth_index = upsample_embeddings(last_h,data.y,data.edge_index,self.upsample) 65 | else: 66 | synth_index = [] 67 | scores = self.decoder(last_h) 68 | 69 | 70 | return scores,torch.Tensor(labs),h0,synth_index 71 | 72 | class RNN_only(Model): 73 | def __init__(self,upsample, n_nodes,RNN,DECODER ,rnn_input_dim,rnn_hidden_dim ,rnn_layers ,**kwargs ) -> None: 74 | super().__init__() 75 | rnn_kw = { 76 | 'hidden_dim' : rnn_hidden_dim, 77 | 'input_dim' : rnn_input_dim, 78 | 'n_layers' : rnn_layers, 79 | 'n_nodes' : n_nodes 80 | } 81 | 82 | self.upsample = upsample 83 | self.RNN = get_RNN(RNN)(**rnn_kw) 84 | self.decoder = get_decoder(DECODER)(rnn_kw['hidden_dim']) 85 | 86 | def forward(self, month_list,data_dict,h0): 87 | h0s = self.RNN.init__hidd() if h0 == None else h0 88 | hidden_states = h0s 89 | for i,m in enumerate(month_list): 90 | data = data_dict[m] 91 | labs = data.y 92 | hidden_states = self.RNN(torch.Tensor(data.x), (hidden_states 93 | )) 94 | 95 | if i == 0: 96 | h0 = [get_hidden(hidden_states[0])] 97 | last_h = hidden_states[-1] 98 | last_h = last_h[0] if type(last_h) == tuple else last_h 99 | if self.upsample > 0 : 100 | last_h,labs,synth_index = upsample_embeddings(last_h,data.y,data.edge_index,self.upsample) 101 | else: 102 | synth_index = [] 103 | scores = self.decoder(last_h) 104 | 105 | return scores,torch.Tensor(labs),h0,synth_index 106 | 107 | 108 | 109 | class GNN_only(Model): 110 | def __init__(self, GNN,DECODER, gnn_input_dim, gnn_embedding_dim , gnn_output_dim,gnn_layers,heads, dropout_rate, edge_dim ,eps ,train_eps,search_depth,**kwargs ) -> None: 111 | super().__init__() 112 | gnn_kw = { 113 | 'embedding_dim': gnn_embedding_dim, 114 | 'input_dim' : gnn_input_dim, 115 | 'n_layers' : gnn_layers, 116 | 'heads' : heads, 117 | 'dropout_rate': dropout_rate, 118 | 'edge_dim' : edge_dim, 119 | 'output_dim' : gnn_output_dim, 120 | 'train_eps' : train_eps, 121 | 'eps': eps , 122 | 'search_depth':search_depth 123 | } 124 | self.GNN = get_GNN(GNN)(**gnn_kw) 125 | self.decoder = get_decoder(DECODER)(gnn_output_dim) 126 | 127 | def forward_call(self, data): 128 | labs = data.y 129 | emb = self.GNN(torch.Tensor(data.x).float(), torch.Tensor(data.edge_index).type(torch.int64),torch.tensor(data.edge_attr).float()) 130 | scores = self.decoder(emb) 131 | h0 = None 132 | synth_index = [] 133 | return scores,torch.Tensor(labs),h0,synth_index 134 | 135 | def forward(self,month, data_dict,h0=None): 136 | assert type(month) == int, 'CANNOT USE WINDOWS WITH ONLY GNN' 137 | return self.forward_call(data_dict[month]) 138 | 139 | def get_model(gnn_kw, rnn_kw,decoder_kw): 140 | if gnn_kw['GNN'] and rnn_kw['RNN']: 141 | kw = dict(rnn_kw ,**gnn_kw) 142 | kw = dict(kw,**decoder_kw) 143 | pprint(kw) 144 | return RNN_GNN(**kw) 145 | elif gnn_kw['GNN'] and not rnn_kw['RNN']: 146 | kw = dict(gnn_kw,**decoder_kw) 147 | pprint(kw) 148 | return GNN_only(**kw) 149 | elif not gnn_kw['GNN'] and rnn_kw['RNN']: 150 | kw = dict(rnn_kw,**decoder_kw) 151 | return RNN_only( **kw) 152 | else: 153 | print('SPECIFY A MODEL, RNN AND GNN ARE EMPTY') -------------------------------------------------------------------------------- /utils/utils.py: -------------------------------------------------------------------------------- 1 | from xml.sax.handler import property_interning_dict 2 | import pandas as pd 3 | import numpy as np 4 | from scipy.spatial.distance import pdist,squareform 5 | import random 6 | import torch 7 | from sklearn.metrics import auc, roc_curve, classification_report, precision_recall_curve, accuracy_score 8 | from torch.nn import BCEWithLogitsLoss 9 | from statistics import mean 10 | import os 11 | 12 | def upsample_embeddings(embed, labels,edges,sample_rate,end_index = None): 13 | 14 | n_nodes = len(set(edges.flatten()) ) 15 | avg_number = n_nodes * sample_rate 16 | chosen = np.argwhere(labels.flatten() == 1).flatten() 17 | original_idx = embed.shape[0] 18 | #ipdb.set_trace() 19 | 20 | c_portion = int(avg_number/chosen.shape[0]) 21 | 22 | for j in range(c_portion): 23 | chosen_embed = embed[chosen,:] 24 | distance = squareform(pdist(chosen_embed.cpu().detach())) 25 | np.fill_diagonal(distance,distance.max()+100) 26 | idx_neighbor = distance.argmin(axis=-1) 27 | interp_place = random.random() 28 | 29 | new_embed = embed[chosen,:] + (chosen_embed[idx_neighbor,:]-embed[chosen,:])*interp_place 30 | new_labels = torch.zeros((chosen.shape[0],1)).reshape(-1).fill_(1) 31 | embed = torch.cat((embed,new_embed), 0) 32 | labels = torch.cat((torch.tensor(labels),new_labels), 0) 33 | 34 | if end_index is not None: 35 | embed = embed[:end_index] 36 | synthetic_index = [x for x in range( original_idx , embed.shape[0] ,1) ] 37 | 38 | return embed, labels,synthetic_index 39 | 40 | def auprc_wrap(labels,scores): 41 | pres, recall, _ = precision_recall_curve(labels, scores) 42 | return recall,pres,_ 43 | 44 | 45 | def get_metrics(scores,labels, ): 46 | 47 | fpr, tpr, thresholds = roc_curve( labels,scores) 48 | precision, recall, thresholds = precision_recall_curve(labels,scores) 49 | 50 | return auc(fpr, tpr),auc(recall,precision) 51 | 52 | def get_loss(loss) : 53 | if loss =='bce': 54 | return BCEWithLogitsLoss 55 | 56 | def get_window(dates): 57 | return [[int(dates[i-1]),int(dates[i]),int(dates[i+1])] for i in range( 1,len(dates) -1 )] 58 | 59 | def evauluate(model,windows,data_dict,loss_function,train_nodes,unseen_nodes_set,h0,boot =False,boot_size =10000): 60 | 61 | losses = [] 62 | auprcs = [] 63 | auprcs_seen= [] 64 | auprcs_unseen=[] 65 | aucs= [] 66 | aucs_seen= [] 67 | aucs_unseen = [] 68 | 69 | for month_window in windows: 70 | m = month_window if type(month_window) == int else month_window[-1] 71 | nodes_to_eval= list(set(data_dict[m].edge_index.flatten())) 72 | unseen_nodes = list(set(nodes_to_eval ) & unseen_nodes_set) 73 | seen_nodes = list(set(nodes_to_eval) & train_nodes) 74 | scores_for_loss,labels , h0,synth_index = model(month_window, data_dict,h0,train = False) 75 | scores_for_loss = torch.Tensor(scores_for_loss.detach().flatten()).float() 76 | scores_seen = scores_for_loss[seen_nodes] 77 | labels_seen = labels[seen_nodes] 78 | 79 | scores_unseen = scores_for_loss[unseen_nodes] 80 | labels_unseen = labels[unseen_nodes] 81 | loss = loss_function(scores_for_loss, labels) 82 | 83 | auc, auprc = get_metrics(torch.sigmoid(scores_for_loss),labels) 84 | seen_auc, seen_auprc = get_metrics(torch.sigmoid(scores_seen),labels_seen) 85 | unseen_auc, unseen_auprc = get_metrics(torch.sigmoid(scores_unseen),labels_unseen) 86 | 87 | losses.append(loss) 88 | auprcs.append(auprc) 89 | auprcs_seen.append(seen_auprc) 90 | auprcs_unseen.append(unseen_auprc) 91 | aucs.append(auc) 92 | aucs_seen.append(seen_auc) 93 | aucs_unseen.append(unseen_auc) 94 | 95 | if boot: 96 | seen_boot = bootstrap_preds( scores_seen, labels_seen,num_boot = boot_size) 97 | unseen_boot = bootstrap_preds( scores_unseen, labels_unseen,num_boot = boot_size) 98 | d_seen_auc = get_stats(seen_boot, metric = 'seen_auc') 99 | d_unseen_auc = get_stats(unseen_boot, metric = 'unseen_auc') 100 | 101 | auprc_seen_boot = bootstrap_preds( scores_seen, labels_seen,auprc_wrap,boot_size) 102 | auprc_unseen_boot = bootstrap_preds( scores_unseen, labels_unseen,auprc_wrap,boot_size) 103 | d_seen_auprc =get_stats(auprc_seen_boot, metric = 'seen_auprc') 104 | d_unseen_auprc = get_stats(auprc_unseen_boot, metric = 'unseen_auprc') 105 | boot_dict = dict(d_seen_auc , **d_unseen_auc) 106 | boot_dict = dict(boot_dict,**d_seen_auprc) 107 | boot_dict = dict(boot_dict,**d_unseen_auprc) 108 | else: 109 | boot_dict = None 110 | 111 | return np.mean(aucs),np.mean(aucs_seen),np.mean(aucs_unseen),np.mean(auprcs),np.mean(auprcs_seen),np.mean(auprcs_unseen),np.mean(losses),boot_dict,h0 112 | 113 | 114 | 115 | def print_log(names,metrics,e): 116 | for n,m in zip(names,metrics): 117 | print(f"{n} = {m} at epoch {e}") 118 | 119 | def bootstrap_preds(preds, labs,curve_fun = roc_curve,num_boot = 10000): 120 | 121 | n = len(preds) 122 | boot_means = np.zeros(num_boot) 123 | data = pd.DataFrame({'preds':preds.flatten(),'labs':labs.flatten()}) 124 | 125 | np.random.seed(0) 126 | for i in range(num_boot): 127 | d = data.sample(n, replace=True) 128 | 129 | fpr, tpr, thresholds = curve_fun(d.labs,d.preds) 130 | 131 | boot_means[i] = auc(fpr,tpr) 132 | 133 | return boot_means 134 | 135 | def get_stats(arr,metric = 'auc'): 136 | boot_dict = { 137 | f'{metric}_boot_mean' : np.mean(arr), 138 | f'{metric}_boot_std' : np.std(arr), 139 | f'{metric}_boot_q0' : np.quantile(arr,0), 140 | f'{metric}_boot_q1' : np.quantile(arr,.25), 141 | f'{metric}_boot_q2' : np.quantile(arr,.5), 142 | f'{metric}_boot_q3' : np.quantile(arr,.75), 143 | f'{metric}_boot_max' : np.quantile(arr,1), 144 | } 145 | return boot_dict 146 | 147 | def write_log(metrics, log_file, run_name, model): 148 | p = f'/home/etiukhov/projects/def-cbravo/emiliano/RappiAnomalyDetection/methods/GNN_LSTM' 149 | log_file = p + '/logs/' + log_file 150 | metrics.insert(0,run_name) 151 | df = pd.DataFrame({c:[r] for c,r in zip(range(len(metrics)),metrics)}) 152 | with open(log_file, 'a') as f: 153 | df.to_csv(f, header=False,index = False) 154 | 155 | torch.save(model.state_dict(), p + '/model_files/' + run_name) -------------------------------------------------------------------------------- /make_data.py: -------------------------------------------------------------------------------- 1 | import networkx as nx 2 | from ast import parse 3 | import torch 4 | from sklearn.preprocessing import MinMaxScaler 5 | from torch_geometric.data import Data 6 | import json 7 | import os 8 | import numpy as np 9 | import pandas as pd 10 | from pathlib import Path 11 | import argparse 12 | 13 | 14 | def preprocess(data_name): 15 | u_list, i_list, ts_list, label_list = [], [], [], [] 16 | feat_l = [] 17 | idx_list = [] 18 | 19 | with open(data_name) as f: 20 | s = next(f) 21 | for idx, line in enumerate(f): 22 | e = line.strip().split(',') 23 | #user id 24 | u = int(e[0]) 25 | #item id 26 | i = int(e[1]) 27 | #time stamp 28 | ts = float(e[2]) 29 | #state label/ label of what it is 30 | label = float(e[3]) # int(e[3]) 31 | #csv of features anything after index 4 32 | #note features must be numeric so we would 1 hot encode categorical variables 33 | feat = np.array([float(x) for x in e[4:]]) 34 | 35 | u_list.append(u) 36 | i_list.append(i) 37 | ts_list.append(ts) 38 | label_list.append(label) 39 | idx_list.append(idx) 40 | 41 | feat_l.append(feat) 42 | #returns pandas dataframe with users, ids, time stamps, numpy array of features 43 | return pd.DataFrame({'u': u_list, 44 | 'i': i_list, 45 | 'ts': ts_list, 46 | 'label': label_list, 47 | 'idx': idx_list}), np.array(feat_l) 48 | 49 | 50 | def reindex(df, bipartite=True): 51 | new_df = df.copy() 52 | if bipartite: 53 | assert (df.u.max() - df.u.min() + 1 == len(df.u.unique())) 54 | assert (df.i.max() - df.i.min() + 1 == len(df.i.unique())) 55 | 56 | upper_u = df.u.max() + 1 57 | new_i = df.i + upper_u 58 | 59 | new_df.i = new_i 60 | new_df.u += 1 61 | new_df.i += 1 62 | new_df.idx += 1 63 | else: 64 | new_df.u += 1 65 | new_df.i += 1 66 | new_df.idx += 1 67 | return new_df 68 | 69 | def make_labels(df,label_arr,ts): 70 | mask = df[df.label == 1].u.values 71 | label_arr[mask,0,int(ts)] = 1 72 | 73 | return label_arr 74 | 75 | def make_page_rank(edges,data): 76 | edge_df = pd.DataFrame({'u': edges[0,:], 'v': edges[1,:]}) 77 | 78 | data_graph=nx.from_pandas_edgelist(edge_df,source = 'u',target = 'v', create_using= nx.Graph) 79 | 80 | h = pd.DataFrame(data) 81 | h['PageRank'] = 0 82 | total_nodes = data.shape[0] 83 | #print(total_nodes) 84 | 85 | for i in nx.connected_components(data_graph): 86 | subgraph = data_graph.subgraph(list(i)) 87 | #print(subgraph.number_of_nodes()) 88 | if subgraph.number_of_nodes() != 1: 89 | pagerank_dict = nx.pagerank(subgraph) 90 | for k in subgraph: 91 | h['PageRank'].loc[k] = pagerank_dict[k]*subgraph.number_of_nodes()/total_nodes 92 | return h.to_numpy() 93 | 94 | def run(data_name, page_rank,GCN,val_start ,test_start, bipartite=False): 95 | base = '/home/' + os.environ.get('USER') + '/projects/def-cbravo/rappi_data/QTR/' 96 | Path("data/").mkdir(parents=True, exist_ok=True) 97 | PATH = base + '/{}.csv'.format(data_name) 98 | 99 | 100 | df, feat = preprocess(PATH) 101 | new_df = reindex(df, bipartite) 102 | 103 | empty = np.zeros(feat.shape[1])[np.newaxis, :] 104 | feat = np.vstack([empty, feat]) 105 | sub_set = pd.read_csv(base + 'TGN_fin.csv') 106 | numerical = sub_set.drop(columns = ['APPLICATION_USER_ID','START_DATE']).columns 107 | keys = sub_set[['START_DATE','APPLICATION_USER_ID']] 108 | 109 | train = sub_set[(val_start > sub_set.START_DATE) ] 110 | sub_train = train.drop(columns=['START_DATE','APPLICATION_USER_ID']) 111 | 112 | min_ = sub_train.min() 113 | max_ = sub_train.max() 114 | 115 | train_feats = (sub_train - min_) / (max_ - min_) 116 | 117 | val = sub_set[(val_start <= sub_set.START_DATE) & (test_start > sub_set.START_DATE)] 118 | sub_val = val.drop(columns=['START_DATE','APPLICATION_USER_ID']) 119 | 120 | val_feats = (sub_val - min_) / (max_ - min_) 121 | val_feats.iloc[:,:] = np.where(val_feats > 1 , 1 , val_feats) 122 | val_feats.iloc[:,:] = np.where(val_feats < 0 , 0 , val_feats) 123 | 124 | # val_feats.apply(lambda x : 1 if x > 1 else x) 125 | 126 | test = sub_set[(test_start <= sub_set.START_DATE) ] 127 | sub_test= test.drop(columns=['START_DATE','APPLICATION_USER_ID']) 128 | 129 | test_feats = (sub_test - min_) / (max_ - min_) 130 | 131 | test_feats.iloc[:,:] = np.where(test_feats > 1 , 1 , test_feats) 132 | test_feats.iloc[:,:] = np.where(test_feats < 0 , 0 , test_feats) 133 | 134 | 135 | sub_set[numerical] = pd.concat([train_feats,val_feats,test_feats]) 136 | 137 | # feats = (feats- feats.min()) / (feats.max() - feats.min()) 138 | sub_set = sub_set.drop(columns = ['START_DATE','APPLICATION_USER_ID']) 139 | sub_set = pd.concat([keys,sub_set],axis = 1) 140 | s = sub_set[sub_set.START_DATE == max(sub_set.START_DATE)].shape 141 | out = np.zeros((sub_set.APPLICATION_USER_ID.max() + 2,s[1]-2,max(sub_set.START_DATE)+1)) 142 | for ts in sub_set.START_DATE.unique(): 143 | timed = sub_set[sub_set.START_DATE == ts] 144 | for user in timed.APPLICATION_USER_ID: 145 | 146 | out[user+1,:,ts] = timed[user == timed.APPLICATION_USER_ID].drop(columns = ['APPLICATION_USER_ID','START_DATE']) 147 | 148 | 149 | graph_df = new_df 150 | edge_features = feat 151 | inds = np.argwhere(np.sum(edge_features,axis = 1) ==0 ).flatten() 152 | #equals 1 due to added edge to match node indexing 153 | assert len(inds.flatten()) == 1, 'EDGES FROM REFERALS IN DATASET' 154 | 155 | node_features = out 156 | 157 | 158 | n_nodes = node_features.shape[0] 159 | label_arr = np.zeros((n_nodes,1,len(graph_df.ts.unique()))) 160 | labels = graph_df.label.values 161 | datas = {} 162 | ts_iter = sorted(graph_df.ts.unique())[4:] 163 | sum_rows = 0 164 | 165 | assert not all([np.isnan(x).any() for x in [node_features,edge_features,graph_df.to_numpy()]]) 166 | 167 | for ts in ts_iter: 168 | sub_graph = graph_df[graph_df.ts == ts] 169 | sources = sub_graph.u.values 170 | destinations = sub_graph.i.values 171 | edges = np.array([np.concatenate((sources, destinations), axis = None),np.concatenate((destinations,sources), axis = None)]) 172 | edges_pagerank = np.array([sources,destinations]) 173 | labels = make_labels(sub_graph,label_arr,ts) 174 | nodes = list(set(edges.flatten()) ) 175 | if page_rank: 176 | n_feats = make_page_rank(edges_pagerank,node_features[:,:,int(ts)]) 177 | else: 178 | n_feats = node_features[:,:,int(ts)] 179 | 180 | n_feats = node_features[:,:,int(ts)] 181 | print(len(labels),'labels') 182 | #assert (edges.shape[1] == len(sub_graph.index)), 'MISMATCH IN EDGE FEATS AND EDGE SIZE ' 183 | if GCN == 'GAT': 184 | datas[ts] = Data(n_feats,edges,y = labels[:,:,int(ts)].flatten(),edge_attr = np.concatenate((edge_features[sub_graph.index,:],edge_features[sub_graph.index,:])),nodes = nodes ) 185 | elif GCN == 'GCN': 186 | datas[ts] = Data(n_feats ,edges,y = labels[:,:,int(ts)].flatten(),edge_attr = np.sum(np.concatenate((edge_features[sub_graph.index,:],edge_features[sub_graph.index,:])),axis = 1),nodes = nodes) 187 | sum_rows+= edges.shape[1] 188 | print(f'THE MONTH {ts} has {len(edges.flatten())} edges \n it has {len(set(edges.flatten()))} nodes with {len(np.argwhere(labels[:,:,int(ts)].flatten() == 1).flatten())} postives') 189 | print(f'NODE FEATRUE MATRIX IS DIMENSIONS {n_feats.shape}') 190 | #assert sum_rows == len(graph_df[graph_df.ts.isin(ts_iter)]), 'NOT SAME NODES AS IN EDGES' 191 | if page_rank: GCN = 'PR' 192 | torch.save( datas,'./data/' + f'{data_name}_{GCN}.pt') 193 | print('SAVING COMPLETE') 194 | print('SAVED' , './data/' + f'{data_name}_{GCN}.pt') 195 | 196 | parser = argparse.ArgumentParser('Interface for TGN data preprocessing') 197 | parser.add_argument('--data', type=str, help='Dataset name (eg. wikipedia or reddit)', 198 | default='wikipedia') 199 | parser.add_argument('--bipartite', action='store_true', help='Whether the graph is bipartite') 200 | parser.add_argument('--page_rank', action='store_true', help='Whether the graph is bipartite') 201 | parser.add_argument('--GNN', type=str, help='Dataset name (eg. wikipedia or reddit)',default ='GCN') 202 | parser.add_argument('--val_start', type=int, help='Dataset name (eg. wikipedia or reddit)',default =13) 203 | parser.add_argument('--test_start', type=int, help='Dataset name (eg. wikipedia or reddit)',default =15) 204 | 205 | args = parser.parse_args() 206 | 207 | run(args.data, args.page_rank,args.GNN,args.val_start,args.test_start) -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | import os 2 | from copy import deepcopy 3 | from statistics import mean 4 | import numpy as np 5 | from tqdm import tqdm 6 | from pprint import pprint 7 | import pandas as pd 8 | import torch 9 | from torch.optim import Adam 10 | import argparse 11 | from utils.utils import get_loss,get_metrics,get_window,evauluate, print_log, write_log 12 | import wandb 13 | from models.models import get_model 14 | import numpy as np 15 | import random 16 | torch.manual_seed(1) 17 | np.random.seed(0) 18 | 19 | 20 | def get_var_name(variable): 21 | names = [] 22 | for v in variable: 23 | for name, value in globals().items(): 24 | if value is v: 25 | names.append(name) 26 | return names 27 | 28 | parser = argparse.ArgumentParser('Training execution') 29 | parser.add_argument('--GNN', type=str, help='GAT, SAGE, GCN GIN',default='') 30 | parser.add_argument('--RNN', type=str, help='RNN, GRU/LSTM',default = '') 31 | parser.add_argument('--DECODER', type=str, help='decoder',default = 'LIN') 32 | parser.add_argument('--gnn_input_dim', type=int, help='input dim for GNN, defaults to number of features',default =25) 33 | parser.add_argument('--gnn_output_dim', type=int, help='output dim of gnn, matches input dim of RNN in combined models',default =200) 34 | parser.add_argument('--embedding_dim', type=int, help='hidden dimension for GNN layers',default =200) 35 | parser.add_argument('--heads', type=int, help='attention heads',default =2) 36 | parser.add_argument('--dropout_rate', default = 0.2, help='Dropout rate',type = float) 37 | parser.add_argument('--RNN_hidden_dim', type=int, help='hidden dim for RNNs',default =200) 38 | parser.add_argument('--RNN_layers', type=int, help='layers for RNN',default =1) 39 | parser.add_argument('--GNN_layers', type=int, help='layers for gnn',default =2) 40 | parser.add_argument('--upsample_rate', type=float, help='upsample rate for embedding upsampleing',default =0) 41 | parser.add_argument('--page_rank', action='store_true', help='Whether the graph is bipartite') 42 | parser.add_argument('--epochs', type=int, help='epochs',default =5) 43 | parser.add_argument('--lr', type=float, help='learning rate',default =.0001) 44 | parser.add_argument('--eps', type=float, help='epsilon for GIN',default =0) 45 | parser.add_argument('--boot_sample', type=int, help='sample size for bootstrap',default =10000) 46 | parser.add_argument('--data_name', type=str, help='data name',default = 'TGN_paper') 47 | parser.add_argument('--loss', type=str, help='loss name',default = 'bce') 48 | parser.add_argument('--full_windows', action='store_true', help='Use full windows or not') 49 | parser.add_argument('--train_eps', action='store_true', help='Train eps for GIN') 50 | parser.add_argument('--search_depth_SAGE', type=int, help='Depth for SAGE search',default =2) 51 | parser.add_argument('--run_name', type=str, help='Name with combination of hps being ran',default = '') 52 | parser.add_argument('--log_file', type=str, help='Log file',default = 'logs.csv') 53 | args = parser.parse_args() 54 | 55 | RNN = args.RNN 56 | GNN = args.GNN 57 | page_rank = args.page_rank 58 | boots = args.boot_sample 59 | if page_rank: 60 | GNN_type = 'page_rank' 61 | elif not GNN: 62 | GNN_type = 'GCN' 63 | else: 64 | if GNN == 'GIN' or GNN == 'GAT' or GNN == 'SAGE': 65 | GNN_type = 'GAT' 66 | GNN = 'GAT' 67 | else: 68 | GNN 69 | GNN_type = 'GCN' 70 | 71 | data_path = f'./data/{args.data_name}_{GNN_type}.pt' 72 | 73 | # wandb.init(project="rappi", entity="emilianouw",settings=wandb.Settings(start_method="fork")) 74 | # wandb.init(project="Influencer_detection_final", entity="elena-tiukhova", name =f'{run}') 75 | 76 | data_dict = torch.load(data_path) 77 | ts_list = list(data_dict.keys()) 78 | ts_list = [int(x) for x in ts_list] 79 | n_nodes = data_dict[ts_list[0]].x.shape[0] 80 | n_feats = data_dict[ts_list[0]].x.shape[1] 81 | rnn_input_dim = args.gnn_output_dim if GNN else n_feats 82 | 83 | rnn_kw = { 84 | 'RNN' : args.RNN, 85 | 'rnn_input_dim' : rnn_input_dim, 86 | 'rnn_hidden_dim' : args.RNN_hidden_dim, 87 | 'rnn_layers' : args.RNN_layers, 88 | 'upsample' : args.upsample_rate, 89 | 'n_nodes' : n_nodes 90 | } 91 | gnn_kw ={ 92 | 'GNN' : args.GNN, 93 | 'gnn_input_dim': args.gnn_input_dim, 94 | 'gnn_embedding_dim' : args.embedding_dim, 95 | 'heads' : args.heads, 96 | 'dropout_rate': args.dropout_rate, 97 | 'edge_dim' :data_dict[ts_list[0]].edge_attr.shape[1] if GNN == 'GAT' else None, 98 | 'gnn_output_dim' : args.gnn_output_dim, 99 | 'gnn_layers': args.GNN_layers, 100 | 'eps': args.eps, 101 | 'train_eps' : args.train_eps, 102 | 'search_depth': args.search_depth_SAGE 103 | } 104 | decoder_kw = { 105 | "DECODER": args.DECODER 106 | } 107 | epochs = args.epochs 108 | # other_kw = {'lr': args.lr , 'page_rank':args.page_rank, 'upsample_rate': args.upsample_rate,'epochs':args.epochs } 109 | # kw = {**decoder_kw,**rnn_kw,**gnn_kw,**other_kw} 110 | # wandb.config=kw 111 | 112 | 113 | 114 | if args.full_windows: 115 | windows= ts_list[:-3] 116 | val_window = ts_list[-3:-1] 117 | test_window = ts_list[-1:] 118 | print(windows,val_window,test_window) 119 | train_nodes= set(list(data_dict[windows[-1]].edge_index.flatten())) 120 | val_nodes = set(list(data_dict[val_window[-1]].edge_index.flatten())).difference(train_nodes) 121 | test_nodes = set(list(data_dict[test_window[-1]].edge_index.flatten())).difference(train_nodes) 122 | 123 | else: 124 | full_windows = get_window(ts_list) 125 | windows = full_windows[:-4] 126 | val_window = full_windows[-4:-2] 127 | test_window = full_windows[-2:] 128 | print(windows,val_window,test_window) 129 | train_nodes= set(list(data_dict[windows[-1][-1]].edge_index.flatten())) 130 | val_nodes = set(list(data_dict[val_window[-1][-1]].edge_index.flatten())).difference(train_nodes) 131 | test_nodes = set(list(data_dict[test_window[-1][-1]].edge_index.flatten())).difference(train_nodes) 132 | 133 | 134 | model = get_model(gnn_kw=gnn_kw,rnn_kw=rnn_kw,decoder_kw=decoder_kw) 135 | optimizer = Adam(model.parameters() , lr = args.lr) 136 | loss_function = get_loss(args.loss)() 137 | 138 | # torch.save(model.GNN.state_dict(), './model.pt') 139 | 140 | # torch.save(model.RNN.state_dict(), './model_rnn.pt') 141 | # exit(0) 142 | 143 | train_losses = [] 144 | auc_list = [] 145 | auprc_list = [] 146 | val_auc_seen = [] 147 | val_auc_unseen = [] 148 | val_auprc_seen = [] 149 | val_auprc_unseen = [] 150 | min_val_loss = np.inf 151 | best_epoch = 0 152 | since_changed = 0 153 | for e in tqdm(range(epochs)): 154 | model.train() 155 | loss = 0 156 | running_auprc = [] 157 | running_auc = [] 158 | running_loss = [] 159 | best_e = 0 160 | h0 = None 161 | for month_window in windows: 162 | 163 | optimizer.zero_grad() 164 | scores,labels , h0,synth_index = model(month_window, data_dict,h0) 165 | 166 | m = month_window if type(month_window) == int else month_window[-1] 167 | nodes_to_backprop = list(set(data_dict[m].edge_index.flatten())) + synth_index 168 | labels = labels[nodes_to_backprop].flatten() 169 | scores_for_loss = torch.Tensor(scores[nodes_to_backprop]).float().flatten() 170 | loss = loss_function(scores_for_loss, labels) 171 | 172 | loss.backward() 173 | optimizer.step() 174 | auc,auprc = get_metrics(torch.sigmoid(scores_for_loss.detach().numpy()),labels) 175 | running_auc.append(auc) 176 | running_auprc.append(auprc) 177 | running_loss.append(loss.item()) 178 | 179 | model.eval() 180 | 181 | val_auc, val_seen_auc, val_unseen_auc, val_auprc,val_seen_auprc,val_unseen_auprc,val_losses,_,h0\ 182 | =evauluate(model, val_window ,data_dict,loss_function,train_nodes,val_nodes,h0) 183 | train_auc = mean(running_auc) 184 | train_auprc = mean(running_auprc) 185 | train_loss = mean(running_loss) 186 | auc_list.append(train_auc) 187 | auprc_list.append(train_auprc) 188 | train_losses.append(train_loss) 189 | 190 | metrics = [val_auc,val_seen_auc,val_unseen_auc, val_auprc,val_seen_auprc,val_unseen_auprc,train_auc,train_auprc,train_loss,val_losses] 191 | names = get_var_name(metrics) 192 | print_log(names,metrics,e) 193 | 194 | logs = {n:m for n,m in zip(names,metrics)} 195 | # wandb.log(logs) 196 | 197 | 198 | if min_val_loss >val_losses : 199 | min_val_loss = val_losses 200 | best_model= deepcopy(model.state_dict()) 201 | best_e = e 202 | 203 | 204 | 205 | model.load_state_dict(best_model) 206 | 207 | 208 | 209 | val_final_auc, val_final_seen_auc, val_final_unseen_auc, val_final_auprc, val_final_seen_auprc,val_final_unseen_auprc,val_final_losses,val_boot,val_h0 =evauluate(model, val_window ,data_dict,loss_function,train_nodes,val_nodes,h0,boot =True,boot_size = boots) 210 | 211 | test_final_auc, test_final_seen_auc,test_final_unseen_auc, test_final_auprc, test_final_seen_auprc,test_final_unseen_auprc,test_final_losses,test_boot, _ =evauluate(model, test_window ,data_dict,loss_function,train_nodes,test_nodes,val_h0,boot= True,boot_size =boots) 212 | 213 | metrics =[val_final_auc,val_final_seen_auc,val_final_unseen_auc,val_final_auprc,val_final_seen_auprc,val_final_unseen_auprc,val_final_losses,test_final_auc,test_final_seen_auc,test_final_unseen_auc,test_final_auprc,test_final_seen_auprc,test_final_unseen_auprc,test_final_losses] 214 | names = get_var_name(metrics) 215 | 216 | print_log(names,metrics,best_e) 217 | 218 | print('val boot') 219 | pprint(val_boot) 220 | print('test boot') 221 | pprint(test_boot) 222 | pprint(logs) 223 | 224 | 225 | 226 | write_log(metrics, args.log_file, args.run_name, model) 227 | 228 | 229 | 230 | 231 | 232 | 233 | 234 | 235 | 236 | 237 | 238 | 239 | 240 | 241 | 242 | 243 | --------------------------------------------------------------------------------