├── Cross_Entropy.py ├── LICENSE ├── Readme.md ├── auto_syst_dl.py ├── bitcoin_dl.py ├── data └── sbm_50t_1000n_adj.csv.tar.gz ├── docker-set-up ├── Dockerfile └── README.md ├── edge_cls_tasker.py ├── egcn_h.py ├── egcn_o.py ├── elliptic_construction.md ├── elliptic_temporal_dl.py ├── experiments ├── parameters_auto_syst_egcn_h.yaml ├── parameters_auto_syst_egcn_o.yaml ├── parameters_bitcoin_alpha_edgecls_egcn_h.yaml ├── parameters_bitcoin_alpha_edgecls_egcn_o.yaml ├── parameters_bitcoin_alpha_linkpred_egcn_h.yaml ├── parameters_bitcoin_alpha_linkpred_egcn_o.yaml ├── parameters_bitcoin_otc_edgecls_egcn_h.yaml ├── parameters_bitcoin_otc_edgecls_egcn_o.yaml ├── parameters_bitcoin_otc_linkpred_egcn_h.yaml ├── parameters_bitcoin_otc_linkpred_egcn_o.yaml ├── parameters_elliptic_egcn_h.yaml ├── parameters_elliptic_egcn_o.yaml ├── parameters_example.yaml ├── parameters_reddit_egcn_h.yaml ├── parameters_reddit_egcn_o.yaml ├── parameters_sbm_egcn_h.yaml ├── parameters_sbm_egcn_o.yaml ├── parameters_uc_irv_mess_egcn_h.yaml └── parameters_uc_irv_mess_egcn_o.yaml ├── link_pred_tasker.py ├── log_analyzer.py ├── logger.py ├── models.py ├── node_cls_tasker.py ├── reddit_dl.py ├── run_exp.py ├── sbm_dl.py ├── splitter.py ├── taskers_utils.py ├── trainer.py ├── uc_irv_mess_dl.py └── utils.py /Cross_Entropy.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import utils as u 3 | 4 | class Cross_Entropy(torch.nn.Module): 5 | """docstring for Cross_Entropy""" 6 | def __init__(self, args, dataset): 7 | super().__init__() 8 | weights = torch.tensor(args.class_weights).to(args.device) 9 | 10 | self.weights = self.dyn_scale(args.task, dataset, weights) 11 | 12 | 13 | def dyn_scale(self,task,dataset,weights): 14 | # if task == 'link_pred': commented to have a 1:1 ratio 15 | 16 | # ''' 17 | # when doing link prediction there is an extra weighting factor on the non-existing 18 | # edges 19 | # ''' 20 | # tot_neg = dataset.num_non_existing 21 | # def scale(labels): 22 | # cur_neg = (labels == 0).sum(dtype = torch.float) 23 | # out = weights.clone() 24 | # out[0] *= tot_neg/cur_neg 25 | # return out 26 | # else: 27 | # def scale(labels): 28 | # return weights 29 | def scale(labels): 30 | return weights 31 | return scale 32 | 33 | 34 | def logsumexp(self,logits): 35 | m,_ = torch.max(logits,dim=1) 36 | m = m.view(-1,1) 37 | sum_exp = torch.sum(torch.exp(logits-m),dim=1, keepdim=True) 38 | return m + torch.log(sum_exp) 39 | 40 | def forward(self,logits,labels): 41 | ''' 42 | logits is a matrix M by C where m is the number of classifications and C are the number of classes 43 | labels is a integer tensor of size M where each element corresponds to the class that prediction i 44 | should be matching to 45 | ''' 46 | labels = labels.view(-1,1) 47 | alpha = self.weights(labels)[labels].view(-1,1) 48 | loss = alpha * (- logits.gather(-1,labels) + self.logsumexp(logits)) 49 | return loss.mean() 50 | 51 | if __name__ == '__main__': 52 | dataset = u.Namespace({'num_non_existing': torch.tensor(10)}) 53 | args = u.Namespace({'class_weights': [1.0,1.0], 54 | 'task': 'no_link_pred'}) 55 | labels = torch.tensor([1,0]) 56 | ce_ref = torch.nn.CrossEntropyLoss(reduction='sum') 57 | ce = Cross_Entropy(args,dataset) 58 | # print(ce.weights(labels)) 59 | # print(ce.weights(labels)) 60 | logits = torch.tensor([[1.0,-1.0], 61 | [1.0,-1.0]]) 62 | logits = torch.rand((5,2)) 63 | labels = torch.randint(0,2,(5,)) 64 | print(ce(logits,labels)- ce_ref(logits,labels)) 65 | exit() 66 | ce.logsumexp(logits) 67 | # print(labels) 68 | # print(ce.weights(labels)) 69 | # print(ce.weights(labels)[labels]) 70 | x = torch.tensor([0,1]) 71 | y = torch.tensor([1,0]).view(-1,1) 72 | # idx = torch.stack([x,y]) 73 | # print(idx) 74 | # print(idx) 75 | print(logits.gather(-1,y)) 76 | # print(logits.index_select(0,torch.tensor([0,1]))) -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright [yyyy] [name of copyright owner] 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | -------------------------------------------------------------------------------- /Readme.md: -------------------------------------------------------------------------------- 1 | EvolveGCN 2 | ===== 3 | 4 | This repository contains the code for [EvolveGCN: Evolving Graph Convolutional Networks for Dynamic Graphs](https://arxiv.org/abs/1902.10191), published in AAAI 2020. 5 | 6 | ## Data 7 | 8 | 7 datasets were used in the paper: 9 | 10 | - stochastic block model: See the 'data' folder. Untar the file for use. 11 | - bitcoin OTC: Downloadable from http://snap.stanford.edu/data/soc-sign-bitcoin-otc.html 12 | - bitcoin Alpha: Downloadable from http://snap.stanford.edu/data/soc-sign-bitcoin-alpha.html 13 | - uc_irvine: Downloadable from http://konect.uni-koblenz.de/networks/opsahl-ucsocial 14 | - autonomous systems: Downloadable from http://snap.stanford.edu/data/as-733.html 15 | - reddit hyperlink network: Downloadable from http://snap.stanford.edu/data/soc-RedditHyperlinks.html 16 | - elliptic: A preprocessed version of https://www.kaggle.com/ellipticco/elliptic-data-set is provided in the following link: ~~https://ibm.box.com/s/j04m8lwoqktjixke2gj7lgllrvvdidme.~~ Untar the file in the 'data' folder for use. 17 | 18 | Update on elliptic: The box link is no longer valid. Please see the [instruction](elliptic_construction.md) to manually prepare the preprocessed version. 19 | 20 | For downloaded data sets please place them in the 'data' folder. 21 | 22 | ## Requirements 23 | * PyTorch 1.0 or higher 24 | * Python 3.6 25 | 26 | ## Set up with Docker 27 | 28 | This docker file describes a container that allows you to run the experiments on any Unix-based machine. GPU availability is recommended to train the models. Otherwise, set the use_cuda flag in parameters.yaml to false. 29 | 30 | ### Requirements 31 | 32 | - [install docker](https://docs.docker.com/install/) 33 | - [install nvidia drivers](https://www.nvidia.com/Download/index.aspx?lang=en-us) 34 | 35 | ### Installation 36 | 37 | #### 1. Build the image 38 | 39 | From this folder you can create the image 40 | 41 | ```sh 42 | sudo docker build -t gcn_env:latest docker-set-up/ 43 | ``` 44 | 45 | #### 2. Start the container 46 | 47 | Start the container 48 | 49 | ```sh 50 | sudo docker run -ti --gpus all -v $(pwd):/evolveGCN gcn_env:latest 51 | ``` 52 | 53 | This will start a bash session in the container. 54 | 55 | ## Usage 56 | 57 | Set --config_file with a yaml configuration file to run the experiments. For example: 58 | 59 | ```sh 60 | python run_exp.py --config_file ./experiments/parameters_example.yaml 61 | ``` 62 | 63 | Most of the parameters in the yaml configuration file are self-explanatory. For hyperparameters tuning, it is possible to set a certain parameter to 'None' and then set a min and max value. Then, each run will pick a random value within the boundaries (for example: 'learning_rate', 'learning_rate_min' and 'learning_rate_max'). 64 | The 'experiments' folder contains one file for each result reported in the [EvolveGCN paper](https://arxiv.org/abs/1902.10191). 65 | 66 | Setting 'use_logfile' to True in the configuration yaml will output a file, in the 'log' directory, containing information about the experiment and validation metrics for the various epochs. The file could be manually analyzed, alternatively 'log_analyzer.py' can be used to automatically parse a log file and to retrieve the evaluation metrics at the best validation epoch. For example: 67 | ```sh 68 | python log_analyzer.py log/filename.log 69 | ``` 70 | 71 | 72 | ## Reference 73 | 74 | [1] Aldo Pareja, Giacomo Domeniconi, Jie Chen, Tengfei Ma, Toyotaro Suzumura, Hiroki Kanezashi, Tim Kaler, Tao B. Schardl, and Charles E. Leiserson. [EvolveGCN: Evolving Graph Convolutional Networks for Dynamic Graphs](https://arxiv.org/abs/1902.10191). AAAI 2020. 75 | 76 | ## BibTeX entry 77 | 78 | Please cite the paper if you use this code in your work: 79 | 80 | 81 | ``` 82 | @INPROCEEDINGS{egcn, 83 | AUTHOR = {Aldo Pareja and Giacomo Domeniconi and Jie Chen and Tengfei Ma and Toyotaro Suzumura and Hiroki Kanezashi and Tim Kaler and Tao B. Schardl and Charles E. Leiserson}, 84 | TITLE = {{EvolveGCN}: Evolving Graph Convolutional Networks for Dynamic Graphs}, 85 | BOOKTITLE = {Proceedings of the Thirty-Fourth AAAI Conference on Artificial Intelligence}, 86 | YEAR = {2020}, 87 | } 88 | ``` 89 | -------------------------------------------------------------------------------- /auto_syst_dl.py: -------------------------------------------------------------------------------- 1 | import utils as u 2 | import os 3 | 4 | import tarfile 5 | 6 | import torch 7 | 8 | from datetime import datetime 9 | 10 | 11 | class Autonomous_Systems_Dataset(): 12 | def __init__(self,args): 13 | args.aut_sys_args = u.Namespace(args.aut_sys_args) 14 | 15 | tar_file = os.path.join(args.aut_sys_args.folder, args.aut_sys_args.tar_file) 16 | tar_archive = tarfile.open(tar_file, 'r:gz') 17 | 18 | self.edges = self.load_edges(args,tar_archive) 19 | 20 | def load_edges(self,args,tar_archive): 21 | files = tar_archive.getnames() 22 | 23 | cont_files2times = self.times_from_names(files) 24 | 25 | edges = [] 26 | cols = u.Namespace({'source': 0, 27 | 'target': 1, 28 | 'time': 2}) 29 | for file in files: 30 | data = u.load_data_from_tar(file, 31 | tar_archive, 32 | starting_line=4, 33 | sep='\t', 34 | type_fn = int, 35 | tensor_const = torch.LongTensor) 36 | 37 | time_col = torch.zeros(data.size(0),1,dtype=torch.long) + cont_files2times[file] 38 | 39 | data = torch.cat([data,time_col],dim = 1) 40 | 41 | data = torch.cat([data,data[:,[cols.target, 42 | cols.source, 43 | cols.time]]]) 44 | 45 | edges.append(data) 46 | 47 | edges = torch.cat(edges) 48 | 49 | 50 | _,edges[:,[cols.source,cols.target]] = edges[:,[cols.source,cols.target]].unique(return_inverse = True) 51 | 52 | 53 | #use only first X time steps 54 | indices = edges[:,cols.time] < args.aut_sys_args.steps_accounted 55 | edges = edges[indices,:] 56 | 57 | #time aggregation 58 | edges[:,cols.time] = u.aggregate_by_time(edges[:,cols.time],args.aut_sys_args.aggr_time) 59 | 60 | self.num_nodes = int(edges[:,[cols.source,cols.target]].max()+1) 61 | 62 | 63 | ids = edges[:,cols.source] * self.num_nodes + edges[:,cols.target] 64 | self.num_non_existing = float(self.num_nodes**2 - ids.unique().size(0)) 65 | 66 | 67 | self.max_time = edges[:,cols.time].max() 68 | self.min_time = edges[:,cols.time].min() 69 | 70 | return {'idx': edges, 'vals': torch.ones(edges.size(0))} 71 | 72 | def times_from_names(self,files): 73 | files2times = {} 74 | times2files = {} 75 | 76 | base = datetime.strptime("19800101", '%Y%m%d') 77 | for file in files: 78 | delta = (datetime.strptime(file[2:-4], '%Y%m%d') - base).days 79 | 80 | files2times[file] = delta 81 | times2files[delta] = file 82 | 83 | 84 | cont_files2times = {} 85 | 86 | sorted_times = sorted(files2times.values()) 87 | new_t = 0 88 | 89 | for t in sorted_times: 90 | 91 | file = times2files[t] 92 | 93 | cont_files2times[file] = new_t 94 | 95 | new_t += 1 96 | return cont_files2times -------------------------------------------------------------------------------- /bitcoin_dl.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import utils as u 3 | import os 4 | 5 | class bitcoin_dataset(): 6 | def __init__(self,args): 7 | assert args.task in ['link_pred', 'edge_cls'], 'bitcoin only implements link_pred or edge_cls' 8 | self.ecols = u.Namespace({'FromNodeId': 0, 9 | 'ToNodeId': 1, 10 | 'Weight': 2, 11 | 'TimeStep': 3 12 | }) 13 | args.bitcoin_args = u.Namespace(args.bitcoin_args) 14 | 15 | #build edge data structure 16 | edges = self.load_edges(args.bitcoin_args) 17 | 18 | edges = self.make_contigous_node_ids(edges) 19 | num_nodes = edges[:,[self.ecols.FromNodeId, 20 | self.ecols.ToNodeId]].unique().size(0) 21 | 22 | timesteps = u.aggregate_by_time(edges[:,self.ecols.TimeStep],args.bitcoin_args.aggr_time) 23 | self.max_time = timesteps.max() 24 | self.min_time = timesteps.min() 25 | edges[:,self.ecols.TimeStep] = timesteps 26 | 27 | edges[:,self.ecols.Weight] = self.cluster_negs_and_positives(edges[:,self.ecols.Weight]) 28 | 29 | 30 | #add the reversed link to make the graph undirected 31 | edges = torch.cat([edges,edges[:,[self.ecols.ToNodeId, 32 | self.ecols.FromNodeId, 33 | self.ecols.Weight, 34 | self.ecols.TimeStep]]]) 35 | 36 | #separate classes 37 | sp_indices = edges[:,[self.ecols.FromNodeId, 38 | self.ecols.ToNodeId, 39 | self.ecols.TimeStep]].t() 40 | sp_values = edges[:,self.ecols.Weight] 41 | 42 | 43 | neg_mask = sp_values == -1 44 | 45 | neg_sp_indices = sp_indices[:,neg_mask] 46 | neg_sp_values = sp_values[neg_mask] 47 | neg_sp_edges = torch.sparse.LongTensor(neg_sp_indices 48 | ,neg_sp_values, 49 | torch.Size([num_nodes, 50 | num_nodes, 51 | self.max_time+1])).coalesce() 52 | 53 | pos_mask = sp_values == 1 54 | 55 | pos_sp_indices = sp_indices[:,pos_mask] 56 | pos_sp_values = sp_values[pos_mask] 57 | 58 | pos_sp_edges = torch.sparse.LongTensor(pos_sp_indices 59 | ,pos_sp_values, 60 | torch.Size([num_nodes, 61 | num_nodes, 62 | self.max_time+1])).coalesce() 63 | 64 | #scale positive class to separate after adding 65 | pos_sp_edges *= 1000 66 | 67 | #we substract the neg_sp_edges to make the values positive 68 | sp_edges = (pos_sp_edges - neg_sp_edges).coalesce() 69 | 70 | #separating negs and positive edges per edge/timestamp 71 | vals = sp_edges._values() 72 | neg_vals = vals%1000 73 | pos_vals = vals//1000 74 | #We add the negative and positive scores and do majority voting 75 | vals = pos_vals - neg_vals 76 | #creating labels new_vals -> the label of the edges 77 | new_vals = torch.zeros(vals.size(0),dtype=torch.long) 78 | new_vals[vals>0] = 1 79 | new_vals[vals<=0] = 0 80 | indices_labels = torch.cat([sp_edges._indices().t(),new_vals.view(-1,1)],dim=1) 81 | 82 | #the weight of the edges (vals), is simply the number of edges between two entities at each time_step 83 | vals = pos_vals + neg_vals 84 | 85 | 86 | self.edges = {'idx': indices_labels, 'vals': vals} 87 | self.num_nodes = num_nodes 88 | self.num_classes = 2 89 | 90 | 91 | def cluster_negs_and_positives(self,ratings): 92 | pos_indices = ratings > 0 93 | neg_indices = ratings <= 0 94 | ratings[pos_indices] = 1 95 | ratings[neg_indices] = -1 96 | return ratings 97 | 98 | def prepare_node_feats(self,node_feats): 99 | node_feats = node_feats[0] 100 | return node_feats 101 | 102 | def edges_to_sp_dict(self,edges): 103 | idx = edges[:,[self.ecols.FromNodeId, 104 | self.ecols.ToNodeId, 105 | self.ecols.TimeStep]] 106 | 107 | vals = edges[:,self.ecols.Weight] 108 | return {'idx': idx, 109 | 'vals': vals} 110 | 111 | def get_num_nodes(self,edges): 112 | all_ids = edges[:,[self.ecols.FromNodeId,self.ecols.ToNodeId]] 113 | num_nodes = all_ids.max() + 1 114 | return num_nodes 115 | 116 | def load_edges(self,bitcoin_args): 117 | file = os.path.join(bitcoin_args.folder,bitcoin_args.edges_file) 118 | with open(file) as f: 119 | lines = f.read().splitlines() 120 | edges = [[float(r) for r in row.split(',')] for row in lines] 121 | edges = torch.tensor(edges,dtype = torch.long) 122 | return edges 123 | 124 | def make_contigous_node_ids(self,edges): 125 | new_edges = edges[:,[self.ecols.FromNodeId,self.ecols.ToNodeId]] 126 | _, new_edges = new_edges.unique(return_inverse=True) 127 | edges[:,[self.ecols.FromNodeId,self.ecols.ToNodeId]] = new_edges 128 | return edges 129 | -------------------------------------------------------------------------------- /data/sbm_50t_1000n_adj.csv.tar.gz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/IBM/EvolveGCN/90869062bbc98d56935e3d92e1d9b1b4c25be593/data/sbm_50t_1000n_adj.csv.tar.gz -------------------------------------------------------------------------------- /docker-set-up/Dockerfile: -------------------------------------------------------------------------------- 1 | FROM nvidia/cuda:10.1-base-ubuntu16.04 2 | 3 | ENV DEBIAN_FRONTEND noninteractive 4 | 5 | RUN apt-get update; 6 | 7 | RUN apt-get install -y wget; \ 8 | wget https://repo.anaconda.com/miniconda/Miniconda3-latest-Linux-x86_64.sh; \ 9 | bash Miniconda3-latest-Linux-x86_64.sh -b -p $HOME/miniconda;\ 10 | eval "$(/root/miniconda/bin/conda shell.bash hook)";\ 11 | conda init; \ 12 | pip install pyyaml; \ 13 | conda install -y pytorch cudatoolkit=10.0 -c pytorch; \ 14 | conda install -y matplotlib pandas scikit-learn; 15 | 16 | WORKDIR /evolveGCN 17 | ENTRYPOINT /bin/bash 18 | 19 | -------------------------------------------------------------------------------- /docker-set-up/README.md: -------------------------------------------------------------------------------- 1 | # Set Up With Docker 2 | 3 | This docker file describes a container that allows you to run the experiments on any Unix-based machine. GPU availability is recommended to train the models. Otherwise, set the use_cuda flag in parameters.yaml to false. 4 | 5 | # Requirements 6 | 7 | - [install docker](https://docs.docker.com/install/) 8 | - [install nvidia drivers](https://www.nvidia.com/Download/index.aspx?lang=en-us) 9 | 10 | # Installation 11 | 12 | ## 1. Build the image 13 | 14 | From this folder you can create the image 15 | 16 | ```sh 17 | sudo docker build -t gcn_env:latest docker-set-up/ 18 | ``` 19 | 20 | ## 2. Start the container 21 | 22 | Start the container 23 | 24 | ```sh 25 | sudo docker run -ti --gpus all -v $(pwd):/evolveGCN gcn_env:latest 26 | ``` 27 | 28 | This will start a bash session in the container. 29 | 30 | ## 3. Run an experiment 31 | 32 | Run the following command for example: 33 | 34 | ```sh 35 | python run_exp.py --config_file ./experiments/parameters_uc_irv_mess.yaml 36 | ``` 37 | 38 | -------------------------------------------------------------------------------- /edge_cls_tasker.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import taskers_utils as tu 3 | import utils as u 4 | 5 | 6 | class Edge_Cls_Tasker(): 7 | def __init__(self,args,dataset): 8 | self.data = dataset 9 | #max_time for link pred should be one before 10 | self.max_time = dataset.max_time 11 | self.args = args 12 | self.num_classes = dataset.num_classes 13 | 14 | if not args.use_1_hot_node_feats: 15 | self.feats_per_node = dataset.feats_per_node 16 | 17 | self.get_node_feats = self.build_get_node_feats(args,dataset) 18 | self.prepare_node_feats = self.build_prepare_node_feats(args,dataset) 19 | 20 | self.is_static = False 21 | 22 | def build_prepare_node_feats(self,args,dataset): 23 | if args.use_2_hot_node_feats or args.use_1_hot_node_feats: 24 | def prepare_node_feats(node_feats): 25 | return u.sparse_prepare_tensor(node_feats, 26 | torch_size= [dataset.num_nodes, 27 | self.feats_per_node]) 28 | else: 29 | prepare_node_feats = self.data.prepare_node_feats 30 | 31 | return prepare_node_feats 32 | 33 | 34 | def build_get_node_feats(self,args,dataset): 35 | if args.use_2_hot_node_feats: 36 | max_deg_out, max_deg_in = tu.get_max_degs(args,dataset) 37 | self.feats_per_node = max_deg_out + max_deg_in 38 | def get_node_feats(adj): 39 | return tu.get_2_hot_deg_feats(adj, 40 | max_deg_out, 41 | max_deg_in, 42 | dataset.num_nodes) 43 | elif args.use_1_hot_node_feats: 44 | max_deg,_ = tu.get_max_degs(args,dataset) 45 | self.feats_per_node = max_deg 46 | def get_node_feats(adj): 47 | return tu.get_1_hot_deg_feats(adj, 48 | max_deg, 49 | dataset.num_nodes) 50 | else: 51 | def get_node_feats(adj): 52 | return dataset.nodes_feats 53 | 54 | return get_node_feats 55 | 56 | 57 | def get_sample(self,idx,test): 58 | hist_adj_list = [] 59 | hist_ndFeats_list = [] 60 | hist_mask_list = [] 61 | 62 | for i in range(idx - self.args.num_hist_steps, idx+1): 63 | cur_adj = tu.get_sp_adj(edges = self.data.edges, 64 | time = i, 65 | weighted = True, 66 | time_window = self.args.adj_mat_time_window) 67 | node_mask = tu.get_node_mask(cur_adj, self.data.num_nodes) 68 | node_feats = self.get_node_feats(cur_adj) 69 | cur_adj = tu.normalize_adj(adj = cur_adj, num_nodes = self.data.num_nodes) 70 | 71 | hist_adj_list.append(cur_adj) 72 | hist_ndFeats_list.append(node_feats) 73 | hist_mask_list.append(node_mask) 74 | 75 | label_adj = tu.get_edge_labels(edges = self.data.edges, 76 | time = idx) 77 | 78 | 79 | return {'idx': idx, 80 | 'hist_adj_list': hist_adj_list, 81 | 'hist_ndFeats_list': hist_ndFeats_list, 82 | 'label_sp': label_adj, 83 | 'node_mask_list': hist_mask_list} 84 | 85 | -------------------------------------------------------------------------------- /egcn_h.py: -------------------------------------------------------------------------------- 1 | import utils as u 2 | import torch 3 | from torch.nn.parameter import Parameter 4 | import torch.nn as nn 5 | import math 6 | 7 | 8 | class EGCN(torch.nn.Module): 9 | def __init__(self, args, activation, device='cpu', skipfeats=False): 10 | super().__init__() 11 | GRCU_args = u.Namespace({}) 12 | 13 | feats = [args.feats_per_node, 14 | args.layer_1_feats, 15 | args.layer_2_feats] 16 | self.device = device 17 | self.skipfeats = skipfeats 18 | self.GRCU_layers = [] 19 | self._parameters = nn.ParameterList() 20 | for i in range(1,len(feats)): 21 | GRCU_args = u.Namespace({'in_feats' : feats[i-1], 22 | 'out_feats': feats[i], 23 | 'activation': activation}) 24 | 25 | grcu_i = GRCU(GRCU_args) 26 | #print (i,'grcu_i', grcu_i) 27 | self.GRCU_layers.append(grcu_i.to(self.device)) 28 | self._parameters.extend(list(self.GRCU_layers[-1].parameters())) 29 | 30 | def parameters(self): 31 | return self._parameters 32 | 33 | def forward(self,A_list, Nodes_list,nodes_mask_list): 34 | node_feats= Nodes_list[-1] 35 | 36 | for unit in self.GRCU_layers: 37 | Nodes_list = unit(A_list,Nodes_list,nodes_mask_list) 38 | 39 | out = Nodes_list[-1] 40 | if self.skipfeats: 41 | out = torch.cat((out,node_feats), dim=1) # use node_feats.to_dense() if 2hot encoded input 42 | return out 43 | 44 | 45 | class GRCU(torch.nn.Module): 46 | def __init__(self,args): 47 | super().__init__() 48 | self.args = args 49 | cell_args = u.Namespace({}) 50 | cell_args.rows = args.in_feats 51 | cell_args.cols = args.out_feats 52 | 53 | self.evolve_weights = mat_GRU_cell(cell_args) 54 | 55 | self.activation = self.args.activation 56 | self.GCN_init_weights = Parameter(torch.Tensor(self.args.in_feats,self.args.out_feats)) 57 | self.reset_param(self.GCN_init_weights) 58 | 59 | def reset_param(self,t): 60 | #Initialize based on the number of columns 61 | stdv = 1. / math.sqrt(t.size(1)) 62 | t.data.uniform_(-stdv,stdv) 63 | 64 | def forward(self,A_list,node_embs_list,mask_list): 65 | GCN_weights = self.GCN_init_weights 66 | out_seq = [] 67 | for t,Ahat in enumerate(A_list): 68 | node_embs = node_embs_list[t] 69 | #first evolve the weights from the initial and use the new weights with the node_embs 70 | GCN_weights = self.evolve_weights(GCN_weights,node_embs,mask_list[t]) 71 | node_embs = self.activation(Ahat.matmul(node_embs.matmul(GCN_weights))) 72 | 73 | out_seq.append(node_embs) 74 | 75 | return out_seq 76 | 77 | class mat_GRU_cell(torch.nn.Module): 78 | def __init__(self,args): 79 | super().__init__() 80 | self.args = args 81 | self.update = mat_GRU_gate(args.rows, 82 | args.cols, 83 | torch.nn.Sigmoid()) 84 | 85 | self.reset = mat_GRU_gate(args.rows, 86 | args.cols, 87 | torch.nn.Sigmoid()) 88 | 89 | self.htilda = mat_GRU_gate(args.rows, 90 | args.cols, 91 | torch.nn.Tanh()) 92 | 93 | self.choose_topk = TopK(feats = args.rows, 94 | k = args.cols) 95 | 96 | def forward(self,prev_Q,prev_Z,mask): 97 | z_topk = self.choose_topk(prev_Z,mask) 98 | 99 | update = self.update(z_topk,prev_Q) 100 | reset = self.reset(z_topk,prev_Q) 101 | 102 | h_cap = reset * prev_Q 103 | h_cap = self.htilda(z_topk, h_cap) 104 | 105 | new_Q = (1 - update) * prev_Q + update * h_cap 106 | 107 | return new_Q 108 | 109 | 110 | 111 | class mat_GRU_gate(torch.nn.Module): 112 | def __init__(self,rows,cols,activation): 113 | super().__init__() 114 | self.activation = activation 115 | #the k here should be in_feats which is actually the rows 116 | self.W = Parameter(torch.Tensor(rows,rows)) 117 | self.reset_param(self.W) 118 | 119 | self.U = Parameter(torch.Tensor(rows,rows)) 120 | self.reset_param(self.U) 121 | 122 | self.bias = Parameter(torch.zeros(rows,cols)) 123 | 124 | def reset_param(self,t): 125 | #Initialize based on the number of columns 126 | stdv = 1. / math.sqrt(t.size(1)) 127 | t.data.uniform_(-stdv,stdv) 128 | 129 | def forward(self,x,hidden): 130 | out = self.activation(self.W.matmul(x) + \ 131 | self.U.matmul(hidden) + \ 132 | self.bias) 133 | 134 | return out 135 | 136 | class TopK(torch.nn.Module): 137 | def __init__(self,feats,k): 138 | super().__init__() 139 | self.scorer = Parameter(torch.Tensor(feats,1)) 140 | self.reset_param(self.scorer) 141 | 142 | self.k = k 143 | 144 | def reset_param(self,t): 145 | #Initialize based on the number of rows 146 | stdv = 1. / math.sqrt(t.size(0)) 147 | t.data.uniform_(-stdv,stdv) 148 | 149 | def forward(self,node_embs,mask): 150 | scores = node_embs.matmul(self.scorer) / self.scorer.norm() 151 | scores = scores + mask 152 | 153 | vals, topk_indices = scores.view(-1).topk(self.k) 154 | topk_indices = topk_indices[vals > -float("Inf")] 155 | 156 | if topk_indices.size(0) < self.k: 157 | topk_indices = u.pad_with_last_val(topk_indices,self.k) 158 | 159 | tanh = torch.nn.Tanh() 160 | 161 | if isinstance(node_embs, torch.sparse.FloatTensor) or \ 162 | isinstance(node_embs, torch.cuda.sparse.FloatTensor): 163 | node_embs = node_embs.to_dense() 164 | 165 | out = node_embs[topk_indices] * tanh(scores[topk_indices].view(-1,1)) 166 | 167 | #we need to transpose the output 168 | return out.t() 169 | -------------------------------------------------------------------------------- /egcn_o.py: -------------------------------------------------------------------------------- 1 | import utils as u 2 | import torch 3 | from torch.nn.parameter import Parameter 4 | import torch.nn as nn 5 | import math 6 | 7 | 8 | class EGCN(torch.nn.Module): 9 | def __init__(self, args, activation, device='cpu', skipfeats=False): 10 | super().__init__() 11 | GRCU_args = u.Namespace({}) 12 | 13 | feats = [args.feats_per_node, 14 | args.layer_1_feats, 15 | args.layer_2_feats] 16 | self.device = device 17 | self.skipfeats = skipfeats 18 | self.GRCU_layers = [] 19 | self._parameters = nn.ParameterList() 20 | for i in range(1,len(feats)): 21 | GRCU_args = u.Namespace({'in_feats' : feats[i-1], 22 | 'out_feats': feats[i], 23 | 'activation': activation}) 24 | 25 | grcu_i = GRCU(GRCU_args) 26 | #print (i,'grcu_i', grcu_i) 27 | self.GRCU_layers.append(grcu_i.to(self.device)) 28 | self._parameters.extend(list(self.GRCU_layers[-1].parameters())) 29 | 30 | def parameters(self): 31 | return self._parameters 32 | 33 | def forward(self,A_list, Nodes_list,nodes_mask_list): 34 | node_feats= Nodes_list[-1] 35 | 36 | for unit in self.GRCU_layers: 37 | Nodes_list = unit(A_list,Nodes_list)#,nodes_mask_list) 38 | 39 | out = Nodes_list[-1] 40 | if self.skipfeats: 41 | out = torch.cat((out,node_feats), dim=1) # use node_feats.to_dense() if 2hot encoded input 42 | return out 43 | 44 | 45 | class GRCU(torch.nn.Module): 46 | def __init__(self,args): 47 | super().__init__() 48 | self.args = args 49 | cell_args = u.Namespace({}) 50 | cell_args.rows = args.in_feats 51 | cell_args.cols = args.out_feats 52 | 53 | self.evolve_weights = mat_GRU_cell(cell_args) 54 | 55 | self.activation = self.args.activation 56 | self.GCN_init_weights = Parameter(torch.Tensor(self.args.in_feats,self.args.out_feats)) 57 | self.reset_param(self.GCN_init_weights) 58 | 59 | def reset_param(self,t): 60 | #Initialize based on the number of columns 61 | stdv = 1. / math.sqrt(t.size(1)) 62 | t.data.uniform_(-stdv,stdv) 63 | 64 | def forward(self,A_list,node_embs_list):#,mask_list): 65 | GCN_weights = self.GCN_init_weights 66 | out_seq = [] 67 | for t,Ahat in enumerate(A_list): 68 | node_embs = node_embs_list[t] 69 | #first evolve the weights from the initial and use the new weights with the node_embs 70 | GCN_weights = self.evolve_weights(GCN_weights)#,node_embs,mask_list[t]) 71 | node_embs = self.activation(Ahat.matmul(node_embs.matmul(GCN_weights))) 72 | 73 | out_seq.append(node_embs) 74 | 75 | return out_seq 76 | 77 | class mat_GRU_cell(torch.nn.Module): 78 | def __init__(self,args): 79 | super().__init__() 80 | self.args = args 81 | self.update = mat_GRU_gate(args.rows, 82 | args.cols, 83 | torch.nn.Sigmoid()) 84 | 85 | self.reset = mat_GRU_gate(args.rows, 86 | args.cols, 87 | torch.nn.Sigmoid()) 88 | 89 | self.htilda = mat_GRU_gate(args.rows, 90 | args.cols, 91 | torch.nn.Tanh()) 92 | 93 | self.choose_topk = TopK(feats = args.rows, 94 | k = args.cols) 95 | 96 | def forward(self,prev_Q):#,prev_Z,mask): 97 | # z_topk = self.choose_topk(prev_Z,mask) 98 | z_topk = prev_Q 99 | 100 | update = self.update(z_topk,prev_Q) 101 | reset = self.reset(z_topk,prev_Q) 102 | 103 | h_cap = reset * prev_Q 104 | h_cap = self.htilda(z_topk, h_cap) 105 | 106 | new_Q = (1 - update) * prev_Q + update * h_cap 107 | 108 | return new_Q 109 | 110 | 111 | 112 | class mat_GRU_gate(torch.nn.Module): 113 | def __init__(self,rows,cols,activation): 114 | super().__init__() 115 | self.activation = activation 116 | #the k here should be in_feats which is actually the rows 117 | self.W = Parameter(torch.Tensor(rows,rows)) 118 | self.reset_param(self.W) 119 | 120 | self.U = Parameter(torch.Tensor(rows,rows)) 121 | self.reset_param(self.U) 122 | 123 | self.bias = Parameter(torch.zeros(rows,cols)) 124 | 125 | def reset_param(self,t): 126 | #Initialize based on the number of columns 127 | stdv = 1. / math.sqrt(t.size(1)) 128 | t.data.uniform_(-stdv,stdv) 129 | 130 | def forward(self,x,hidden): 131 | out = self.activation(self.W.matmul(x) + \ 132 | self.U.matmul(hidden) + \ 133 | self.bias) 134 | 135 | return out 136 | 137 | class TopK(torch.nn.Module): 138 | def __init__(self,feats,k): 139 | super().__init__() 140 | self.scorer = Parameter(torch.Tensor(feats,1)) 141 | self.reset_param(self.scorer) 142 | 143 | self.k = k 144 | 145 | def reset_param(self,t): 146 | #Initialize based on the number of rows 147 | stdv = 1. / math.sqrt(t.size(0)) 148 | t.data.uniform_(-stdv,stdv) 149 | 150 | def forward(self,node_embs,mask): 151 | scores = node_embs.matmul(self.scorer) / self.scorer.norm() 152 | scores = scores + mask 153 | 154 | vals, topk_indices = scores.view(-1).topk(self.k) 155 | topk_indices = topk_indices[vals > -float("Inf")] 156 | 157 | if topk_indices.size(0) < self.k: 158 | topk_indices = u.pad_with_last_val(topk_indices,self.k) 159 | 160 | tanh = torch.nn.Tanh() 161 | 162 | if isinstance(node_embs, torch.sparse.FloatTensor) or \ 163 | isinstance(node_embs, torch.cuda.sparse.FloatTensor): 164 | node_embs = node_embs.to_dense() 165 | 166 | out = node_embs[topk_indices] * tanh(scores[topk_indices].view(-1,1)) 167 | 168 | #we need to transpose the output 169 | return out.t() 170 | -------------------------------------------------------------------------------- /elliptic_construction.md: -------------------------------------------------------------------------------- 1 | # Instructions for processing the Elliptic data set 2 | 3 | The Elliptic data set, downloadable from [https://www.kaggle.com/ellipticco/elliptic-data-set](), consists of a dynamic graph that can be used for node classification. It was used, for example, for experimentation by the EvolveGCN paper [https://arxiv.org/abs/1902.10191](). A preprocessed version of the data set can be read by the EvolveGCN code, specifically, the dataloader `elliptic_temporal_dl.py` in [https://github.com/IBM/EvolveGCN](). Here are the preprocessing instructions. 4 | 5 | 6 | ## Step 0: Download data 7 | 8 | Download data from [https://www.kaggle.com/ellipticco/elliptic-data-set](). You will see three files `elliptic_txs_features.csv`, `elliptic_txs_classes.csv`, and `elliptic_txs_edgelist.csv`. 9 | 10 | 11 | ## Step 1: Create a file named `elliptic_txs_orig2contiguos.csv` and modify `elliptic_txs_features.csv`. 12 | 13 | The file `elliptic_txs_features.csv` contains node features, one node each line. Each line contains 167 numbers, where the first number denotes node id. For example, the first three numbers in the first line are 14 | 15 | ``` 16 | 230425980,1,-0.1714692896288031 17 | ``` 18 | 19 | Here, `230425980` is the node id. You will replace this id by the line number (starting from 0). Moreover, make the first number and second number floating point numbers. That is, in the modified `elliptic_txs_features.csv`, the three numbers in the first line should be 20 | 21 | ``` 22 | 0.0,1.0,-0.1714692896288031 23 | ``` 24 | 25 | In the newly created `elliptic_txs_orig2contiguos.csv`, the first line is the header 26 | 27 | ``` 28 | originalId,contiguosId 29 | ``` 30 | 31 | and the lines that follow contain the id conversion information. For example, the line after the header line should be 32 | 33 | ``` 34 | 230425980,0 35 | ``` 36 | 37 | 38 | ## Step 2: Modify `elliptic_txs_classes.csv` 39 | 40 | The file `elliptic_txs_classes.csv` contains node labels. Because we have converted the node ids, we need to modify this file accordingly. We also use numeric values to denote the labels. 41 | 42 | Specifically, the classes `unknown`, `1`, and `2` are changed to `-1.0`, `1.0`, and `0`, respectively. For example, the line after the header line is changed from 43 | 44 | ``` 45 | 230425980,unknown 46 | ``` 47 | 48 | to 49 | 50 | ``` 51 | 0.0,-1.0 52 | ``` 53 | 54 | Additionally, the header line is never changed. It is always 55 | 56 | ``` 57 | txId,class 58 | ``` 59 | 60 | 61 | ## Step 3: Create a file named `elliptic_txs_nodetime.csv` 62 | 63 | This file will add a time stamp to each node. 64 | 65 | The header line is 66 | 67 | ``` 68 | txId,timestep 69 | ``` 70 | 71 | Each line afterward will contain two numbers. The first number is the new node id and the second number is the time stamp. 72 | 73 | The time stamp appears in the second column of `elliptic_txs_features.csv`. Recall that the first three numbers in the first line of the original `elliptic_txs_features.csv` are 74 | 75 | ``` 76 | 230425980,1,-0.1714692896288031 77 | ``` 78 | 79 | The second number `1` indicates time stamp. We will use a zero based indexing and hence shift this number down by 1. 80 | 81 | Therefore, the line after the header line of the new file `elliptic_txs_nodetime.csv` is 82 | 83 | ``` 84 | 0,0 85 | ``` 86 | 87 | indicating that the time stamp of node 0 is 0. 88 | 89 | 90 | ## Step 4: Modify `elliptic_txs_edgelist.csv` and rename it to `elliptic_txs_edgelist_timed.csv` 91 | 92 | The header line is changed from 93 | 94 | ``` 95 | txId1,txId2 96 | ``` 97 | 98 | to 99 | 100 | ``` 101 | txId1,txId2,timestep 102 | ``` 103 | 104 | For each line that follows, the two numbers indicating old node ids are changed to new node ids, followed by time stamp (as floating point number). These two nodes always have the same time stamp in `elliptic_txs_nodetime.csv` (I recommend doing a sanity check by yourself). Then the edge with these two nodes has a time stamp the same as the node time stamps. 105 | 106 | Therefore, the line after the header line in `elliptic_txs_edgelist.csv`: 107 | 108 | ``` 109 | 230425980,5530458 110 | ``` 111 | 112 | will be changed to, in `elliptic_txs_edgelist_timed.csv`: 113 | 114 | ``` 115 | 0,1,0.0 116 | ``` 117 | 118 | because the new node id for `230425980` is `0` and that for `5530458 ` is `1`. The time stamp for these two nodes are `0`. 119 | 120 | With all above preprocessing steps done, you should be able to use the dataloader `elliptic_temporal_dl.py` to load in the data set. 121 | -------------------------------------------------------------------------------- /elliptic_temporal_dl.py: -------------------------------------------------------------------------------- 1 | import utils as u 2 | import os 3 | import torch 4 | #erase 5 | import time 6 | import tarfile 7 | import itertools 8 | import numpy as np 9 | 10 | 11 | class Elliptic_Temporal_Dataset(): 12 | def __init__(self,args): 13 | args.elliptic_args = u.Namespace(args.elliptic_args) 14 | 15 | tar_file = os.path.join(args.elliptic_args.folder, args.elliptic_args.tar_file) 16 | tar_archive = tarfile.open(tar_file, 'r:gz') 17 | 18 | self.nodes_labels_times = self.load_node_labels(args.elliptic_args, tar_archive) 19 | 20 | self.edges = self.load_transactions(args.elliptic_args, tar_archive) 21 | 22 | self.nodes, self.nodes_feats = self.load_node_feats(args.elliptic_args, tar_archive) 23 | 24 | def load_node_feats(self, elliptic_args, tar_archive): 25 | data = u.load_data_from_tar(elliptic_args.feats_file, tar_archive, starting_line=0) 26 | nodes = data 27 | 28 | nodes_feats = nodes[:,1:] 29 | 30 | 31 | self.num_nodes = len(nodes) 32 | self.feats_per_node = data.size(1) - 1 33 | 34 | return nodes, nodes_feats.float() 35 | 36 | 37 | def load_node_labels(self, elliptic_args, tar_archive): 38 | labels = u.load_data_from_tar(elliptic_args.classes_file, tar_archive, replace_unknow=True).long() 39 | times = u.load_data_from_tar(elliptic_args.times_file, tar_archive, replace_unknow=True).long() 40 | lcols = u.Namespace({'nid': 0, 41 | 'label': 1}) 42 | tcols = u.Namespace({'nid':0, 'time':1}) 43 | 44 | 45 | nodes_labels_times =[] 46 | for i in range(len(labels)): 47 | label = labels[i,[lcols.label]].long() 48 | if label>=0: 49 | nid=labels[i,[lcols.nid]].long() 50 | time=times[nid,[tcols.time]].long() 51 | nodes_labels_times.append([nid , label, time]) 52 | nodes_labels_times = torch.tensor(nodes_labels_times) 53 | 54 | return nodes_labels_times 55 | 56 | 57 | def load_transactions(self, elliptic_args, tar_archive): 58 | data = u.load_data_from_tar(elliptic_args.edges_file, tar_archive, type_fn=float, tensor_const=torch.LongTensor) 59 | tcols = u.Namespace({'source': 0, 60 | 'target': 1, 61 | 'time': 2}) 62 | 63 | data = torch.cat([data,data[:,[1,0,2]]]) 64 | 65 | self.max_time = data[:,tcols.time].max() 66 | self.min_time = data[:,tcols.time].min() 67 | 68 | return {'idx': data, 'vals': torch.ones(data.size(0))} 69 | -------------------------------------------------------------------------------- /experiments/parameters_auto_syst_egcn_h.yaml: -------------------------------------------------------------------------------- 1 | data: autonomous_syst 2 | 3 | aut_sys_args: 4 | folder: ./data 5 | tar_file: as-733.tar.gz 6 | aggr_time: 1 #number of days per time step (window size) 7 | steps_accounted: 100 #only first 100 steps 8 | 9 | 10 | use_cuda: True 11 | use_logfile: True 12 | 13 | model: egcn_h 14 | 15 | task: link_pred 16 | 17 | class_weights: [ 0.1, 0.9] 18 | use_2_hot_node_feats: False 19 | use_1_hot_node_feats: True 20 | save_node_embeddings: False 21 | 22 | train_proportion: 0.7 23 | dev_proportion: 0.1 24 | 25 | num_epochs: 1000 #number of passes though the data 26 | steps_accum_gradients: 1 27 | learning_rate: 0.005 28 | learning_rate_min: 0.0005 29 | learning_rate_max: 0.05 30 | negative_mult_training: 100 31 | negative_mult_test: 100 32 | smart_neg_sampling: True 33 | seed: 1234 34 | target_measure: MAP # measure to define the best epoch F1, Precision, Recall, MRR, MAP 35 | target_class: 1 # Target class to get the measure to define the best epoch (all, 0, 1) 36 | early_stop_patience: 50 37 | 38 | 39 | eval_after_epochs: 5 40 | adj_mat_time_window: 1 # Time window to create the adj matrix for each timestep. Use None to use all the history (from 0 to t) 41 | # adj_mat_time_window_min: 1 42 | # adj_mat_time_window_max: 10 43 | num_hist_steps: 10 # number of previous steps used for prediction 44 | num_hist_steps_min: 2 # only used if num_hist_steps: None 45 | num_hist_steps_max: 10 # only used if num_hist_steps: None 46 | 47 | data_loading_params: 48 | batch_size: 1 49 | num_workers: 8 50 | 51 | gcn_parameters: 52 | feats_per_node: 100 53 | feats_per_node_min: 50 54 | feats_per_node_max: 256 55 | layer_1_feats: 171 56 | layer_1_feats_min: 20 57 | layer_1_feats_max: 200 58 | layer_2_feats: 50 59 | layer_2_feats_same_as_l1: True 60 | k_top_grcu: 200 61 | num_layers: 2 62 | lstm_l1_layers: 1 63 | lstm_l1_feats: 37 # only used with sp_lstm_B_trainer 64 | lstm_l1_feats_min: 10 65 | lstm_l1_feats_max: 150 66 | lstm_l2_layers: 1 # only used with both sp_lstm_A_trainer and sp_lstm_B_trainer 67 | lstm_l2_feats: None # only used with both sp_lstm_A_trainer and sp_lstm_B_trainer 68 | lstm_l2_feats_same_as_l1: True 69 | cls_feats: 113 # Hidden size of the classifier 70 | cls_feats_min: 100 71 | cls_feats_max: 512 72 | comments: 73 | - added a mask parameter to exclude non-available nodes 74 | - elliptic, skipgcn 75 | -------------------------------------------------------------------------------- /experiments/parameters_auto_syst_egcn_o.yaml: -------------------------------------------------------------------------------- 1 | data: autonomous_syst 2 | 3 | aut_sys_args: 4 | folder: ./data 5 | tar_file: as-733.tar.gz 6 | aggr_time: 1 #number of days per time step (window size) 7 | steps_accounted: 100 #only first 100 steps 8 | 9 | 10 | use_cuda: True 11 | use_logfile: True 12 | 13 | model: egcn_o 14 | 15 | task: link_pred 16 | 17 | class_weights: [ 0.1, 0.9] 18 | use_2_hot_node_feats: False 19 | use_1_hot_node_feats: True 20 | save_node_embeddings: False 21 | 22 | train_proportion: 0.7 23 | dev_proportion: 0.1 24 | 25 | num_epochs: 1000 #number of passes though the data 26 | steps_accum_gradients: 1 27 | learning_rate: 0.011 28 | learning_rate_min: 0.0005 29 | learning_rate_max: 0.05 30 | negative_mult_training: 100 31 | negative_mult_test: 100 32 | smart_neg_sampling: True 33 | seed: 1234 34 | target_measure: MAP # measure to define the best epoch F1, Precision, Recall, MRR, MAP 35 | target_class: 1 # Target class to get the measure to define the best epoch (all, 0, 1) 36 | early_stop_patience: 50 37 | 38 | 39 | eval_after_epochs: 5 40 | adj_mat_time_window: 1 # Time window to create the adj matrix for each timestep. Use None to use all the history (from 0 to t) 41 | # adj_mat_time_window_min: 1 42 | # adj_mat_time_window_max: 10 43 | num_hist_steps: 5 # number of previous steps used for prediction 44 | num_hist_steps_min: 2 # only used if num_hist_steps: None 45 | num_hist_steps_max: 10 # only used if num_hist_steps: None 46 | 47 | data_loading_params: 48 | batch_size: 1 49 | num_workers: 8 50 | 51 | gcn_parameters: 52 | feats_per_node: 30 53 | feats_per_node_min: 50 54 | feats_per_node_max: 256 55 | layer_1_feats: 30 56 | layer_1_feats_min: 20 57 | layer_1_feats_max: 200 58 | layer_2_feats: 30 59 | layer_2_feats_same_as_l1: True 60 | k_top_grcu: 200 61 | num_layers: 2 62 | lstm_l1_layers: 1 63 | lstm_l1_feats: 30 # only used with sp_lstm_B_trainer 64 | lstm_l1_feats_min: 10 65 | lstm_l1_feats_max: 150 66 | lstm_l2_layers: 1 # only used with both sp_lstm_A_trainer and sp_lstm_B_trainer 67 | lstm_l2_feats: None # only used with both sp_lstm_A_trainer and sp_lstm_B_trainer 68 | lstm_l2_feats_same_as_l1: True 69 | cls_feats: 100 # Hidden size of the classifier 70 | cls_feats_min: 100 71 | cls_feats_max: 512 72 | comments: 73 | - added a mask parameter to exclude non-available nodes 74 | - elliptic, skipgcn 75 | -------------------------------------------------------------------------------- /experiments/parameters_bitcoin_alpha_edgecls_egcn_h.yaml: -------------------------------------------------------------------------------- 1 | data: bitcoinalpha # bitcoinalpha bitcoinotc 2 | 3 | bitcoinalpha_args: 4 | folder: ./data/ 5 | edges_file: soc-sign-bitcoinalpha.csv 6 | aggr_time: 1200000 #three weeks in seconds: 1200000 7 | feats_per_node: 3 8 | 9 | use_cuda: True 10 | use_logfile: True 11 | 12 | model: egcn_h 13 | 14 | task: edge_cls # link_pred edge_cls 15 | 16 | class_weights: [ 0.8, 0.2] 17 | use_2_hot_node_feats: False 18 | use_1_hot_node_feats: True 19 | save_node_embeddings: False 20 | 21 | train_proportion: 0.7 22 | dev_proportion: 0.1 23 | 24 | num_epochs: 1000 #number of passes though the data 25 | steps_accum_gradients: 1 26 | learning_rate: 0.001 27 | learning_rate_min: 0.0001 28 | learning_rate_max: 0.1 29 | negative_mult_training: 20 30 | negative_mult_test: 100 31 | smart_neg_sampling: True 32 | seed: 1234 33 | target_measure: F1 # measure to define the best epoch F1, Precision, Recall, MRR, MAP, Loss 34 | target_class: 0 # Target class to get the measure to define the best epoch (AVG, 0, 1) 35 | early_stop_patience: 100 36 | 37 | eval_after_epochs: 5 38 | adj_mat_time_window: 1 39 | num_hist_steps: 10 # number of previous steps used for prediction 40 | num_hist_steps_min: 3 # only used if num_hist_steps: None 41 | num_hist_steps_max: 10 # only used if num_hist_steps: None 42 | 43 | data_loading_params: 44 | batch_size: 1 45 | num_workers: 8 46 | 47 | gcn_parameters: 48 | feats_per_node: 100 49 | feats_per_node_min: 50 50 | feats_per_node_max: 256 51 | layer_1_feats: 100 52 | layer_1_feats_min: 20 53 | layer_1_feats_max: 200 54 | layer_2_feats: 100 55 | layer_2_feats_same_as_l1: True 56 | k_top_grcu: 200 57 | num_layers: 2 58 | lstm_l1_layers: 1 59 | lstm_l1_feats: 100 # only used with sp_lstm_B_trainer 60 | lstm_l1_feats_min: 20 61 | lstm_l1_feats_max: 200 62 | lstm_l2_layers: 1 # only used with both sp_lstm_A_trainer and sp_lstm_B_trainer 63 | lstm_l2_feats: None # only used with both sp_lstm_A_trainer and sp_lstm_B_trainer 64 | lstm_l2_feats_same_as_l1: True 65 | cls_feats: 200 # Hidden size of the classifier 66 | cls_feats_min: 50 67 | cls_feats_max: 500 68 | comments: 69 | - comments 70 | -------------------------------------------------------------------------------- /experiments/parameters_bitcoin_alpha_edgecls_egcn_o.yaml: -------------------------------------------------------------------------------- 1 | data: bitcoinalpha # bitcoinalpha bitcoinotc 2 | 3 | bitcoinalpha_args: 4 | folder: ./data/ 5 | edges_file: soc-sign-bitcoinalpha.csv 6 | aggr_time: 1200000 #three weeks in seconds: 1200000 7 | feats_per_node: 3 8 | 9 | use_cuda: True 10 | use_logfile: True 11 | 12 | model: egcn_o 13 | 14 | task: edge_cls # link_pred edge_cls 15 | 16 | class_weights: [ 0.8, 0.2] 17 | use_2_hot_node_feats: False 18 | use_1_hot_node_feats: True 19 | save_node_embeddings: False 20 | 21 | train_proportion: 0.7 22 | dev_proportion: 0.1 23 | 24 | num_epochs: 1000 #number of passes though the data 25 | steps_accum_gradients: 1 26 | learning_rate: 0.001 27 | learning_rate_min: 0.0001 28 | learning_rate_max: 0.1 29 | negative_mult_training: 20 30 | negative_mult_test: 100 31 | smart_neg_sampling: True 32 | seed: 1234 33 | target_measure: F1 # measure to define the best epoch F1, Precision, Recall, MRR, MAP, Loss 34 | target_class: 0 # Target class to get the measure to define the best epoch (AVG, 0, 1) 35 | early_stop_patience: 100 36 | 37 | eval_after_epochs: 5 38 | adj_mat_time_window: 1 39 | num_hist_steps: 10 # number of previous steps used for prediction 40 | num_hist_steps_min: 3 # only used if num_hist_steps: None 41 | num_hist_steps_max: 10 # only used if num_hist_steps: None 42 | 43 | data_loading_params: 44 | batch_size: 1 45 | num_workers: 8 46 | 47 | gcn_parameters: 48 | feats_per_node: 100 49 | feats_per_node_min: 50 50 | feats_per_node_max: 256 51 | layer_1_feats: 122 52 | layer_1_feats_min: 20 53 | layer_1_feats_max: 200 54 | layer_2_feats: 100 55 | layer_2_feats_same_as_l1: True 56 | k_top_grcu: 200 57 | num_layers: 2 58 | lstm_l1_layers: 1 59 | lstm_l1_feats: 68 # only used with sp_lstm_B_trainer 60 | lstm_l1_feats_min: 20 61 | lstm_l1_feats_max: 200 62 | lstm_l2_layers: 1 # only used with both sp_lstm_A_trainer and sp_lstm_B_trainer 63 | lstm_l2_feats: None # only used with both sp_lstm_A_trainer and sp_lstm_B_trainer 64 | lstm_l2_feats_same_as_l1: True 65 | cls_feats: 77 # Hidden size of the classifier 66 | cls_feats_min: 50 67 | cls_feats_max: 500 68 | comments: 69 | - comments 70 | -------------------------------------------------------------------------------- /experiments/parameters_bitcoin_alpha_linkpred_egcn_h.yaml: -------------------------------------------------------------------------------- 1 | data: bitcoinalpha # bitcoinalpha bitcoinotc 2 | 3 | bitcoinalpha_args: 4 | folder: ./data/ 5 | edges_file: soc-sign-bitcoinalpha.csv 6 | aggr_time: 1200000 #three weeks in seconds: 1200000 7 | feats_per_node: 3 8 | 9 | use_cuda: True 10 | use_logfile: True 11 | 12 | model: egcn_h 13 | 14 | task: link_pred 15 | 16 | class_weights: [ 0.1, 0.9] 17 | use_2_hot_node_feats: False 18 | use_1_hot_node_feats: True 19 | save_node_embeddings: False 20 | 21 | train_proportion: 0.7 # with train_proportion: 0.715 we have the 70/30% of actual splits on the 50 timesteps 22 | dev_proportion: 0.1 23 | 24 | num_epochs: 500 #number of passes though the data 25 | steps_accum_gradients: 1 26 | learning_rate: 0.005 27 | learning_rate_min: 0.005 28 | learning_rate_max: 0.05 29 | negative_mult_training: 100 30 | negative_mult_test: 100 31 | smart_neg_sampling: True 32 | seed: 1234 33 | target_measure: MAP # measure to define the best epoch F1, Precision, Recall, MRR, MAP, Loss 34 | target_class: 1 # Target class to get the measure to define the best epoch (all, 0, 1) 35 | early_stop_patience: 50 36 | 37 | eval_after_epochs: 5 38 | adj_mat_time_window: 1 # Time window to create the adj matrix for each timestep. Use None to use all the history (from 0 to t) 39 | # adj_mat_time_window_min: 1 40 | # adj_mat_time_window_max: 10 41 | num_hist_steps: 10 # number of previous steps used for prediction 42 | num_hist_steps_min: 3 # only used if num_hist_steps: None 43 | num_hist_steps_max: 10 # only used if num_hist_steps: None 44 | 45 | data_loading_params: 46 | batch_size: 1 47 | num_workers: 8 48 | 49 | gcn_parameters: 50 | feats_per_node: 100 51 | feats_per_node_min: 50 52 | feats_per_node_max: 256 53 | layer_1_feats: 50 54 | layer_1_feats_min: 21 55 | layer_1_feats_max: 200 56 | layer_2_feats: 30 57 | layer_2_feats_same_as_l1: True 58 | k_top_grcu: 200 59 | num_layers: 2 60 | lstm_l1_layers: 1 61 | lstm_l1_feats: 50 # only used with sp_lstm_B_trainer 62 | lstm_l1_feats_min: 20 63 | lstm_l1_feats_max: 200 64 | lstm_l2_layers: 1 # only used with both sp_lstm_A_trainer and sp_lstm_B_trainer 65 | lstm_l2_feats: None # only used with both sp_lstm_A_trainer and sp_lstm_B_trainer 66 | lstm_l2_feats_same_as_l1: True 67 | cls_feats: 100 # Hidden size of the classifier 68 | cls_feats_min: 51 69 | cls_feats_max: 500 70 | comments: 71 | - comments 72 | -------------------------------------------------------------------------------- /experiments/parameters_bitcoin_alpha_linkpred_egcn_o.yaml: -------------------------------------------------------------------------------- 1 | data: bitcoinalpha # bitcoinalpha bitcoinotc 2 | 3 | bitcoinalpha_args: 4 | folder: ./data/ 5 | edges_file: soc-sign-bitcoinalpha.csv 6 | aggr_time: 1200000 #three weeks in seconds: 1200000 7 | feats_per_node: 3 8 | 9 | use_cuda: True 10 | use_logfile: True 11 | 12 | model: egcn_o 13 | 14 | task: link_pred 15 | 16 | class_weights: [ 0.05, 0.95] 17 | use_2_hot_node_feats: False 18 | use_1_hot_node_feats: True 19 | save_node_embeddings: False 20 | 21 | train_proportion: 0.7 # with train_proportion: 0.715 we have the 70/30% of actual splits on the 50 timesteps 22 | dev_proportion: 0.1 23 | 24 | num_epochs: 500 #number of passes though the data 25 | steps_accum_gradients: 1 26 | learning_rate: 0.005 27 | learning_rate_min: 0.005 28 | learning_rate_max: 0.05 29 | negative_mult_training: 100 30 | negative_mult_test: 100 31 | smart_neg_sampling: True 32 | seed: 1234 33 | target_measure: MAP # measure to define the best epoch F1, Precision, Recall, MRR, MAP, Loss 34 | target_class: 1 # Target class to get the measure to define the best epoch (all, 0, 1) 35 | early_stop_patience: 50 36 | 37 | eval_after_epochs: 5 38 | adj_mat_time_window: 1 # Time window to create the adj matrix for each timestep. Use None to use all the history (from 0 to t) 39 | # adj_mat_time_window_min: 1 40 | # adj_mat_time_window_max: 10 41 | num_hist_steps: 10 # number of previous steps used for prediction 42 | num_hist_steps_min: 3 # only used if num_hist_steps: None 43 | num_hist_steps_max: 10 # only used if num_hist_steps: None 44 | 45 | data_loading_params: 46 | batch_size: 1 47 | num_workers: 8 48 | 49 | gcn_parameters: 50 | feats_per_node: 100 51 | feats_per_node_min: 50 52 | feats_per_node_max: 256 53 | layer_1_feats: 100 54 | layer_1_feats_min: 21 55 | layer_1_feats_max: 200 56 | layer_2_feats: 30 57 | layer_2_feats_same_as_l1: True 58 | k_top_grcu: 200 59 | num_layers: 2 60 | lstm_l1_layers: 1 61 | lstm_l1_feats: 100 # only used with sp_lstm_B_trainer 62 | lstm_l1_feats_min: 20 63 | lstm_l1_feats_max: 200 64 | lstm_l2_layers: 1 # only used with both sp_lstm_A_trainer and sp_lstm_B_trainer 65 | lstm_l2_feats: None # only used with both sp_lstm_A_trainer and sp_lstm_B_trainer 66 | lstm_l2_feats_same_as_l1: True 67 | cls_feats: 400 # Hidden size of the classifier 68 | cls_feats_min: 51 69 | cls_feats_max: 500 70 | comments: 71 | - comments 72 | -------------------------------------------------------------------------------- /experiments/parameters_bitcoin_otc_edgecls_egcn_h.yaml: -------------------------------------------------------------------------------- 1 | data: bitcoinalpha # bitcoinalpha bitcoinotc 2 | 3 | bitcoinotc_args: 4 | folder: ./data/ 5 | edges_file: soc-sign-bitcoinotc.csv 6 | aggr_time: 1200000 #three weeks in seconds: 1200000 7 | feats_per_node: 3 8 | 9 | use_cuda: True 10 | use_logfile: True 11 | 12 | model: egcn_h 13 | 14 | task: edge_cls 15 | 16 | class_weights: [ 0.8, 0.2] 17 | use_2_hot_node_feats: False 18 | use_1_hot_node_feats: True 19 | save_node_embeddings: False 20 | 21 | train_proportion: 0.7 22 | dev_proportion: 0.1 23 | 24 | num_epochs: 1000 25 | steps_accum_gradients: 1 26 | learning_rate: 0.001 27 | learning_rate_min: 0.0001 28 | learning_rate_max: 0.1 29 | negative_mult_training: 20 30 | negative_mult_test: 100 31 | smart_neg_sampling: True 32 | seed: 1234 33 | target_measure: F1 # measure to define the best epoch F1, Precision, Recall, MRR, MAP, Loss 34 | target_class: 0 # Target class to get the measure to define the best epoch (AVG, 0, 1) 35 | early_stop_patience: 100 36 | 37 | eval_after_epochs: 5 38 | adj_mat_time_window: 1 39 | num_hist_steps: 10 # number of previous steps used for prediction 40 | num_hist_steps_min: 3 # only used if num_hist_steps: None 41 | num_hist_steps_max: 10 # only used if num_hist_steps: None 42 | 43 | data_loading_params: 44 | batch_size: 1 45 | num_workers: 8 46 | 47 | gcn_parameters: 48 | feats_per_node: 100 49 | feats_per_node_min: 50 50 | feats_per_node_max: 256 51 | layer_1_feats: 50 52 | layer_1_feats_min: 20 53 | layer_1_feats_max: 200 54 | layer_2_feats: 100 55 | layer_2_feats_same_as_l1: True 56 | k_top_grcu: 200 57 | num_layers: 2 58 | lstm_l1_layers: 1 59 | lstm_l1_feats: 50 # only used with sp_lstm_B_trainer 60 | lstm_l1_feats_min: 20 61 | lstm_l1_feats_max: 200 62 | lstm_l2_layers: 1 # only used with both sp_lstm_A_trainer and sp_lstm_B_trainer 63 | lstm_l2_feats: None # only used with both sp_lstm_A_trainer and sp_lstm_B_trainer 64 | lstm_l2_feats_same_as_l1: True 65 | cls_feats: 100 # Hidden size of the classifier 66 | cls_feats_min: 50 67 | cls_feats_max: 500 68 | comments: 69 | - comments 70 | -------------------------------------------------------------------------------- /experiments/parameters_bitcoin_otc_edgecls_egcn_o.yaml: -------------------------------------------------------------------------------- 1 | data: bitcoinalpha # bitcoinalpha bitcoinotc 2 | 3 | bitcoinotc_args: 4 | folder: ./data/ 5 | edges_file: soc-sign-bitcoinotc.csv 6 | aggr_time: 1200000 #three weeks in seconds: 1200000 7 | feats_per_node: 3 8 | 9 | use_cuda: True 10 | use_logfile: True 11 | 12 | model: egcn_o 13 | 14 | task: edge_cls # link_pred edge_cls 15 | 16 | class_weights: [ 0.8, 0.2] 17 | use_2_hot_node_feats: False 18 | use_1_hot_node_feats: True 19 | save_node_embeddings: False 20 | 21 | train_proportion: 0.7 22 | dev_proportion: 0.1 23 | 24 | num_epochs: 1000 #number of passes though the data 25 | steps_accum_gradients: 1 26 | learning_rate: 0.001 27 | learning_rate_min: 0.0001 28 | learning_rate_max: 0.1 29 | negative_mult_training: 20 30 | negative_mult_test: 100 31 | smart_neg_sampling: True 32 | seed: 1234 33 | target_measure: F1 # measure to define the best epoch F1, Precision, Recall, MRR, MAP, Loss 34 | target_class: 0 # Target class to get the measure to define the best epoch (AVG, 0, 1) 35 | early_stop_patience: 100 36 | 37 | eval_after_epochs: 5 38 | adj_mat_time_window: 1 39 | num_hist_steps: 10 # number of previous steps used for prediction 40 | num_hist_steps_min: 3 # only used if num_hist_steps: None 41 | num_hist_steps_max: 10 # only used if num_hist_steps: None 42 | 43 | data_loading_params: 44 | batch_size: 1 45 | num_workers: 8 46 | 47 | gcn_parameters: 48 | feats_per_node: 100 49 | feats_per_node_min: 50 50 | feats_per_node_max: 256 51 | layer_1_feats: 32 52 | layer_1_feats_min: 20 53 | layer_1_feats_max: 200 54 | layer_2_feats: 100 55 | layer_2_feats_same_as_l1: True 56 | k_top_grcu: 200 57 | num_layers: 2 58 | lstm_l1_layers: 1 59 | lstm_l1_feats: 150 # only used with sp_lstm_B_trainer 60 | lstm_l1_feats_min: 20 61 | lstm_l1_feats_max: 200 62 | lstm_l2_layers: 1 # only used with both sp_lstm_A_trainer and sp_lstm_B_trainer 63 | lstm_l2_feats: None # only used with both sp_lstm_A_trainer and sp_lstm_B_trainer 64 | lstm_l2_feats_same_as_l1: True 65 | cls_feats: 316 # Hidden size of the classifier 66 | cls_feats_min: 50 67 | cls_feats_max: 500 68 | comments: 69 | - comments 70 | -------------------------------------------------------------------------------- /experiments/parameters_bitcoin_otc_linkpred_egcn_h.yaml: -------------------------------------------------------------------------------- 1 | data: bitcoinalpha # bitcoinalpha bitcoinotc 2 | 3 | bitcoinotc_args: 4 | folder: ./data/ 5 | edges_file: soc-sign-bitcoinotc.csv 6 | aggr_time: 1200000 #three weeks in seconds: 1200000 7 | feats_per_node: 3 8 | 9 | use_cuda: True 10 | use_logfile: True 11 | 12 | model: egcn_h 13 | 14 | task: link_pred # link_pred edge_cls 15 | 16 | class_weights: [ 0.05, 0.95] 17 | use_2_hot_node_feats: False 18 | use_1_hot_node_feats: True 19 | save_node_embeddings: False 20 | 21 | train_proportion: 0.7 # with train_proportion: 0.715 we have the 70/30% of actual splits on the 50 timesteps 22 | dev_proportion: 0.1 23 | 24 | num_epochs: 500 #number of passes though the data 25 | steps_accum_gradients: 1 26 | learning_rate: 0.005 27 | learning_rate_min: 0.005 28 | learning_rate_max: 0.05 29 | negative_mult_training: 100 30 | negative_mult_test: 100 31 | smart_neg_sampling: True 32 | seed: 1234 33 | target_measure: MAP # measure to define the best epoch F1, Precision, Recall, MRR, MAP, Loss 34 | target_class: 1 # Target class to get the measure to define the best epoch (all, 0, 1) 35 | early_stop_patience: 50 36 | 37 | eval_after_epochs: 5 38 | adj_mat_time_window: 1 # Time window to create the adj matrix for each timestep. Use None to use all the history (from 0 to t) 39 | # adj_mat_time_window_min: 1 40 | # adj_mat_time_window_max: 10 41 | num_hist_steps: 10 # number of previous steps used for prediction 42 | num_hist_steps_min: 3 # only used if num_hist_steps: None 43 | num_hist_steps_max: 10 # only used if num_hist_steps: None 44 | 45 | data_loading_params: 46 | batch_size: 1 47 | num_workers: 8 48 | 49 | gcn_parameters: 50 | feats_per_node: 100 51 | feats_per_node_min: 50 52 | feats_per_node_max: 256 53 | layer_1_feats: 33 54 | layer_1_feats_min: 21 55 | layer_1_feats_max: 200 56 | layer_2_feats: 30 57 | layer_2_feats_same_as_l1: True 58 | k_top_grcu: 200 59 | num_layers: 2 60 | lstm_l1_layers: 1 61 | lstm_l1_feats: 107 # only used with sp_lstm_B_trainer 62 | lstm_l1_feats_min: 20 63 | lstm_l1_feats_max: 200 64 | lstm_l2_layers: 1 # only used with both sp_lstm_A_trainer and sp_lstm_B_trainer 65 | lstm_l2_feats: None # only used with both sp_lstm_A_trainer and sp_lstm_B_trainer 66 | lstm_l2_feats_same_as_l1: True 67 | cls_feats: 400 # Hidden size of the classifier 68 | cls_feats_min: 51 69 | cls_feats_max: 500 70 | comments: 71 | - comments 72 | -------------------------------------------------------------------------------- /experiments/parameters_bitcoin_otc_linkpred_egcn_o.yaml: -------------------------------------------------------------------------------- 1 | data: bitcoinalpha # bitcoinalpha bitcoinotc 2 | 3 | bitcoinotc_args: 4 | folder: ./data/ 5 | edges_file: soc-sign-bitcoinotc.csv 6 | aggr_time: 1200000 #three weeks in seconds: 1200000 7 | feats_per_node: 3 8 | 9 | 10 | use_cuda: True 11 | use_logfile: True 12 | 13 | model: egcn_o 14 | 15 | task: link_pred # link_pred edge_cls 16 | 17 | class_weights: [ 0.05, 0.95] 18 | use_2_hot_node_feats: False 19 | use_1_hot_node_feats: True 20 | save_node_embeddings: False 21 | 22 | train_proportion: 0.7 # with train_proportion: 0.715 we have the 70/30% of actual splits on the 50 timesteps 23 | dev_proportion: 0.1 24 | 25 | num_epochs: 500 #number of passes though the data 26 | steps_accum_gradients: 1 27 | learning_rate: 0.01 28 | learning_rate_min: 0.005 29 | learning_rate_max: 0.05 30 | negative_mult_training: 100 31 | negative_mult_test: 100 32 | smart_neg_sampling: True 33 | seed: 1234 34 | target_measure: MAP # measure to define the best epoch F1, Precision, Recall, MRR, MAP, Loss 35 | target_class: 1 # Target class to get the measure to define the best epoch (all, 0, 1) 36 | early_stop_patience: 50 37 | 38 | eval_after_epochs: 5 39 | adj_mat_time_window: 1 # Time window to create the adj matrix for each timestep. Use None to use all the history (from 0 to t) 40 | # adj_mat_time_window_min: 1 41 | # adj_mat_time_window_max: 10 42 | num_hist_steps: 10 # number of previous steps used for prediction 43 | num_hist_steps_min: 3 # only used if num_hist_steps: None 44 | num_hist_steps_max: 10 # only used if num_hist_steps: None 45 | 46 | data_loading_params: 47 | batch_size: 1 48 | num_workers: 8 49 | 50 | gcn_parameters: 51 | feats_per_node: 100 52 | feats_per_node_min: 50 53 | feats_per_node_max: 256 54 | layer_1_feats: 100 55 | layer_1_feats_min: 21 56 | layer_1_feats_max: 200 57 | layer_2_feats: 30 58 | layer_2_feats_same_as_l1: True 59 | k_top_grcu: 200 60 | num_layers: 2 61 | lstm_l1_layers: 1 62 | lstm_l1_feats: 100 # only used with sp_lstm_B_trainer 63 | lstm_l1_feats_min: 20 64 | lstm_l1_feats_max: 200 65 | lstm_l2_layers: 1 # only used with both sp_lstm_A_trainer and sp_lstm_B_trainer 66 | lstm_l2_feats: None # only used with both sp_lstm_A_trainer and sp_lstm_B_trainer 67 | lstm_l2_feats_same_as_l1: True 68 | cls_feats: 400 # Hidden size of the classifier 69 | cls_feats_min: 51 70 | cls_feats_max: 500 71 | comments: 72 | - comments 73 | -------------------------------------------------------------------------------- /experiments/parameters_elliptic_egcn_h.yaml: -------------------------------------------------------------------------------- 1 | data: elliptic_temporal #HELP: arxiv, bitcoin, aml_sim, dbg, elliptic, elliptic_temporal 2 | elliptic_args: 3 | folder: ./data/elliptic_temporal 4 | tar_file: elliptic_bitcoin_dataset_cont.tar.gz 5 | feats_file: elliptic_bitcoin_dataset_cont/elliptic_txs_features.csv 6 | edges_file: elliptic_bitcoin_dataset_cont/elliptic_txs_edgelist_timed.csv 7 | classes_file: elliptic_bitcoin_dataset_cont/elliptic_txs_classes.csv 8 | times_file: elliptic_bitcoin_dataset_cont/elliptic_txs_nodetime.csv 9 | aggr_time: 1 10 | 11 | use_cuda: True 12 | use_logfile: True 13 | 14 | model: egcn_h 15 | 16 | task: node_cls 17 | 18 | class_weights: [ 0.25, 0.75] 19 | use_2_hot_node_feats: False 20 | use_1_hot_node_feats: False 21 | save_node_embeddings: True 22 | 23 | train_proportion: 0.65 24 | dev_proportion: 0.1 25 | num_epochs: 1000 26 | steps_accum_gradients: 1 27 | learning_rate: 0.001 28 | learning_rate_min: 0.001 29 | learning_rate_max: 0.02 30 | negative_mult_training: 20 31 | negative_mult_test: 100 32 | smart_neg_sampling: False 33 | seed: 1234 34 | target_measure: F1 # measure to define the best epoch F1, Precision, Recall, MRR, MAP 35 | target_class: 1 # Target class to get the measure to define the best epoch (all, 0, 1) 36 | early_stop_patience: 100 37 | 38 | eval_after_epochs: 5 39 | adj_mat_time_window: 1 # Time window to create the adj matrix for each timestep. Use None to use all the history (from 0 to t) 40 | adj_mat_time_window_min: 1 41 | adj_mat_time_window_max: 10 42 | num_hist_steps: 5 # number of previous steps used for prediction 43 | num_hist_steps_min: 3 # only used if num_hist_steps: None 44 | num_hist_steps_max: 10 # only used if num_hist_steps: None 45 | data_loading_params: 46 | batch_size: 1 47 | num_workers: 6 48 | gcn_parameters: 49 | feats_per_node: 50 50 | feats_per_node_min: 30 51 | feats_per_node_max: 312 52 | layer_1_feats: 76 53 | layer_1_feats_min: 30 54 | layer_1_feats_max: 500 55 | layer_2_feats: None 56 | layer_2_feats_same_as_l1: True 57 | k_top_grcu: 200 58 | num_layers: 2 59 | lstm_l1_layers: 1 60 | lstm_l1_feats: 125 # only used with sp_lstm_B_trainer 61 | lstm_l1_feats_min: 50 62 | lstm_l1_feats_max: 500 63 | lstm_l2_layers: 1 # only used with both sp_lstm_A_trainer and sp_lstm_B_trainer 64 | lstm_l2_feats: 400 # only used with both sp_lstm_A_trainer and sp_lstm_B_trainer 65 | lstm_l2_feats_same_as_l1: True 66 | cls_feats: 510 # Hidden size of the classifier 67 | cls_feats_min: 100 68 | cls_feats_max: 700 69 | comments: 70 | - comments 71 | -------------------------------------------------------------------------------- /experiments/parameters_elliptic_egcn_o.yaml: -------------------------------------------------------------------------------- 1 | data: elliptic_temporal #HELP: arxiv, bitcoin, aml_sim, dbg, elliptic, elliptic_temporal 2 | elliptic_args: 3 | folder: ./data/elliptic_temporal 4 | tar_file: elliptic_bitcoin_dataset_cont.tar.gz 5 | feats_file: elliptic_bitcoin_dataset_cont/elliptic_txs_features.csv 6 | edges_file: elliptic_bitcoin_dataset_cont/elliptic_txs_edgelist_timed.csv 7 | classes_file: elliptic_bitcoin_dataset_cont/elliptic_txs_classes.csv 8 | times_file: elliptic_bitcoin_dataset_cont/elliptic_txs_nodetime.csv 9 | aggr_time: 1 10 | 11 | use_cuda: True 12 | use_logfile: True 13 | 14 | model: egcn_o 15 | 16 | task: node_cls 17 | 18 | class_weights: [ 0.35, 0.65] 19 | use_2_hot_node_feats: False 20 | use_1_hot_node_feats: False 21 | save_node_embeddings: True 22 | 23 | train_proportion: 0.65 24 | dev_proportion: 0.1 25 | num_epochs: 800 #number of passes though the data 26 | steps_accum_gradients: 1 27 | learning_rate: 0.001 28 | learning_rate_min: 0.001 29 | learning_rate_max: 0.02 30 | negative_mult_training: 20 31 | negative_mult_test: 100 32 | smart_neg_sampling: False 33 | seed: 1234 34 | target_measure: F1 # measure to define the best epoch F1, Precision, Recall, MRR, MAP 35 | target_class: 1 # Target class to get the measure to define the best epoch (all, 0, 1) 36 | early_stop_patience: 100 37 | 38 | eval_after_epochs: 5 39 | adj_mat_time_window: 1 40 | adj_mat_time_window_min: 1 41 | adj_mat_time_window_max: 10 42 | num_hist_steps: 5 # number of previous steps used for prediction 43 | num_hist_steps_min: 3 # only used if num_hist_steps: None 44 | num_hist_steps_max: 10 # only used if num_hist_steps: None 45 | data_loading_params: 46 | batch_size: 1 47 | num_workers: 6 48 | 49 | gcn_parameters: 50 | feats_per_node: 50 51 | feats_per_node_min: 30 52 | feats_per_node_max: 312 53 | layer_1_feats: 256 54 | layer_1_feats_min: 30 55 | layer_1_feats_max: 500 56 | layer_2_feats: None 57 | layer_2_feats_same_as_l1: True 58 | k_top_grcu: 200 59 | num_layers: 2 60 | lstm_l1_layers: 125 61 | lstm_l1_feats: 100 # only used with sp_lstm_B_trainer 62 | lstm_l1_feats_min: 50 63 | lstm_l1_feats_max: 500 64 | lstm_l2_layers: 1 # only used with both sp_lstm_A_trainer and sp_lstm_B_trainer 65 | lstm_l2_feats: 400 # only used with both sp_lstm_A_trainer and sp_lstm_B_trainer 66 | lstm_l2_feats_same_as_l1: True 67 | cls_feats: 307 # Hidden size of the classifier 68 | cls_feats_min: 100 69 | cls_feats_max: 700 70 | comments: 71 | - comments 72 | -------------------------------------------------------------------------------- /experiments/parameters_example.yaml: -------------------------------------------------------------------------------- 1 | data: sbm50 2 | 3 | sbm50_args: 4 | folder: ./data/ 5 | edges_file: sbm_50t_1000n_adj.csv 6 | aggr_time: 1 # 7 | feats_per_node: 3 8 | 9 | 10 | use_cuda: True 11 | use_logfile: False # If True save the output in a log file, if False in stdout 12 | 13 | model: egcn_o #HELP: gcn 14 | # gruA 15 | # gruB 16 | # egcn_h 17 | # egcn_o 18 | 19 | task: link_pred # Help: link_pred, edge_cls or node_cls 20 | 21 | class_weights: [ 0.1, 0.9] 22 | use_2_hot_node_feats: False 23 | use_1_hot_node_feats: True 24 | save_node_embeddings: False 25 | 26 | train_proportion: 0.7 27 | dev_proportion: 0.1 28 | 29 | num_epochs: 100 30 | steps_accum_gradients: 1 31 | learning_rate: 0.005 # use None to pick a random number between learning_rate_min and learning_rate_max 32 | learning_rate_min: 0.0001 33 | learning_rate_max: 0.1 34 | negative_mult_training: 50 35 | negative_mult_test: 100 36 | smart_neg_sampling: True 37 | seed: 1234 38 | target_measure: MAP # measure to define the best epoch F1, Precision, Recall, MRR, MAP 39 | target_class: 1 # Target class to get the measure to define the best epoch (all, 0, 1) 40 | early_stop_patience: 50 41 | 42 | 43 | eval_after_epochs: 5 # Epoch for the first validation (avoid the first epochs to save time) 44 | adj_mat_time_window: 1 # Time window to create the adj matrix for each timestep. Use None to use all the history (from 0 to t) 45 | num_hist_steps: 5 46 | num_hist_steps_min: 1 # only used if num_hist_steps: None 47 | num_hist_steps_max: 10 # only used if num_hist_steps: None 48 | 49 | data_loading_params: 50 | batch_size: 1 # should be always 1 51 | num_workers: 8 52 | 53 | gcn_parameters: 54 | feats_per_node: 100 55 | feats_per_node_min: 50 56 | feats_per_node_max: 256 57 | layer_1_feats: 100 58 | layer_1_feats_min: 10 59 | layer_1_feats_max: 200 60 | layer_2_feats: 100 61 | layer_2_feats_same_as_l1: True 62 | k_top_grcu: 200 63 | num_layers: 2 64 | lstm_l1_layers: 1 65 | lstm_l1_feats: 100 # only used with sp_lstm_B_trainer 66 | lstm_l1_feats_min: 10 67 | lstm_l1_feats_max: 200 68 | lstm_l2_layers: 1 # only used with both sp_lstm_A_trainer and sp_lstm_B_trainer 69 | lstm_l2_feats: None # only used with both sp_lstm_A_trainer and sp_lstm_B_trainer 70 | lstm_l2_feats_same_as_l1: True 71 | cls_feats: 100 # Hidden size of the classifier 72 | cls_feats_min: 100 73 | cls_feats_max: 800 74 | comments: 75 | - comments 76 | -------------------------------------------------------------------------------- /experiments/parameters_reddit_egcn_h.yaml: -------------------------------------------------------------------------------- 1 | data: reddit 2 | 3 | reddit_args: 4 | folder: ./data/reddit 5 | title_edges_file: soc-redditHyperlinks-title.tsv 6 | body_edges_file: soc-redditHyperlinks-body.tsv 7 | nodes_file: web-redditEmbeddings-subreddits.csv 8 | aggr_time: 7 #number of days 9 | 10 | use_cuda: True 11 | use_logfile: True 12 | 13 | model: egcn_h 14 | 15 | task: edge_cls 16 | 17 | class_weights: [ 0.9, 0.1] 18 | use_2_hot_node_feats: False 19 | use_1_hot_node_feats: False 20 | save_node_embeddings: False 21 | 22 | train_proportion: 0.71 23 | dev_proportion: 0.1 24 | 25 | num_epochs: 500 26 | steps_accum_gradients: 1 27 | learning_rate: 0.001 28 | learning_rate_min: 0.0005 29 | learning_rate_max: 0.1 30 | negative_mult_training: 20 31 | negative_mult_test: 100 32 | smart_neg_sampling: False 33 | seed: 1234 34 | target_measure: F1 # measure to define the best epoch F1, Precision, Recall, MRR, MAP 35 | target_class: 0 # Target class to get the measure to define the best epoch (all, 0, 1) 36 | early_stop_patience: 50 37 | 38 | 39 | eval_after_epochs: 1 40 | adj_mat_time_window: 10 # Time window to create the adj matrix for each timestep. Use None to use all the history (from 0 to t) 41 | num_hist_steps: 10 # number of previous steps used for prediction 42 | num_hist_steps_min: 3 # only used if num_hist_steps: None 43 | num_hist_steps_max: 10 # only used if num_hist_steps: None 44 | 45 | data_loading_params: 46 | batch_size: 1 47 | num_workers: 8 48 | 49 | gcn_parameters: 50 | feats_per_node: 100 51 | feats_per_node_min: 100 52 | feats_per_node_max: 256 53 | layer_1_feats: 100 54 | layer_1_feats_min: 10 55 | layer_1_feats_max: 200 56 | layer_2_feats: 20 57 | layer_2_feats_same_as_l1: True 58 | k_top_grcu: 200 59 | num_layers: 2 60 | lstm_l1_layers: 1 61 | lstm_l1_feats: 100 # only used with sp_lstm_B_trainer 62 | lstm_l1_feats_min: 10 63 | lstm_l1_feats_max: 100 64 | lstm_l2_layers: 1 # only used with both sp_lstm_A_trainer and sp_lstm_B_trainer 65 | lstm_l2_feats: None # only used with both sp_lstm_A_trainer and sp_lstm_B_trainer 66 | lstm_l2_feats_same_as_l1: True 67 | cls_feats: 100 # Hidden size of the classifier 68 | cls_feats_min: 100 69 | cls_feats_max: 512 70 | comments: 71 | - comments 72 | -------------------------------------------------------------------------------- /experiments/parameters_reddit_egcn_o.yaml: -------------------------------------------------------------------------------- 1 | data: reddit 2 | 3 | reddit_args: 4 | folder: ./data/reddit 5 | title_edges_file: soc-redditHyperlinks-title.tsv 6 | body_edges_file: soc-redditHyperlinks-body.tsv 7 | nodes_file: web-redditEmbeddings-subreddits.csv 8 | aggr_time: 7 #number of days 9 | 10 | use_cuda: True 11 | use_logfile: True 12 | 13 | model: egcn_o 14 | 15 | task: edge_cls 16 | 17 | class_weights: [ 0.9, 0.1] 18 | use_2_hot_node_feats: False 19 | use_1_hot_node_feats: False 20 | save_node_embeddings: False 21 | 22 | train_proportion: 0.71 # with train_proportion: 0.715 we have the 70/30% of actual splits on the 50 timesteps 23 | dev_proportion: 0.1 24 | 25 | num_epochs: 500 #number of passes though the data 26 | steps_accum_gradients: 1 27 | learning_rate: 0.085167092 28 | learning_rate_min: 0.0005 29 | learning_rate_max: 0.1 30 | negative_mult_training: 20 31 | negative_mult_test: 100 32 | smart_neg_sampling: False 33 | seed: 1234 34 | target_measure: F1 # measure to define the best epoch F1, Precision, Recall, MRR, MAP 35 | target_class: 0 # Target class to get the measure to define the best epoch (all, 0, 1) 36 | early_stop_patience: 50 37 | 38 | 39 | eval_after_epochs: 1 40 | adj_mat_time_window: 10 # Time window to create the adj matrix for each timestep. Use None to use all the history (from 0 to t) 41 | # adj_mat_time_window_min: 1 42 | # adj_mat_time_window_max: 10 43 | num_hist_steps: 10 # number of previous steps used for prediction 44 | num_hist_steps_min: 3 # only used if num_hist_steps: None 45 | num_hist_steps_max: 10 # only used if num_hist_steps: None 46 | 47 | data_loading_params: 48 | batch_size: 1 49 | num_workers: 8 50 | 51 | gcn_parameters: 52 | feats_per_node: 100 53 | feats_per_node_min: 100 54 | feats_per_node_max: 256 55 | layer_1_feats: 152 56 | layer_1_feats_min: 10 57 | layer_1_feats_max: 200 58 | layer_2_feats: 20 59 | layer_2_feats_same_as_l1: True 60 | k_top_grcu: 200 61 | num_layers: 2 62 | lstm_l1_layers: 1 63 | lstm_l1_feats: 35 # only used with sp_lstm_B_trainer 64 | lstm_l1_feats_min: 10 65 | lstm_l1_feats_max: 100 66 | lstm_l2_layers: 1 # only used with both sp_lstm_A_trainer and sp_lstm_B_trainer 67 | lstm_l2_feats: None # only used with both sp_lstm_A_trainer and sp_lstm_B_trainer 68 | lstm_l2_feats_same_as_l1: True 69 | cls_feats: 147 # Hidden size of the classifier 70 | cls_feats_min: 100 71 | cls_feats_max: 512 72 | comments: 73 | - comments 74 | -------------------------------------------------------------------------------- /experiments/parameters_sbm_egcn_h.yaml: -------------------------------------------------------------------------------- 1 | data: sbm50 2 | 3 | sbm50_args: 4 | folder: ./data/ 5 | edges_file: sbm_50t_1000n_adj.csv 6 | aggr_time: 1 # 7 | feats_per_node: 3 # 8 | 9 | use_cuda: True 10 | use_logfile: True 11 | 12 | model: egcn_h 13 | 14 | task: link_pred 15 | 16 | class_weights: [ 0.1, 0.9] 17 | use_2_hot_node_feats: False 18 | use_1_hot_node_feats: True 19 | save_node_embeddings: False 20 | 21 | train_proportion: 0.7 22 | dev_proportion: 0.1 23 | 24 | num_epochs: 100 25 | steps_accum_gradients: 1 26 | learning_rate: 0.01 27 | learning_rate_min: 0.0001 28 | learning_rate_max: 0.1 29 | negative_mult_training: 50 30 | negative_mult_test: 100 31 | smart_neg_sampling: True 32 | seed: 1234 33 | target_measure: MAP # measure to define the best epoch, can be either F1, Precision, Recall, MRR, MAP 34 | target_class: 1 # Target class to get the measure to define the best epoch (all, 0, 1) 35 | early_stop_patience: 50 36 | 37 | 38 | eval_after_epochs: 5 39 | adj_mat_time_window: 1 # Time window to create the adj matrix for each timestep. Use None to use all the history (from 0 to t) 40 | num_hist_steps: 5 # number of previous steps used for prediction 41 | num_hist_steps_min: 2 # only used if num_hist_steps: None 42 | num_hist_steps_max: 10 # only used if num_hist_steps: None 43 | 44 | data_loading_params: 45 | batch_size: 1 46 | num_workers: 8 47 | 48 | gcn_parameters: 49 | feats_per_node: 100 50 | feats_per_node_min: 50 51 | feats_per_node_max: 256 52 | layer_1_feats: 50 53 | layer_1_feats_min: 10 54 | layer_1_feats_max: 200 55 | layer_2_feats: 100 56 | layer_2_feats_same_as_l1: True 57 | k_top_grcu: 200 58 | num_layers: 2 59 | lstm_l1_layers: 1 60 | lstm_l1_feats: 50 # only used with sp_lstm_B_trainer 61 | lstm_l1_feats_min: 10 62 | lstm_l1_feats_max: 200 63 | lstm_l2_layers: 1 # only used with both sp_lstm_A_trainer and sp_lstm_B_trainer 64 | lstm_l2_feats: None # only used with both sp_lstm_A_trainer and sp_lstm_B_trainer 65 | lstm_l2_feats_same_as_l1: True 66 | cls_feats: 100 # Hidden size of the classifier 67 | cls_feats_min: 100 68 | cls_feats_max: 812 69 | comments: 70 | - comments 71 | -------------------------------------------------------------------------------- /experiments/parameters_sbm_egcn_o.yaml: -------------------------------------------------------------------------------- 1 | data: sbm50 2 | 3 | sbm50_args: 4 | folder: ./data/ 5 | edges_file: sbm_50t_1000n_adj.csv 6 | aggr_time: 1 # 7 | feats_per_node: 3 8 | 9 | 10 | use_cuda: True 11 | use_logfile: True 12 | 13 | model: egcn_o 14 | 15 | task: link_pred 16 | 17 | class_weights: [ 0.15, 0.85] 18 | use_2_hot_node_feats: False 19 | use_1_hot_node_feats: True 20 | save_node_embeddings: False 21 | 22 | train_proportion: 0.7 23 | dev_proportion: 0.1 24 | 25 | num_epochs: 100 26 | steps_accum_gradients: 1 27 | learning_rate: 0.005 28 | learning_rate_min: 0.0001 29 | learning_rate_max: 0.1 30 | negative_mult_training: 50 31 | negative_mult_test: 100 32 | smart_neg_sampling: True 33 | seed: 1234 34 | target_measure: MAP # measure to define the best epoch F1, Precision, Recall, MRR, MAP 35 | target_class: 1 # Target class to get the measure to define the best epoch (all, 0, 1) 36 | early_stop_patience: 50 37 | 38 | 39 | eval_after_epochs: 5 40 | adj_mat_time_window: 1 # Time window to create the adj matrix for each timestep. Use None to use all the history (from 0 to t) 41 | num_hist_steps: 5 # number of previous steps used for prediction 42 | num_hist_steps_min: 2 # only used if num_hist_steps: None 43 | num_hist_steps_max: 10 # only used if num_hist_steps: None 44 | 45 | data_loading_params: 46 | batch_size: 1 47 | num_workers: 8 48 | 49 | gcn_parameters: 50 | feats_per_node: 100 51 | feats_per_node_min: 50 52 | feats_per_node_max: 256 53 | layer_1_feats: 51 54 | layer_1_feats_min: 10 55 | layer_1_feats_max: 200 56 | layer_2_feats: 100 57 | layer_2_feats_same_as_l1: True 58 | k_top_grcu: 200 59 | num_layers: 2 60 | lstm_l1_layers: 1 61 | lstm_l1_feats: 157 # only used with sp_lstm_B_trainer 62 | lstm_l1_feats_min: 10 63 | lstm_l1_feats_max: 200 64 | lstm_l2_layers: 1 # only used with both sp_lstm_A_trainer and sp_lstm_B_trainer 65 | lstm_l2_feats: None # only used with both sp_lstm_A_trainer and sp_lstm_B_trainer 66 | lstm_l2_feats_same_as_l1: True 67 | cls_feats: 565 # Hidden size of the classifier 68 | cls_feats_min: 100 69 | cls_feats_max: 812 70 | comments: 71 | - comments 72 | -------------------------------------------------------------------------------- /experiments/parameters_uc_irv_mess_egcn_h.yaml: -------------------------------------------------------------------------------- 1 | data: uc_irv_mess 2 | 3 | uc_irc_args: 4 | folder: ./data 5 | tar_file: opsahl-ucsocial.tar.bz2 6 | edges_file: opsahl-ucsocial/out.opsahl-ucsocial 7 | aggr_time: 190080 #216000 #172800, 86400 smaller numbers yields days with no edges 8 | 9 | 10 | use_cuda: True 11 | use_logfile: False 12 | 13 | model: egcn_h 14 | task: link_pred 15 | 16 | class_weights: [ 0.1, 0.9] 17 | use_2_hot_node_feats: False 18 | use_1_hot_node_feats: True 19 | save_node_embeddings: False 20 | 21 | train_proportion: 0.71 22 | dev_proportion: 0.1 23 | 24 | num_epochs: 1000 #number of passes though the data 25 | steps_accum_gradients: 1 26 | learning_rate: 0.001 27 | learning_rate_min: 0.0005 28 | learning_rate_max: 0.1 29 | negative_mult_training: 50 30 | negative_mult_test: 100 31 | smart_neg_sampling: True 32 | seed: 1234 33 | target_measure: MAP # measure to define the best epoch F1, Precision, Recall, MRR, MAP 34 | target_class: 1 # Target class to get the measure to define the best epoch (all, 0, 1) 35 | early_stop_patience: 50 36 | 37 | eval_after_epochs: 5 38 | adj_mat_time_window: 1 # Time window to create the adj matrix for each timestep. Use None to use all the history (from 0 to t) 39 | num_hist_steps: 10 # number of previous steps used for prediction 40 | num_hist_steps_min: 3 # only used if num_hist_steps: None 41 | num_hist_steps_max: 10 # only used if num_hist_steps: None 42 | 43 | data_loading_params: 44 | batch_size: 1 45 | num_workers: 0 46 | 47 | gcn_parameters: 48 | feats_per_node: 100 49 | feats_per_node_min: 100 50 | feats_per_node_max: 256 51 | layer_1_feats: 100 52 | layer_1_feats_min: 20 53 | layer_1_feats_max: 200 54 | layer_2_feats: 20 55 | layer_2_feats_same_as_l1: True 56 | k_top_grcu: 200 57 | num_layers: 2 58 | lstm_l1_layers: 1 59 | lstm_l1_feats: 50 # only used with sp_lstm_B_trainer 60 | lstm_l1_feats_min: 20 61 | lstm_l1_feats_max: 200 62 | lstm_l2_layers: 1 # only used with both sp_lstm_A_trainer and sp_lstm_B_trainer 63 | lstm_l2_feats: None # only used with both sp_lstm_A_trainer and sp_lstm_B_trainer 64 | lstm_l2_feats_same_as_l1: True 65 | cls_feats: 100 # Hidden size of the classifier 66 | cls_feats_min: 100 67 | cls_feats_max: 800 68 | comments: 69 | - comments 70 | -------------------------------------------------------------------------------- /experiments/parameters_uc_irv_mess_egcn_o.yaml: -------------------------------------------------------------------------------- 1 | data: uc_irv_mess 2 | 3 | uc_irc_args: 4 | folder: ./data 5 | tar_file: opsahl-ucsocial.tar.bz2 6 | edges_file: opsahl-ucsocial/out.opsahl-ucsocial 7 | aggr_time: 190080 #216000 #172800, 86400 smaller numbers yields days with no edges 8 | 9 | 10 | use_cuda: True 11 | use_logfile: False 12 | 13 | model: egcn_o 14 | 15 | task: link_pred 16 | 17 | class_weights: [ 0.1, 0.9] 18 | use_2_hot_node_feats: False 19 | use_1_hot_node_feats: True 20 | save_node_embeddings: False 21 | 22 | train_proportion: 0.71 23 | dev_proportion: 0.1 24 | 25 | num_epochs: 1000 #number of passes though the data 26 | steps_accum_gradients: 1 27 | learning_rate: 0.001 28 | learning_rate_min: 0.0005 29 | learning_rate_max: 0.1 30 | negative_mult_training: 50 31 | negative_mult_test: 100 32 | smart_neg_sampling: True 33 | seed: 1234 34 | target_measure: MAP # measure to define the best epoch F1, Precision, Recall, MRR, MAP 35 | target_class: 1 # Target class to get the measure to define the best epoch (all, 0, 1) 36 | early_stop_patience: 50 37 | 38 | eval_after_epochs: 5 39 | adj_mat_time_window: 1 # Time window to create the adj matrix for each timestep. Use None to use all the history (from 0 to t) 40 | num_hist_steps: 10 # number of previous steps used for prediction 41 | num_hist_steps_min: 3 # only used if num_hist_steps: None 42 | num_hist_steps_max: 10 # only used if num_hist_steps: None 43 | 44 | data_loading_params: 45 | batch_size: 1 46 | num_workers: 0 47 | 48 | gcn_parameters: 49 | feats_per_node: 100 50 | feats_per_node_min: 100 51 | feats_per_node_max: 256 52 | layer_1_feats: 100 53 | layer_1_feats_min: 20 54 | layer_1_feats_max: 200 55 | layer_2_feats: 20 56 | layer_2_feats_same_as_l1: True 57 | k_top_grcu: 200 58 | num_layers: 2 59 | lstm_l1_layers: 1 60 | lstm_l1_feats: 100 # only used with sp_lstm_B_trainer 61 | lstm_l1_feats_min: 20 62 | lstm_l1_feats_max: 200 63 | lstm_l2_layers: 1 # only used with both sp_lstm_A_trainer and sp_lstm_B_trainer 64 | lstm_l2_feats: None # only used with both sp_lstm_A_trainer and sp_lstm_B_trainer 65 | lstm_l2_feats_same_as_l1: True 66 | cls_feats: 100 # Hidden size of the classifier 67 | cls_feats_min: 100 68 | cls_feats_max: 800 69 | comments: 70 | - comments 71 | -------------------------------------------------------------------------------- /link_pred_tasker.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import taskers_utils as tu 3 | import utils as u 4 | 5 | 6 | class Link_Pred_Tasker(): 7 | ''' 8 | Creates a tasker object which computes the required inputs for training on a link prediction 9 | task. It receives a dataset object which should have two attributes: nodes_feats and edges, this 10 | makes the tasker independent of the dataset being used (as long as mentioned attributes have the same 11 | structure). 12 | 13 | Based on the dataset it implements the get_sample function required by edge_cls_trainer. 14 | This is a dictionary with: 15 | - time_step: the time_step of the prediction 16 | - hist_adj_list: the input adjacency matrices until t, each element of the list 17 | is a sparse tensor with the current edges. For link_pred they're 18 | unweighted 19 | - nodes_feats_list: the input nodes for the GCN models, each element of the list is a tensor 20 | two dimmensions: node_idx and node_feats 21 | - label_adj: a sparse representation of the target edges. A dict with two keys: idx: M by 2 22 | matrix with the indices of the nodes conforming each edge, vals: 1 if the node exists 23 | , 0 if it doesn't 24 | 25 | There's a test difference in the behavior, on test (or development), the number of sampled non existing 26 | edges should be higher. 27 | ''' 28 | def __init__(self,args,dataset): 29 | self.data = dataset 30 | #max_time for link pred should be one before 31 | self.max_time = dataset.max_time - 1 32 | self.args = args 33 | self.num_classes = 2 34 | 35 | if not (args.use_2_hot_node_feats or args.use_1_hot_node_feats): 36 | self.feats_per_node = dataset.feats_per_node 37 | 38 | self.get_node_feats = self.build_get_node_feats(args,dataset) 39 | self.prepare_node_feats = self.build_prepare_node_feats(args,dataset) 40 | self.is_static = False 41 | 42 | '''TO CREATE THE CSV DATASET TO USE IN DynGEM 43 | print ('min max time:', self.data.min_time, self.data.max_time) 44 | file = open('data/autonomous_syst100_adj.csv','w') 45 | file.write ('source,target,weight,time\n') 46 | for time in range(self.data.min_time, self.data.max_time): 47 | adj_t = tu.get_sp_adj(edges = self.data.edges, 48 | time = time, 49 | weighted = True, 50 | time_window = 1) 51 | #node_feats = self.get_node_feats(adj_t) 52 | print (time, len(adj_t)) 53 | idx = adj_t['idx'] 54 | vals = adj_t['vals'] 55 | num_nodes = self.data.num_nodes 56 | sp_tensor = torch.sparse.FloatTensor(idx.t(),vals.type(torch.float),torch.Size([num_nodes,num_nodes])) 57 | dense_tensor = sp_tensor.to_dense() 58 | idx = sp_tensor._indices() 59 | for i in range(idx.size()[1]): 60 | i0=idx[0,i] 61 | i1=idx[1,i] 62 | w = dense_tensor[i0,i1] 63 | file.write(str(i0.item())+','+str(i1.item())+','+str(w.item())+','+str(time)+'\n') 64 | 65 | #for i, v in zip(idx, vals): 66 | # file.write(str(i[0].item())+','+str(i[1].item())+','+str(v.item())+','+str(time)+'\n') 67 | 68 | file.close() 69 | exit''' 70 | 71 | # def build_get_non_existing(args): 72 | # if args.use_smart_neg_sampling: 73 | # else: 74 | # return tu.get_non_existing_edges 75 | 76 | def build_prepare_node_feats(self,args,dataset): 77 | if args.use_2_hot_node_feats or args.use_1_hot_node_feats: 78 | def prepare_node_feats(node_feats): 79 | return u.sparse_prepare_tensor(node_feats, 80 | torch_size= [dataset.num_nodes, 81 | self.feats_per_node]) 82 | else: 83 | prepare_node_feats = self.data.prepare_node_feats 84 | 85 | return prepare_node_feats 86 | 87 | 88 | def build_get_node_feats(self,args,dataset): 89 | if args.use_2_hot_node_feats: 90 | max_deg_out, max_deg_in = tu.get_max_degs(args,dataset) 91 | self.feats_per_node = max_deg_out + max_deg_in 92 | def get_node_feats(adj): 93 | return tu.get_2_hot_deg_feats(adj, 94 | max_deg_out, 95 | max_deg_in, 96 | dataset.num_nodes) 97 | elif args.use_1_hot_node_feats: 98 | max_deg,_ = tu.get_max_degs(args,dataset) 99 | self.feats_per_node = max_deg 100 | def get_node_feats(adj): 101 | return tu.get_1_hot_deg_feats(adj, 102 | max_deg, 103 | dataset.num_nodes) 104 | else: 105 | def get_node_feats(adj): 106 | return dataset.nodes_feats 107 | 108 | return get_node_feats 109 | 110 | 111 | def get_sample(self,idx,test, **kwargs): 112 | hist_adj_list = [] 113 | hist_ndFeats_list = [] 114 | hist_mask_list = [] 115 | existing_nodes = [] 116 | for i in range(idx - self.args.num_hist_steps, idx+1): 117 | cur_adj = tu.get_sp_adj(edges = self.data.edges, 118 | time = i, 119 | weighted = True, 120 | time_window = self.args.adj_mat_time_window) 121 | 122 | if self.args.smart_neg_sampling: 123 | existing_nodes.append(cur_adj['idx'].unique()) 124 | else: 125 | existing_nodes = None 126 | 127 | node_mask = tu.get_node_mask(cur_adj, self.data.num_nodes) 128 | 129 | node_feats = self.get_node_feats(cur_adj) 130 | 131 | cur_adj = tu.normalize_adj(adj = cur_adj, num_nodes = self.data.num_nodes) 132 | 133 | hist_adj_list.append(cur_adj) 134 | hist_ndFeats_list.append(node_feats) 135 | hist_mask_list.append(node_mask) 136 | 137 | # This would be if we were training on all the edges in the time_window 138 | label_adj = tu.get_sp_adj(edges = self.data.edges, 139 | time = idx+1, 140 | weighted = False, 141 | time_window = self.args.adj_mat_time_window) 142 | if test: 143 | neg_mult = self.args.negative_mult_test 144 | else: 145 | neg_mult = self.args.negative_mult_training 146 | 147 | if self.args.smart_neg_sampling: 148 | existing_nodes = torch.cat(existing_nodes) 149 | 150 | 151 | if 'all_edges' in kwargs.keys() and kwargs['all_edges'] == True: 152 | non_exisiting_adj = tu.get_all_non_existing_edges(adj = label_adj, tot_nodes = self.data.num_nodes) 153 | else: 154 | non_exisiting_adj = tu.get_non_existing_edges(adj = label_adj, 155 | number = label_adj['vals'].size(0) * neg_mult, 156 | tot_nodes = self.data.num_nodes, 157 | smart_sampling = self.args.smart_neg_sampling, 158 | existing_nodes = existing_nodes) 159 | 160 | # label_adj = tu.get_sp_adj_only_new(edges = self.data.edges, 161 | # weighted = False, 162 | # time = idx) 163 | 164 | label_adj['idx'] = torch.cat([label_adj['idx'],non_exisiting_adj['idx']]) 165 | label_adj['vals'] = torch.cat([label_adj['vals'],non_exisiting_adj['vals']]) 166 | return {'idx': idx, 167 | 'hist_adj_list': hist_adj_list, 168 | 'hist_ndFeats_list': hist_ndFeats_list, 169 | 'label_sp': label_adj, 170 | 'node_mask_list': hist_mask_list} 171 | 172 | -------------------------------------------------------------------------------- /log_analyzer.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import matplotlib.pyplot as plt 3 | import os 4 | from pylab import * 5 | import pprint 6 | 7 | ##### Parameters ###### 8 | filename = sys.argv[-1] # log filename 9 | cl_to_plot_id = 1 # Target class, typically the low frequent one 10 | if 'reddit' in filename or ('bitcoin' in filename and 'edge' in filename): 11 | cl_to_plot_id = 0 # 0 for reddit dataset or bitcoin edge cls 12 | 13 | simulate_early_stop = 0 # Early stop patience 14 | eval_k = 1000 # to compute metrics @K (for instance precision@1000) 15 | print_params = True # Print the parameters of each simulation 16 | ##### End parameters ###### 17 | 18 | if 'elliptic' in filename or 'reddit' in filename or ('bitcoin' in filename and 'edge' in filename): 19 | target_measure='f1' # map mrr f1 p r loss avg_p avg_r avg_f1 20 | else: 21 | target_measure='map' # map mrr f1 p r loss avg_p avg_r avg_f1 22 | 23 | 24 | # Hyper parameters to analyze 25 | params = [] 26 | params.append('learning_rate') 27 | params.append('num_hist_steps') 28 | params.append('layer_1_feats') 29 | params.append('lstm_l1_feats') 30 | params.append('class_weights') 31 | params.append('adj_mat_time_window') 32 | params.append('cls_feats') 33 | params.append('model') 34 | 35 | 36 | res_map={} 37 | errors = {} 38 | losses = {} 39 | MRRs = {} 40 | MAPs = {} 41 | prec = {} 42 | rec = {} 43 | f1 = {} 44 | prec_at_k = {} 45 | rec_at_k = {} 46 | f1_at_k = {} 47 | prec_cl = {} 48 | rec_cl = {} 49 | f1_cl = {} 50 | prec_at_k_cl = {} 51 | rec_at_k_cl = {} 52 | f1_at_k_cl = {} 53 | best_measure = {} 54 | best_epoch = {} 55 | 56 | last_test_ep={} 57 | last_test_ep['precision'] = '-' 58 | last_test_ep['recall'] = '-' 59 | last_test_ep['F1'] = '-' 60 | last_test_ep['AVG-precision'] = '-' 61 | last_test_ep['AVG-recall'] = '-' 62 | last_test_ep['AVG-F1'] = '-' 63 | last_test_ep['precision@'+str(eval_k)] = '-' 64 | last_test_ep['recall@'+str(eval_k)] = '-' 65 | last_test_ep['F1@'+str(eval_k)] = '-' 66 | last_test_ep['AVG-precision@'+str(eval_k)] = '-' 67 | last_test_ep['AVG-recall@'+str(eval_k)] = '-' 68 | last_test_ep['AVG-F1@'+str(eval_k)] = '-' 69 | last_test_ep['MRR'] = '-' 70 | last_test_ep['MAP'] = '-' 71 | last_test_ep['best_epoch'] = -1 72 | 73 | sets = ['TRAIN', 'VALID', 'TEST'] 74 | 75 | for s in sets: 76 | errors[s] = {} 77 | losses[s] = {} 78 | MRRs[s] = {} 79 | MAPs[s] = {} 80 | prec[s] = {} 81 | rec[s] = {} 82 | f1[s] = {} 83 | prec_at_k[s] = {} 84 | rec_at_k[s] = {} 85 | f1_at_k[s] = {} 86 | prec_cl[s] = {} 87 | rec_cl[s] = {} 88 | f1_cl[s] = {} 89 | prec_at_k_cl[s] = {} 90 | rec_at_k_cl[s] = {} 91 | f1_at_k_cl[s] = {} 92 | 93 | best_measure[s] = 0 94 | best_epoch[s] = -1 95 | 96 | str_comments='' 97 | str_comments1='' 98 | 99 | exp_params={} 100 | 101 | print ("Start parsing: ",filename) 102 | with open(filename) as f: 103 | params_line=True 104 | readlr=False 105 | for line in f: 106 | line=line.replace('INFO:root:','').replace('\n','') 107 | if params_line: #print parameters 108 | if "'learning_rate':" in line: 109 | readlr=True 110 | if not readlr: 111 | str_comments+=line+'\n' 112 | else: 113 | str_comments1+=line+'\n' 114 | if params_line: #print parameters 115 | for p in params: 116 | str_p='\''+p+'\': ' 117 | if str_p in line: 118 | exp_params[p]=line.split(str_p)[1].split(',')[0] 119 | if line=='': 120 | params_line=False 121 | 122 | if 'TRAIN epoch' in line or 'VALID epoch' in line or 'TEST epoch' in line: 123 | set = line.split(' ')[1] 124 | epoch = int(line.split(' ')[3])+1 125 | if set=='TEST': 126 | last_test_ep['best_epoch'] = epoch 127 | if epoch==50000: 128 | break 129 | elif 'mean errors' in line: 130 | v=float(line.split('mean errors ')[1])#float(line.split('(')[1].split(')')[0]) 131 | errors[set][epoch]=v 132 | if target_measure=='errors': 133 | if vbest_measure[set]: 141 | best_measure[set]=v 142 | best_epoch[set]=epoch 143 | elif 'mean MRR' in line: 144 | v = float(line.split('mean MRR ')[1].split(' ')[0]) 145 | MRRs[set][epoch]=v 146 | if set=='TEST': 147 | last_test_ep['MRR'] = v 148 | if target_measure=='mrr': 149 | if v>best_measure[set]: 150 | best_measure[set]=v 151 | best_epoch[set]=epoch 152 | if 'mean MAP' in line: 153 | v=float(line.split('mean MAP ')[1].split(' ')[0]) 154 | MAPs[set][epoch]=v 155 | if target_measure=='map': 156 | if v>best_measure[set]: 157 | best_measure[set]=v 158 | best_epoch[set]=epoch 159 | if set=='TEST': 160 | last_test_ep['MAP'] = v 161 | elif 'measures microavg' in line: 162 | prec[set][epoch]=float(line.split('precision ')[1].split(' ')[0]) 163 | rec[set][epoch]=float(line.split('recall ')[1].split(' ')[0]) 164 | f1[set][epoch]=float(line.split('f1 ')[1].split(' ')[0]) 165 | if (target_measure=='avg_p' or target_measure=='avg_r' or target_measure=='avg_f1'): 166 | if target_measure=='avg_p': 167 | v=prec[set][epoch] 168 | elif target_measure=='avg_r': 169 | v=rec[set][epoch] 170 | else: #F1 171 | v=f1[set][epoch] 172 | if v>best_measure[set]: 173 | best_measure[set]=v 174 | best_epoch[set]=epoch 175 | if set=='TEST': 176 | last_test_ep['AVG-precision'] = prec[set][epoch] 177 | last_test_ep['AVG-recall'] = rec[set][epoch] 178 | last_test_ep['AVG-F1'] = f1[set][epoch] 179 | 180 | elif 'measures@'+str(eval_k)+' microavg' in line: 181 | prec_at_k[set][epoch]=float(line.split('precision ')[1].split(' ')[0]) 182 | rec_at_k[set][epoch]=float(line.split('recall ')[1].split(' ')[0]) 183 | f1_at_k[set][epoch]=float(line.split('f1 ')[1].split(' ')[0]) 184 | if set=='TEST': 185 | last_test_ep['AVG-precision@'+str(eval_k)] = prec_at_k[set][epoch] 186 | last_test_ep['AVG-recall@'+str(eval_k)] = rec_at_k[set][epoch] 187 | last_test_ep['AVG-F1@'+str(eval_k)] = f1_at_k[set][epoch] 188 | elif 'measures for class ' in line: 189 | cl=int(line.split('class ')[1].split(' ')[0]) 190 | if cl not in prec_cl[set]: 191 | prec_cl[set][cl] = {} 192 | rec_cl[set][cl] = {} 193 | f1_cl[set][cl] = {} 194 | prec_cl[set][cl][epoch]=float(line.split('precision ')[1].split(' ')[0]) 195 | rec_cl[set][cl][epoch]=float(line.split('recall ')[1].split(' ')[0]) 196 | f1_cl[set][cl][epoch]=float(line.split('f1 ')[1].split(' ')[0]) 197 | if (target_measure=='p' or target_measure=='r' or target_measure=='f1') and cl==cl_to_plot_id: 198 | if target_measure=='p': 199 | v=prec_cl[set][cl][epoch] 200 | elif target_measure=='r': 201 | v=rec_cl[set][cl][epoch] 202 | else: #F1 203 | v=f1_cl[set][cl][epoch] 204 | if v>best_measure[set]: 205 | best_measure[set]=v 206 | best_epoch[set]=epoch 207 | if set=='TEST': 208 | last_test_ep['precision'] = prec_cl[set][cl][epoch] 209 | last_test_ep['recall'] = rec_cl[set][cl][epoch] 210 | last_test_ep['F1'] = f1_cl[set][cl][epoch] 211 | elif 'measures@'+str(eval_k)+' for class ' in line: 212 | cl=int(line.split('class ')[1].split(' ')[0]) 213 | if cl not in prec_at_k_cl[set]: 214 | prec_at_k_cl[set][cl] = {} 215 | rec_at_k_cl[set][cl] = {} 216 | f1_at_k_cl[set][cl] = {} 217 | prec_at_k_cl[set][cl][epoch]=float(line.split('precision ')[1].split(' ')[0]) 218 | rec_at_k_cl[set][cl][epoch]=float(line.split('recall ')[1].split(' ')[0]) 219 | f1_at_k_cl[set][cl][epoch]=float(line.split('f1 ')[1].split(' ')[0]) 220 | if (target_measure=='p@k' or target_measure=='r@k' or target_measure=='f1@k') and cl==cl_to_plot_id: 221 | if target_measure=='p@k': 222 | v=prec_at_k_cl[set][cl][epoch] 223 | elif target_measure=='r@k': 224 | v=rec_at_k_cl[set][cl][epoch] 225 | else: 226 | v=f1_at_k_cl[set][cl][epoch] 227 | if v>best_measure[set]: 228 | best_measure[set]=v 229 | best_epoch[set]=epoch 230 | if set=='TEST': 231 | last_test_ep['precision@'+str(eval_k)] = prec_at_k_cl[set][cl][epoch] 232 | last_test_ep['recall@'+str(eval_k)] = rec_at_k_cl[set][cl][epoch] 233 | last_test_ep['F1@'+str(eval_k)] = f1_at_k_cl[set][cl][epoch] 234 | 235 | 236 | 237 | if best_epoch['TEST']<0 and best_epoch['VALID']<0 or last_test_ep['best_epoch']<1: 238 | print ('best_epoch<0: -> skip') 239 | exit(0) 240 | 241 | try: 242 | res_map['model'] = exp_params['model'].replace("'","") 243 | str_params=(pprint.pformat(exp_params)) 244 | if print_params: 245 | print ('str_params:\n', str_params) 246 | if best_epoch['VALID']>=0: 247 | best_ep = best_epoch['VALID'] 248 | print ('Highest %s values among all epochs: TRAIN %0.4f\tVALID %0.4f\tTEST %0.4f' % (target_measure, best_measure['TRAIN'], best_measure['VALID'], best_measure['TEST'])) 249 | else: 250 | best_ep = best_epoch['TEST'] 251 | print ('Highest %s values among all epochs:\tTRAIN F1 %0.4f\tTEST %0.4f' % (target_measure, best_measure['TRAIN'], best_measure['TEST'])) 252 | 253 | use_latest_ep = True 254 | try: 255 | print ('Values at best Valid Epoch (%d) for target class: TEST Precision %0.4f - Recall %0.4f - F1 %0.4f' % (best_ep, prec_cl['TEST'][cl_to_plot_id][best_ep],rec_cl['TEST'][cl_to_plot_id][best_ep],f1_cl['TEST'][cl_to_plot_id][best_ep])) 256 | print ('Values at best Valid Epoch (%d) micro-AVG: TEST Precision %0.4f - Recall %0.4f - F1 %0.4f' % (best_ep, prec['TEST'][best_ep],rec['TEST'][best_ep],f1['TEST'][best_ep])) 257 | res_map['precision'] = prec_cl['TEST'][cl_to_plot_id][best_ep] 258 | res_map['recall'] = rec_cl['TEST'][cl_to_plot_id][best_ep] 259 | res_map['F1'] = f1_cl['TEST'][cl_to_plot_id][best_ep] 260 | res_map['AVG-precision'] = prec['TEST'][best_ep] 261 | res_map['AVG-recall'] = rec['TEST'][best_ep] 262 | res_map['AVG-F1'] = f1['TEST'][best_ep] 263 | except: 264 | res_map['precision'] = last_test_ep['precision'] 265 | res_map['recall'] = last_test_ep['recall'] 266 | res_map['F1'] = last_test_ep['F1'] 267 | res_map['AVG-precision'] = last_test_ep['AVG-precision'] 268 | res_map['AVG-recall'] = last_test_ep['AVG-F1'] 269 | res_map['AVG-F1'] = last_test_ep['AVG-F1'] 270 | use_latest_ep = False 271 | print ('WARNING: last epoch not finished, use the previous one.') 272 | 273 | try: 274 | print ('Values at best Valid Epoch (%d) for target class@%d: TEST Precision %0.4f - Recall %0.4f - F1 %0.4f' % (best_ep, eval_k, prec_at_k_cl['TEST'][cl_to_plot_id][best_ep],rec_at_k_cl['TEST'][cl_to_plot_id][best_ep],f1_at_k_cl['TEST'][cl_to_plot_id][best_ep])) 275 | res_map['precision@'+str(eval_k)] = prec_at_k_cl['TEST'][cl_to_plot_id][best_ep] 276 | res_map['recall@'+str(eval_k)] = rec_at_k_cl['TEST'][cl_to_plot_id][best_ep] 277 | res_map['F1@'+str(eval_k)] = f1_at_k_cl['TEST'][cl_to_plot_id][best_ep] 278 | 279 | print ('Values at best Valid Epoch (%d) micro-AVG@%d: TEST Precision %0.4f - Recall %0.4f - F1 %0.4f' % (best_ep, eval_k, prec_at_k['TEST'][best_ep],rec_at_k['TEST'][best_ep],f1_at_k['TEST'][best_ep])) 280 | res_map['AVG-precision@'+str(eval_k)] = prec_at_k['TEST'][best_ep] 281 | res_map['AVG-recall@'+str(eval_k)] = rec_at_k['TEST'][best_ep] 282 | res_map['AVG-F1@'+str(eval_k)] = f1_at_k['TEST'][best_ep] 283 | 284 | except: 285 | res_map['precision@'+str(eval_k)] = last_test_ep['precision@'+str(eval_k)] 286 | res_map['recall@'+str(eval_k)] = last_test_ep['recall@'+str(eval_k)] 287 | res_map['F1@'+str(eval_k)] = last_test_ep['F1@'+str(eval_k)] 288 | res_map['AVG-precision@'+str(eval_k)] = last_test_ep['AVG-precision@'+str(eval_k)] 289 | res_map['AVG-recall@'+str(eval_k)] = last_test_ep['AVG-recall@'+str(eval_k)] 290 | res_map['AVG-F1@'+str(eval_k)] = last_test_ep['AVG-F1@'+str(eval_k)] 291 | 292 | try: 293 | print ('Values at best Valid Epoch (%d) MAP: TRAIN %0.8f - VALID %0.8f - TEST %0.8f' % (best_ep, MAPs['TRAIN'][best_ep], MAPs['VALID'][best_ep], MAPs['TEST'][best_ep])) 294 | res_map['MAP'] = MAPs['TEST'][best_ep] 295 | except: 296 | res_map['MAP'] = last_test_ep['MAP'] 297 | try: 298 | print ('Values at best Valid Epoch (%d) MRR: TRAIN %0.8f - VALID %0.8f - TEST %0.8f' % (best_ep, MRRs['TRAIN'][best_ep], MRRs['VALID'][best_ep], MRRs['TEST'][best_ep])) 299 | res_map['MRR'] = MRRs['TEST'][best_ep] 300 | except: 301 | res_map['MRR'] = last_test_ep['MRR'] 302 | 303 | if use_latest_ep: 304 | res_map['best_epoch'] = best_ep 305 | else: 306 | res_map['best_epoch'] = last_test_ep['best_epoch'] 307 | 308 | except: 309 | print('Some error occurred in', filename,' - Epochs read: ',epoch) 310 | exit(0) 311 | 312 | str_results = '' 313 | str_legend = '' 314 | for k, v in res_map.items(): 315 | str_results+=str(v)+',' 316 | str_legend+=str(k)+',' 317 | for k, v in exp_params.items(): 318 | str_results+=str(v)+',' 319 | str_legend+=str(k)+',' 320 | str_results+=filename.split('/')[1].split('.log')[0] 321 | str_legend+='log_file' 322 | print ('\n\nCSV-like output:') 323 | print (str_legend) 324 | print (str_results) 325 | -------------------------------------------------------------------------------- /logger.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import pprint 3 | import sys 4 | import datetime 5 | import torch 6 | import utils 7 | import matplotlib.pyplot as plt 8 | import time 9 | from sklearn.metrics import average_precision_score 10 | from scipy.sparse import coo_matrix 11 | import numpy as np 12 | 13 | 14 | 15 | 16 | class Logger(): 17 | def __init__(self, args, num_classes, minibatch_log_interval=10): 18 | 19 | if args is not None: 20 | currdate=str(datetime.datetime.today().strftime('%Y%m%d%H%M%S')) 21 | self.log_name= 'log/log_'+args.data+'_'+args.task+'_'+args.model+'_'+currdate+'_r'+str(args.rank)+'.log' 22 | 23 | if args.use_logfile: 24 | print ("Log file:", self.log_name) 25 | logging.basicConfig(filename=self.log_name, level=logging.INFO) 26 | else: 27 | print ("Log: STDOUT") 28 | logging.basicConfig(stream=sys.stdout, level=logging.INFO) 29 | 30 | logging.info ('*** PARAMETERS ***') 31 | logging.info (pprint.pformat(args.__dict__)) # displays the string 32 | logging.info ('') 33 | else: 34 | print ("Log: STDOUT") 35 | logging.basicConfig(stream=sys.stdout, level=logging.INFO) 36 | 37 | self.num_classes = num_classes 38 | self.minibatch_log_interval = minibatch_log_interval 39 | self.eval_k_list = [10, 100, 1000] 40 | self.args = args 41 | 42 | 43 | def get_log_file_name(self): 44 | return self.log_name 45 | 46 | def log_epoch_start(self, epoch, num_minibatches, set, minibatch_log_interval=None): 47 | #ALDO 48 | self.epoch = epoch 49 | ###### 50 | self.set = set 51 | self.losses = [] 52 | self.errors = [] 53 | self.MRRs = [] 54 | self.MAPs = [] 55 | #self.time_step_sizes = [] 56 | self.conf_mat_tp = {} 57 | self.conf_mat_fn = {} 58 | self.conf_mat_fp = {} 59 | self.conf_mat_tp_at_k = {} 60 | self.conf_mat_fn_at_k = {} 61 | self.conf_mat_fp_at_k = {} 62 | for k in self.eval_k_list: 63 | self.conf_mat_tp_at_k[k] = {} 64 | self.conf_mat_fn_at_k[k] = {} 65 | self.conf_mat_fp_at_k[k] = {} 66 | 67 | for cl in range(self.num_classes): 68 | self.conf_mat_tp[cl]=0 69 | self.conf_mat_fn[cl]=0 70 | self.conf_mat_fp[cl]=0 71 | for k in self.eval_k_list: 72 | self.conf_mat_tp_at_k[k][cl]=0 73 | self.conf_mat_fn_at_k[k][cl]=0 74 | self.conf_mat_fp_at_k[k][cl]=0 75 | 76 | if self.set == "TEST": 77 | self.conf_mat_tp_list = {} 78 | self.conf_mat_fn_list = {} 79 | self.conf_mat_fp_list = {} 80 | for cl in range(self.num_classes): 81 | self.conf_mat_tp_list[cl]=[] 82 | self.conf_mat_fn_list[cl]=[] 83 | self.conf_mat_fp_list[cl]=[] 84 | 85 | self.batch_sizes=[] 86 | self.minibatch_done = 0 87 | self.num_minibatches = num_minibatches 88 | if minibatch_log_interval is not None: 89 | self.minibatch_log_interval = minibatch_log_interval 90 | logging.info('################ '+set+' epoch '+str(epoch)+' ###################') 91 | self.lasttime = time.monotonic() 92 | self.ep_time = self.lasttime 93 | 94 | def log_minibatch(self, predictions, true_classes, loss, **kwargs): 95 | 96 | probs = torch.softmax(predictions,dim=1)[:,1] 97 | if self.set in ['TEST', 'VALID'] and self.args.task == 'link_pred': 98 | MRR = self.get_MRR(probs,true_classes, kwargs['adj'],do_softmax=False) 99 | else: 100 | MRR = torch.tensor([0.0]) 101 | 102 | MAP = torch.tensor(self.get_MAP(probs,true_classes, do_softmax=False)) 103 | 104 | error, conf_mat_per_class = self.eval_predicitions(predictions, true_classes, self.num_classes) 105 | conf_mat_per_class_at_k={} 106 | for k in self.eval_k_list: 107 | conf_mat_per_class_at_k[k] = self.eval_predicitions_at_k(predictions, true_classes, self.num_classes, k) 108 | 109 | batch_size = predictions.size(0) 110 | self.batch_sizes.append(batch_size) 111 | 112 | self.losses.append(loss) #loss.detach() 113 | self.errors.append(error) 114 | self.MRRs.append(MRR) 115 | self.MAPs.append(MAP) 116 | for cl in range(self.num_classes): 117 | self.conf_mat_tp[cl]+=conf_mat_per_class.true_positives[cl] 118 | self.conf_mat_fn[cl]+=conf_mat_per_class.false_negatives[cl] 119 | self.conf_mat_fp[cl]+=conf_mat_per_class.false_positives[cl] 120 | for k in self.eval_k_list: 121 | self.conf_mat_tp_at_k[k][cl]+=conf_mat_per_class_at_k[k].true_positives[cl] 122 | self.conf_mat_fn_at_k[k][cl]+=conf_mat_per_class_at_k[k].false_negatives[cl] 123 | self.conf_mat_fp_at_k[k][cl]+=conf_mat_per_class_at_k[k].false_positives[cl] 124 | if self.set == "TEST": 125 | self.conf_mat_tp_list[cl].append(conf_mat_per_class.true_positives[cl]) 126 | self.conf_mat_fn_list[cl].append(conf_mat_per_class.false_negatives[cl]) 127 | self.conf_mat_fp_list[cl].append(conf_mat_per_class.false_positives[cl]) 128 | 129 | self.minibatch_done+=1 130 | if self.minibatch_done%self.minibatch_log_interval==0: 131 | mb_error = self.calc_epoch_metric(self.batch_sizes, self.errors) 132 | mb_MRR = self.calc_epoch_metric(self.batch_sizes, self.MRRs) 133 | mb_MAP = self.calc_epoch_metric(self.batch_sizes, self.MAPs) 134 | partial_losses = torch.stack(self.losses) 135 | logging.info(self.set+ ' batch %d / %d - partial error %0.4f - partial loss %0.4f - partial MRR %0.4f - partial MAP %0.4f' % (self.minibatch_done, self.num_minibatches, mb_error, partial_losses.mean(), mb_MRR, mb_MAP)) 136 | 137 | tp=conf_mat_per_class.true_positives 138 | fn=conf_mat_per_class.false_negatives 139 | fp=conf_mat_per_class.false_positives 140 | logging.info(self.set+' batch %d / %d - partial tp %s,fn %s,fp %s' % (self.minibatch_done, self.num_minibatches, tp, fn, fp)) 141 | precision, recall, f1 = self.calc_microavg_eval_measures(tp, fn, fp) 142 | logging.info (self.set+' batch %d / %d - measures partial microavg - precision %0.4f - recall %0.4f - f1 %0.4f ' % (self.minibatch_done, self.num_minibatches, precision,recall,f1)) 143 | for cl in range(self.num_classes): 144 | cl_precision, cl_recall, cl_f1 = self.calc_eval_measures_per_class(tp, fn, fp, cl) 145 | logging.info (self.set+' batch %d / %d - measures partial for class %d - precision %0.4f - recall %0.4f - f1 %0.4f ' % (self.minibatch_done, self.num_minibatches, cl,cl_precision,cl_recall,cl_f1)) 146 | 147 | logging.info (self.set+' batch %d / %d - Batch time %d ' % (self.minibatch_done, self.num_minibatches, (time.monotonic()-self.lasttime) )) 148 | 149 | self.lasttime=time.monotonic() 150 | 151 | def log_epoch_done(self): 152 | eval_measure = 0 153 | 154 | self.losses = torch.stack(self.losses) 155 | logging.info(self.set+' mean losses '+ str(self.losses.mean())) 156 | if self.args.target_measure=='loss' or self.args.target_measure=='Loss': 157 | eval_measure = self.losses.mean() 158 | 159 | epoch_error = self.calc_epoch_metric(self.batch_sizes, self.errors) 160 | logging.info(self.set+' mean errors '+ str(epoch_error)) 161 | 162 | epoch_MRR = self.calc_epoch_metric(self.batch_sizes, self.MRRs) 163 | epoch_MAP = self.calc_epoch_metric(self.batch_sizes, self.MAPs) 164 | logging.info(self.set+' mean MRR '+ str(epoch_MRR)+' - mean MAP '+ str(epoch_MAP)) 165 | if self.args.target_measure=='MRR' or self.args.target_measure=='mrr': 166 | eval_measure = epoch_MRR 167 | if self.args.target_measure=='MAP' or self.args.target_measure=='map': 168 | eval_measure = epoch_MAP 169 | 170 | logging.info(self.set+' tp %s,fn %s,fp %s' % (self.conf_mat_tp, self.conf_mat_fn, self.conf_mat_fp)) 171 | precision, recall, f1 = self.calc_microavg_eval_measures(self.conf_mat_tp, self.conf_mat_fn, self.conf_mat_fp) 172 | logging.info (self.set+' measures microavg - precision %0.4f - recall %0.4f - f1 %0.4f ' % (precision,recall,f1)) 173 | if str(self.args.target_class) == 'AVG': 174 | if self.args.target_measure=='Precision' or self.args.target_measure=='prec': 175 | eval_measure = precision 176 | elif self.args.target_measure=='Recall' or self.args.target_measure=='rec': 177 | eval_measure = recall 178 | else: 179 | eval_measure = f1 180 | 181 | 182 | for cl in range(self.num_classes): 183 | cl_precision, cl_recall, cl_f1 = self.calc_eval_measures_per_class(self.conf_mat_tp, self.conf_mat_fn, self.conf_mat_fp, cl) 184 | logging.info (self.set+' measures for class %d - precision %0.4f - recall %0.4f - f1 %0.4f ' % (cl,cl_precision,cl_recall,cl_f1)) 185 | if str(cl) == str(self.args.target_class): 186 | if self.args.target_measure=='Precision' or self.args.target_measure=='prec': 187 | eval_measure = cl_precision 188 | elif self.args.target_measure=='Recall' or self.args.target_measure=='rec': 189 | eval_measure = cl_recall 190 | else: 191 | eval_measure = cl_f1 192 | 193 | for k in self.eval_k_list: #logging.info(self.set+' @%d tp %s,fn %s,fp %s' % (k, self.conf_mat_tp_at_k[k], self.conf_mat_fn_at_k[k], self.conf_mat_fp_at_k[k])) 194 | precision, recall, f1 = self.calc_microavg_eval_measures(self.conf_mat_tp_at_k[k], self.conf_mat_fn_at_k[k], self.conf_mat_fp_at_k[k]) 195 | logging.info (self.set+' measures@%d microavg - precision %0.4f - recall %0.4f - f1 %0.4f ' % (k,precision,recall,f1)) 196 | 197 | for cl in range(self.num_classes): 198 | cl_precision, cl_recall, cl_f1 = self.calc_eval_measures_per_class(self.conf_mat_tp_at_k[k], self.conf_mat_fn_at_k[k], self.conf_mat_fp_at_k[k], cl) 199 | logging.info (self.set+' measures@%d for class %d - precision %0.4f - recall %0.4f - f1 %0.4f ' % (k, cl,cl_precision,cl_recall,cl_f1)) 200 | 201 | 202 | logging.info (self.set+' Total epoch time: '+ str(((time.monotonic()-self.ep_time)))) 203 | 204 | return eval_measure 205 | 206 | def get_MRR(self,predictions,true_classes, adj ,do_softmax=False): 207 | if do_softmax: 208 | probs = torch.softmax(predictions,dim=1)[:,1] 209 | else: 210 | probs = predictions 211 | 212 | probs = probs.cpu().numpy() 213 | true_classes = true_classes.cpu().numpy() 214 | adj = adj.cpu().numpy() 215 | 216 | pred_matrix = coo_matrix((probs,(adj[0],adj[1]))).toarray() 217 | true_matrix = coo_matrix((true_classes,(adj[0],adj[1]))).toarray() 218 | 219 | row_MRRs = [] 220 | for i,pred_row in enumerate(pred_matrix): 221 | #check if there are any existing edges 222 | if np.isin(1,true_matrix[i]): 223 | row_MRRs.append(self.get_row_MRR(pred_row,true_matrix[i])) 224 | 225 | avg_MRR = torch.tensor(row_MRRs).mean() 226 | return avg_MRR 227 | 228 | def get_row_MRR(self,probs,true_classes): 229 | existing_mask = true_classes == 1 230 | #descending in probability 231 | ordered_indices = np.flip(probs.argsort()) 232 | 233 | ordered_existing_mask = existing_mask[ordered_indices] 234 | 235 | existing_ranks = np.arange(1, 236 | true_classes.shape[0]+1, 237 | dtype=np.float)[ordered_existing_mask] 238 | 239 | MRR = (1/existing_ranks).sum()/existing_ranks.shape[0] 240 | return MRR 241 | 242 | 243 | def get_MAP(self,predictions,true_classes, do_softmax=False): 244 | if do_softmax: 245 | probs = torch.softmax(predictions,dim=1)[:,1] 246 | else: 247 | probs = predictions 248 | 249 | predictions_np = probs.detach().cpu().numpy() 250 | true_classes_np = true_classes.detach().cpu().numpy() 251 | 252 | return average_precision_score(true_classes_np, predictions_np) 253 | 254 | def eval_predicitions(self, predictions, true_classes, num_classes): 255 | predicted_classes = predictions.argmax(dim=1) 256 | failures = (predicted_classes!=true_classes).sum(dtype=torch.float) 257 | error = failures/predictions.size(0) 258 | 259 | conf_mat_per_class = utils.Namespace({}) 260 | conf_mat_per_class.true_positives = {} 261 | conf_mat_per_class.false_negatives = {} 262 | conf_mat_per_class.false_positives = {} 263 | 264 | for cl in range(num_classes): 265 | cl_indices = true_classes == cl 266 | 267 | pos = predicted_classes == cl 268 | hits = (predicted_classes[cl_indices] == true_classes[cl_indices]) 269 | 270 | tp = hits.sum() 271 | fn = hits.size(0) - tp 272 | fp = pos.sum() - tp 273 | 274 | conf_mat_per_class.true_positives[cl] = tp 275 | conf_mat_per_class.false_negatives[cl] = fn 276 | conf_mat_per_class.false_positives[cl] = fp 277 | return error, conf_mat_per_class 278 | 279 | 280 | def eval_predicitions_at_k(self, predictions, true_classes, num_classes, k): 281 | conf_mat_per_class = utils.Namespace({}) 282 | conf_mat_per_class.true_positives = {} 283 | conf_mat_per_class.false_negatives = {} 284 | conf_mat_per_class.false_positives = {} 285 | 286 | if predictions.size(0) hits.size(0) - tp 303 | fp = pos.sum() - tp 304 | 305 | conf_mat_per_class.true_positives[cl] = tp 306 | conf_mat_per_class.false_negatives[cl] = fn 307 | conf_mat_per_class.false_positives[cl] = fp 308 | return conf_mat_per_class 309 | 310 | 311 | def calc_microavg_eval_measures(self, tp, fn, fp): 312 | tp_sum = sum(tp.values()).item() 313 | fn_sum = sum(fn.values()).item() 314 | fp_sum = sum(fp.values()).item() 315 | 316 | p = tp_sum*1.0 / (tp_sum+fp_sum) 317 | r = tp_sum*1.0 / (tp_sum+fn_sum) 318 | if (p+r)>0: 319 | f1 = 2.0 * (p*r) / (p+r) 320 | else: 321 | f1 = 0 322 | return p, r, f1 323 | 324 | def calc_eval_measures_per_class(self, tp, fn, fp, class_id): 325 | #ALDO 326 | if type(tp) is dict: 327 | tp_sum = tp[class_id].item() 328 | fn_sum = fn[class_id].item() 329 | fp_sum = fp[class_id].item() 330 | else: 331 | tp_sum = tp.item() 332 | fn_sum = fn.item() 333 | fp_sum = fp.item() 334 | ######## 335 | if tp_sum==0: 336 | return 0,0,0 337 | 338 | p = tp_sum*1.0 / (tp_sum+fp_sum) 339 | r = tp_sum*1.0 / (tp_sum+fn_sum) 340 | if (p+r)>0: 341 | f1 = 2.0 * (p*r) / (p+r) 342 | else: 343 | f1 = 0 344 | return p, r, f1 345 | 346 | def calc_epoch_metric(self,batch_sizes, metric_val): 347 | batch_sizes = torch.tensor(batch_sizes, dtype = torch.float) 348 | epoch_metric_val = torch.stack(metric_val).cpu() * batch_sizes 349 | epoch_metric_val = epoch_metric_val.sum()/batch_sizes.sum() 350 | 351 | return epoch_metric_val.detach().item() 352 | -------------------------------------------------------------------------------- /models.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import utils as u 3 | from argparse import Namespace 4 | from torch.nn.parameter import Parameter 5 | from torch.nn import functional as F 6 | import torch.nn as nn 7 | import math 8 | 9 | class Sp_GCN(torch.nn.Module): 10 | def __init__(self,args,activation): 11 | super().__init__() 12 | self.activation = activation 13 | self.num_layers = args.num_layers 14 | 15 | self.w_list = nn.ParameterList() 16 | for i in range(self.num_layers): 17 | if i==0: 18 | w_i = Parameter(torch.Tensor(args.feats_per_node, args.layer_1_feats)) 19 | u.reset_param(w_i) 20 | else: 21 | w_i = Parameter(torch.Tensor(args.layer_1_feats, args.layer_2_feats)) 22 | u.reset_param(w_i) 23 | self.w_list.append(w_i) 24 | 25 | 26 | def forward(self,A_list, Nodes_list, nodes_mask_list): 27 | node_feats = Nodes_list[-1] 28 | #A_list: T, each element sparse tensor 29 | #take only last adj matrix in time 30 | Ahat = A_list[-1] 31 | #Ahat: NxN ~ 30k 32 | #sparse multiplication 33 | 34 | # Ahat NxN 35 | # self.node_embs = Nxk 36 | # 37 | # note(bwheatman, tfk): change order of matrix multiply 38 | last_l = self.activation(Ahat.matmul(node_feats.matmul(self.w_list[0]))) 39 | for i in range(1, self.num_layers): 40 | last_l = self.activation(Ahat.matmul(last_l.matmul(self.w_list[i]))) 41 | return last_l 42 | 43 | 44 | class Sp_Skip_GCN(Sp_GCN): 45 | def __init__(self,args,activation): 46 | super().__init__(args,activation) 47 | self.W_feat = Parameter(torch.Tensor(args.feats_per_node, args.layer_1_feats)) 48 | 49 | def forward(self,A_list, Nodes_list = None): 50 | node_feats = Nodes_list[-1] 51 | #A_list: T, each element sparse tensor 52 | #take only last adj matrix in time 53 | Ahat = A_list[-1] 54 | #Ahat: NxN ~ 30k 55 | #sparse multiplication 56 | 57 | # Ahat NxN 58 | # self.node_feats = Nxk 59 | # 60 | # note(bwheatman, tfk): change order of matrix multiply 61 | l1 = self.activation(Ahat.matmul(node_feats.matmul(self.W1))) 62 | l2 = self.activation(Ahat.matmul(l1.matmul(self.W2)) + (node_feats.matmul(self.W3))) 63 | 64 | return l2 65 | 66 | class Sp_Skip_NodeFeats_GCN(Sp_GCN): 67 | def __init__(self,args,activation): 68 | super().__init__(args,activation) 69 | 70 | def forward(self,A_list, Nodes_list = None): 71 | node_feats = Nodes_list[-1] 72 | Ahat = A_list[-1] 73 | last_l = self.activation(Ahat.matmul(node_feats.matmul(self.w_list[0]))) 74 | for i in range(1, self.num_layers): 75 | last_l = self.activation(Ahat.matmul(last_l.matmul(self.w_list[i]))) 76 | skip_last_l = torch.cat((last_l,node_feats), dim=1) # use node_feats.to_dense() if 2hot encoded input 77 | return skip_last_l 78 | 79 | class Sp_GCN_LSTM_A(Sp_GCN): 80 | def __init__(self,args,activation): 81 | super().__init__(args,activation) 82 | self.rnn = nn.LSTM( 83 | input_size=args.layer_2_feats, 84 | hidden_size=args.lstm_l2_feats, 85 | num_layers=args.lstm_l2_layers 86 | ) 87 | 88 | def forward(self,A_list, Nodes_list = None, nodes_mask_list = None): 89 | last_l_seq=[] 90 | for t,Ahat in enumerate(A_list): 91 | node_feats = Nodes_list[t] 92 | #A_list: T, each element sparse tensor 93 | #note(bwheatman, tfk): change order of matrix multiply 94 | last_l = self.activation(Ahat.matmul(node_feats.matmul(self.w_list[0]))) 95 | for i in range(1, self.num_layers): 96 | last_l = self.activation(Ahat.matmul(last_l.matmul(self.w_list[i]))) 97 | last_l_seq.append(last_l) 98 | 99 | last_l_seq = torch.stack(last_l_seq) 100 | 101 | out, _ = self.rnn(last_l_seq, None) 102 | return out[-1] 103 | 104 | 105 | class Sp_GCN_GRU_A(Sp_GCN_LSTM_A): 106 | def __init__(self,args,activation): 107 | super().__init__(args,activation) 108 | self.rnn = nn.GRU( 109 | input_size=args.layer_2_feats, 110 | hidden_size=args.lstm_l2_feats, 111 | num_layers=args.lstm_l2_layers 112 | ) 113 | 114 | class Sp_GCN_LSTM_B(Sp_GCN): 115 | def __init__(self,args,activation): 116 | super().__init__(args,activation) 117 | assert args.num_layers == 2, 'GCN-LSTM and GCN-GRU requires 2 conv layers.' 118 | self.rnn_l1 = nn.LSTM( 119 | input_size=args.layer_1_feats, 120 | hidden_size=args.lstm_l1_feats, 121 | num_layers=args.lstm_l1_layers 122 | ) 123 | 124 | self.rnn_l2 = nn.LSTM( 125 | input_size=args.layer_2_feats, 126 | hidden_size=args.lstm_l2_feats, 127 | num_layers=args.lstm_l2_layers 128 | ) 129 | self.W2 = Parameter(torch.Tensor(args.lstm_l1_feats, args.layer_2_feats)) 130 | u.reset_param(self.W2) 131 | 132 | def forward(self,A_list, Nodes_list = None, nodes_mask_list = None): 133 | l1_seq=[] 134 | l2_seq=[] 135 | for t,Ahat in enumerate(A_list): 136 | node_feats = Nodes_list[t] 137 | l1 = self.activation(Ahat.matmul(node_feats.matmul(self.w_list[0]))) 138 | l1_seq.append(l1) 139 | 140 | l1_seq = torch.stack(l1_seq) 141 | 142 | out_l1, _ = self.rnn_l1(l1_seq, None) 143 | 144 | for i in range(len(A_list)): 145 | Ahat = A_list[i] 146 | out_t_l1 = out_l1[i] 147 | #A_list: T, each element sparse tensor 148 | l2 = self.activation(Ahat.matmul(out_t_l1).matmul(self.w_list[1])) 149 | l2_seq.append(l2) 150 | 151 | l2_seq = torch.stack(l2_seq) 152 | 153 | out, _ = self.rnn_l2(l2_seq, None) 154 | return out[-1] 155 | 156 | 157 | class Sp_GCN_GRU_B(Sp_GCN_LSTM_B): 158 | def __init__(self,args,activation): 159 | super().__init__(args,activation) 160 | self.rnn_l1 = nn.GRU( 161 | input_size=args.layer_1_feats, 162 | hidden_size=args.lstm_l1_feats, 163 | num_layers=args.lstm_l1_layers 164 | ) 165 | 166 | self.rnn_l2 = nn.GRU( 167 | input_size=args.layer_2_feats, 168 | hidden_size=args.lstm_l2_feats, 169 | num_layers=args.lstm_l2_layers 170 | ) 171 | 172 | class Classifier(torch.nn.Module): 173 | def __init__(self,args,out_features=2, in_features = None): 174 | super(Classifier,self).__init__() 175 | activation = torch.nn.ReLU() 176 | 177 | if in_features is not None: 178 | num_feats = in_features 179 | elif args.experiment_type in ['sp_lstm_A_trainer', 'sp_lstm_B_trainer', 180 | 'sp_weighted_lstm_A', 'sp_weighted_lstm_B'] : 181 | num_feats = args.gcn_parameters['lstm_l2_feats'] * 2 182 | else: 183 | num_feats = args.gcn_parameters['layer_2_feats'] * 2 184 | print ('CLS num_feats',num_feats) 185 | 186 | self.mlp = torch.nn.Sequential(torch.nn.Linear(in_features = num_feats, 187 | out_features =args.gcn_parameters['cls_feats']), 188 | activation, 189 | torch.nn.Linear(in_features = args.gcn_parameters['cls_feats'], 190 | out_features = out_features)) 191 | 192 | def forward(self,x): 193 | return self.mlp(x) 194 | -------------------------------------------------------------------------------- /node_cls_tasker.py: -------------------------------------------------------------------------------- 1 | import taskers_utils as tu 2 | import torch 3 | import utils as u 4 | 5 | class Node_Cls_Tasker(): 6 | def __init__(self,args,dataset): 7 | self.data = dataset 8 | 9 | self.max_time = dataset.max_time 10 | 11 | self.args = args 12 | 13 | self.num_classes = 2 14 | 15 | self.feats_per_node = dataset.feats_per_node 16 | 17 | self.nodes_labels_times = dataset.nodes_labels_times 18 | 19 | self.get_node_feats = self.build_get_node_feats(args,dataset) 20 | 21 | self.prepare_node_feats = self.build_prepare_node_feats(args,dataset) 22 | 23 | self.is_static = False 24 | 25 | 26 | def build_get_node_feats(self,args,dataset): 27 | if args.use_2_hot_node_feats: 28 | max_deg_out, max_deg_in = tu.get_max_degs(args,dataset,all_window = True) 29 | self.feats_per_node = max_deg_out + max_deg_in 30 | def get_node_feats(i,adj): 31 | return tu.get_2_hot_deg_feats(adj, 32 | max_deg_out, 33 | max_deg_in, 34 | dataset.num_nodes) 35 | elif args.use_1_hot_node_feats: 36 | max_deg,_ = tu.get_max_degs(args,dataset) 37 | self.feats_per_node = max_deg 38 | def get_node_feats(i,adj): 39 | return tu.get_1_hot_deg_feats(adj, 40 | max_deg, 41 | dataset.num_nodes) 42 | else: 43 | def get_node_feats(i,adj): 44 | return dataset.nodes_feats#[i] I'm ignoring the index since the features for Elliptic are static 45 | 46 | return get_node_feats 47 | 48 | def build_prepare_node_feats(self,args,dataset): 49 | if args.use_2_hot_node_feats or args.use_1_hot_node_feats: 50 | def prepare_node_feats(node_feats): 51 | return u.sparse_prepare_tensor(node_feats, 52 | torch_size= [dataset.num_nodes, 53 | self.feats_per_node]) 54 | # elif args.use_1_hot_node_feats: 55 | 56 | else: 57 | def prepare_node_feats(node_feats): 58 | return node_feats[0] #I'll have to check this up 59 | 60 | return prepare_node_feats 61 | 62 | def get_sample(self,idx,test): 63 | hist_adj_list = [] 64 | hist_ndFeats_list = [] 65 | hist_mask_list = [] 66 | 67 | for i in range(idx - self.args.num_hist_steps, idx+1): 68 | #all edgess included from the beginning 69 | cur_adj = tu.get_sp_adj(edges = self.data.edges, 70 | time = i, 71 | weighted = True, 72 | time_window = self.args.adj_mat_time_window) #changed this to keep only a time window 73 | 74 | node_mask = tu.get_node_mask(cur_adj, self.data.num_nodes) 75 | 76 | node_feats = self.get_node_feats(i,cur_adj) 77 | 78 | cur_adj = tu.normalize_adj(adj = cur_adj, num_nodes = self.data.num_nodes) 79 | 80 | hist_adj_list.append(cur_adj) 81 | hist_ndFeats_list.append(node_feats) 82 | hist_mask_list.append(node_mask) 83 | 84 | label_adj = self.get_node_labels(idx) 85 | 86 | return {'idx': idx, 87 | 'hist_adj_list': hist_adj_list, 88 | 'hist_ndFeats_list': hist_ndFeats_list, 89 | 'label_sp': label_adj, 90 | 'node_mask_list': hist_mask_list} 91 | 92 | 93 | def get_node_labels(self,idx): 94 | # window_nodes = tu.get_sp_adj(edges = self.data.edges, 95 | # time = idx, 96 | # weighted = False, 97 | # time_window = self.args.adj_mat_time_window) 98 | 99 | # window_nodes = window_nodes['idx'].unique() 100 | 101 | # fraud_times = self.data.nodes_labels_times[window_nodes] 102 | 103 | # non_fraudulent = ((fraud_times > idx) + (fraud_times == -1))>0 104 | # non_fraudulent = window_nodes[non_fraudulent] 105 | 106 | # fraudulent = (fraud_times <= idx) * (fraud_times > max(idx - self.args.adj_mat_time_window,0)) 107 | # fraudulent = window_nodes[fraudulent] 108 | 109 | # label_idx = torch.cat([non_fraudulent,fraudulent]).view(-1,1) 110 | # label_vals = torch.cat([torch.zeros(non_fraudulent.size(0)), 111 | # torch.ones(fraudulent.size(0))]) 112 | node_labels = self.nodes_labels_times 113 | subset = node_labels[:,2]==idx 114 | label_idx = node_labels[subset,0] 115 | label_vals = node_labels[subset,1] 116 | 117 | return {'idx': label_idx, 118 | 'vals': label_vals} 119 | 120 | 121 | 122 | 123 | class Static_Node_Cls_Tasker(Node_Cls_Tasker): 124 | def __init__(self,args,dataset): 125 | self.data = dataset 126 | 127 | self.args = args 128 | 129 | self.num_classes = 2 130 | 131 | 132 | 133 | self.adj_matrix = tu.get_static_sp_adj(edges = self.data.edges, weighted = False) 134 | 135 | if args.use_2_hot_node_feats: 136 | max_deg_out, max_deg_in = tu.get_max_degs_static(self.data.num_nodes,self.adj_matrix) 137 | self.feats_per_node = max_deg_out + max_deg_in 138 | #print ('feats_per_node',self.feats_per_node ,max_deg_out, max_deg_in) 139 | self.nodes_feats = tu.get_2_hot_deg_feats(self.adj_matrix , 140 | max_deg_out, 141 | max_deg_in, 142 | dataset.num_nodes) 143 | 144 | #print('XXXX self.nodes_feats',self.nodes_feats) 145 | self.nodes_feats = u.sparse_prepare_tensor(self.nodes_feats, torch_size= [self.data.num_nodes,self.feats_per_node], ignore_batch_dim = False) 146 | 147 | else: 148 | self.feats_per_node = dataset.feats_per_node 149 | self.nodes_feats = self.data.node_feats 150 | 151 | self.adj_matrix = tu.normalize_adj(adj = self.adj_matrix, num_nodes = self.data.num_nodes) 152 | self.is_static = True 153 | 154 | def get_sample(self,idx,test): 155 | #print ('self.adj_matrix',self.adj_matrix.size()) 156 | idx=int(idx) 157 | #node_feats = self.data.node_feats_dict[idx] 158 | label = self.data.nodes_labels[idx] 159 | 160 | 161 | return {'idx': idx, 162 | #'node_feats': self.data.node_feats, 163 | #'adj': self.adj_matrix, 164 | 'label': label 165 | } 166 | 167 | 168 | 169 | if __name__ == '__main__': 170 | fraud_times = torch.tensor([10,5,3,6,7,-1,-1]) 171 | idx = 6 172 | non_fraudulent = ((fraud_times > idx) + (fraud_times == -1))>0 173 | print(non_fraudulent) 174 | exit() 175 | -------------------------------------------------------------------------------- /reddit_dl.py: -------------------------------------------------------------------------------- 1 | import utils as u 2 | import os 3 | from datetime import datetime 4 | import torch 5 | 6 | class Reddit_Dataset(): 7 | def __init__(self,args): 8 | args.reddit_args = u.Namespace(args.reddit_args) 9 | folder = args.reddit_args.folder 10 | 11 | #load nodes 12 | cols = u.Namespace({'id': 0, 13 | 'feats': 1}) 14 | file = args.reddit_args.nodes_file 15 | file = os.path.join(folder,file) 16 | with open(file) as file: 17 | file = file.read().splitlines() 18 | 19 | ids_str_to_int = {} 20 | id_counter = 0 21 | 22 | feats = [] 23 | 24 | for line in file: 25 | line = line.split(',') 26 | #node id 27 | nd_id = line[0] 28 | if nd_id not in ids_str_to_int.keys(): 29 | ids_str_to_int[nd_id] = id_counter 30 | id_counter += 1 31 | nd_feats = [float(r) for r in line[1:]] 32 | feats.append(nd_feats) 33 | else: 34 | print('duplicate id', nd_id) 35 | raise Exception('duplicate_id') 36 | 37 | feats = torch.tensor(feats,dtype=torch.float) 38 | num_nodes = feats.size(0) 39 | 40 | edges = [] 41 | not_found = 0 42 | 43 | #load edges in title 44 | edges_tmp, not_found_tmp = self.load_edges_from_file(args.reddit_args.title_edges_file, 45 | folder, 46 | ids_str_to_int) 47 | edges.extend(edges_tmp) 48 | not_found += not_found_tmp 49 | 50 | #load edges in bodies 51 | 52 | edges_tmp, not_found_tmp = self.load_edges_from_file(args.reddit_args.body_edges_file, 53 | folder, 54 | ids_str_to_int) 55 | edges.extend(edges_tmp) 56 | not_found += not_found_tmp 57 | 58 | #min time should be 0 and time aggregation 59 | edges = torch.LongTensor(edges) 60 | edges[:,2] = u.aggregate_by_time(edges[:,2],args.reddit_args.aggr_time) 61 | max_time = edges[:,2].max() 62 | 63 | #separate classes 64 | sp_indices = edges[:,:3].t() 65 | sp_values = edges[:,3] 66 | 67 | # sp_edges = torch.sparse.LongTensor(sp_indices 68 | # ,sp_values, 69 | # torch.Size([num_nodes, 70 | # num_nodes, 71 | # max_time+1])).coalesce() 72 | # vals = sp_edges._values() 73 | # print(vals[vals>0].sum() + vals[vals<0].sum()*-1) 74 | # asdf 75 | 76 | pos_mask = sp_values == 1 77 | neg_mask = sp_values == -1 78 | 79 | neg_sp_indices = sp_indices[:,neg_mask] 80 | neg_sp_values = sp_values[neg_mask] 81 | neg_sp_edges = torch.sparse.LongTensor(neg_sp_indices 82 | ,neg_sp_values, 83 | torch.Size([num_nodes, 84 | num_nodes, 85 | max_time+1])).coalesce() 86 | 87 | pos_sp_indices = sp_indices[:,pos_mask] 88 | pos_sp_values = sp_values[pos_mask] 89 | 90 | pos_sp_edges = torch.sparse.LongTensor(pos_sp_indices 91 | ,pos_sp_values, 92 | torch.Size([num_nodes, 93 | num_nodes, 94 | max_time+1])).coalesce() 95 | 96 | #scale positive class to separate after adding 97 | pos_sp_edges *= 1000 98 | 99 | sp_edges = (pos_sp_edges - neg_sp_edges).coalesce() 100 | 101 | #separating negs and positive edges per edge/timestamp 102 | vals = sp_edges._values() 103 | neg_vals = vals%1000 104 | pos_vals = vals//1000 105 | #vals is simply the number of edges between two nodes at the same time_step, regardless of the edge label 106 | vals = pos_vals - neg_vals 107 | 108 | #creating labels new_vals -> the label of the edges 109 | new_vals = torch.zeros(vals.size(0),dtype=torch.long) 110 | new_vals[vals>0] = 1 111 | new_vals[vals<=0] = 0 112 | vals = pos_vals + neg_vals 113 | indices_labels = torch.cat([sp_edges._indices().t(),new_vals.view(-1,1)],dim=1) 114 | 115 | self.edges = {'idx': indices_labels, 'vals': vals} 116 | self.num_classes = 2 117 | self.feats_per_node = feats.size(1) 118 | self.num_nodes = num_nodes 119 | self.nodes_feats = feats 120 | self.max_time = max_time 121 | self.min_time = 0 122 | 123 | def prepare_node_feats(self,node_feats): 124 | node_feats = node_feats[0] 125 | return node_feats 126 | 127 | 128 | def load_edges_from_file(self,edges_file,folder,ids_str_to_int): 129 | edges = [] 130 | not_found = 0 131 | 132 | file = edges_file 133 | 134 | file = os.path.join(folder,file) 135 | with open(file) as file: 136 | file = file.read().splitlines() 137 | 138 | cols = u.Namespace({'source': 0, 139 | 'target': 1, 140 | 'time': 3, 141 | 'label': 4}) 142 | 143 | base_time = datetime.strptime("19800101", '%Y%m%d') 144 | 145 | 146 | for line in file[1:]: 147 | fields = line.split('\t') 148 | sr = fields[cols.source] 149 | tg = fields[cols.target] 150 | 151 | if sr in ids_str_to_int.keys() and tg in ids_str_to_int.keys(): 152 | sr = ids_str_to_int[sr] 153 | tg = ids_str_to_int[tg] 154 | 155 | time = fields[cols.time].split(' ')[0] 156 | time = datetime.strptime(time,'%Y-%m-%d') 157 | time = (time - base_time).days 158 | 159 | label = int(fields[cols.label]) 160 | edges.append([sr,tg,time,label]) 161 | #add the other edge to make it undirected 162 | edges.append([tg,sr,time,label]) 163 | else: 164 | not_found+=1 165 | 166 | return edges, not_found 167 | 168 | 169 | 170 | 171 | 172 | 173 | 174 | 175 | 176 | 177 | 178 | 179 | 180 | 181 | 182 | 183 | 184 | 185 | 186 | 187 | 188 | 189 | 190 | 191 | 192 | 193 | -------------------------------------------------------------------------------- /run_exp.py: -------------------------------------------------------------------------------- 1 | import utils as u 2 | import torch 3 | import torch.distributed as dist 4 | import numpy as np 5 | import time 6 | import random 7 | 8 | #datasets 9 | import bitcoin_dl as bc 10 | import elliptic_temporal_dl as ell_temp 11 | import uc_irv_mess_dl as ucim 12 | import auto_syst_dl as aus 13 | import sbm_dl as sbm 14 | import reddit_dl as rdt 15 | 16 | 17 | #taskers 18 | import link_pred_tasker as lpt 19 | import edge_cls_tasker as ect 20 | import node_cls_tasker as nct 21 | 22 | #models 23 | import models as mls 24 | import egcn_h 25 | import egcn_o 26 | 27 | 28 | import splitter as sp 29 | import Cross_Entropy as ce 30 | 31 | import trainer as tr 32 | 33 | import logger 34 | 35 | 36 | def random_param_value(param, param_min, param_max, type='int'): 37 | if str(param) is None or str(param).lower()=='none': 38 | if type=='int': 39 | return random.randrange(param_min, param_max+1) 40 | elif type=='logscale': 41 | interval=np.logspace(np.log10(param_min), np.log10(param_max), num=100) 42 | return np.random.choice(interval,1)[0] 43 | else: 44 | return random.uniform(param_min, param_max) 45 | else: 46 | return param 47 | 48 | def build_random_hyper_params(args): 49 | if args.model == 'all': 50 | model_types = ['gcn', 'egcn_o', 'egcn_h', 'gruA', 'gruB','egcn','lstmA', 'lstmB'] 51 | args.model=model_types[args.rank] 52 | elif args.model == 'all_nogcn': 53 | model_types = ['egcn_o', 'egcn_h', 'gruA', 'gruB','egcn','lstmA', 'lstmB'] 54 | args.model=model_types[args.rank] 55 | elif args.model == 'all_noegcn3': 56 | model_types = ['gcn', 'egcn_h', 'gruA', 'gruB','egcn','lstmA', 'lstmB'] 57 | args.model=model_types[args.rank] 58 | elif args.model == 'all_nogruA': 59 | model_types = ['gcn', 'egcn_o', 'egcn_h', 'gruB','egcn','lstmA', 'lstmB'] 60 | args.model=model_types[args.rank] 61 | args.model=model_types[args.rank] 62 | elif args.model == 'saveembs': 63 | model_types = ['gcn', 'gcn', 'skipgcn', 'skipgcn'] 64 | args.model=model_types[args.rank] 65 | 66 | args.learning_rate =random_param_value(args.learning_rate, args.learning_rate_min, args.learning_rate_max, type='logscale') 67 | # args.adj_mat_time_window = random_param_value(args.adj_mat_time_window, args.adj_mat_time_window_min, args.adj_mat_time_window_max, type='int') 68 | 69 | if args.model == 'gcn': 70 | args.num_hist_steps = 0 71 | else: 72 | args.num_hist_steps = random_param_value(args.num_hist_steps, args.num_hist_steps_min, args.num_hist_steps_max, type='int') 73 | 74 | args.gcn_parameters['feats_per_node'] =random_param_value(args.gcn_parameters['feats_per_node'], args.gcn_parameters['feats_per_node_min'], args.gcn_parameters['feats_per_node_max'], type='int') 75 | args.gcn_parameters['layer_1_feats'] =random_param_value(args.gcn_parameters['layer_1_feats'], args.gcn_parameters['layer_1_feats_min'], args.gcn_parameters['layer_1_feats_max'], type='int') 76 | if args.gcn_parameters['layer_2_feats_same_as_l1'] or args.gcn_parameters['layer_2_feats_same_as_l1'].lower()=='true': 77 | args.gcn_parameters['layer_2_feats'] = args.gcn_parameters['layer_1_feats'] 78 | else: 79 | args.gcn_parameters['layer_2_feats'] =random_param_value(args.gcn_parameters['layer_2_feats'], args.gcn_parameters['layer_1_feats_min'], args.gcn_parameters['layer_1_feats_max'], type='int') 80 | args.gcn_parameters['lstm_l1_feats'] =random_param_value(args.gcn_parameters['lstm_l1_feats'], args.gcn_parameters['lstm_l1_feats_min'], args.gcn_parameters['lstm_l1_feats_max'], type='int') 81 | if args.gcn_parameters['lstm_l2_feats_same_as_l1'] or args.gcn_parameters['lstm_l2_feats_same_as_l1'].lower()=='true': 82 | args.gcn_parameters['lstm_l2_feats'] = args.gcn_parameters['lstm_l1_feats'] 83 | else: 84 | args.gcn_parameters['lstm_l2_feats'] =random_param_value(args.gcn_parameters['lstm_l2_feats'], args.gcn_parameters['lstm_l1_feats_min'], args.gcn_parameters['lstm_l1_feats_max'], type='int') 85 | args.gcn_parameters['cls_feats']=random_param_value(args.gcn_parameters['cls_feats'], args.gcn_parameters['cls_feats_min'], args.gcn_parameters['cls_feats_max'], type='int') 86 | return args 87 | 88 | def build_dataset(args): 89 | if args.data == 'bitcoinotc' or args.data == 'bitcoinalpha': 90 | if args.data == 'bitcoinotc': 91 | args.bitcoin_args = args.bitcoinotc_args 92 | elif args.data == 'bitcoinalpha': 93 | args.bitcoin_args = args.bitcoinalpha_args 94 | return bc.bitcoin_dataset(args) 95 | elif args.data == 'aml_sim': 96 | return aml.Aml_Dataset(args) 97 | elif args.data == 'elliptic': 98 | return ell.Elliptic_Dataset(args) 99 | elif args.data == 'elliptic_temporal': 100 | return ell_temp.Elliptic_Temporal_Dataset(args) 101 | elif args.data == 'uc_irv_mess': 102 | return ucim.Uc_Irvine_Message_Dataset(args) 103 | elif args.data == 'dbg': 104 | return dbg.dbg_dataset(args) 105 | elif args.data == 'colored_graph': 106 | return cg.Colored_Graph(args) 107 | elif args.data == 'autonomous_syst': 108 | return aus.Autonomous_Systems_Dataset(args) 109 | elif args.data == 'reddit': 110 | return rdt.Reddit_Dataset(args) 111 | elif args.data.startswith('sbm'): 112 | if args.data == 'sbm20': 113 | args.sbm_args = args.sbm20_args 114 | elif args.data == 'sbm50': 115 | args.sbm_args = args.sbm50_args 116 | return sbm.sbm_dataset(args) 117 | else: 118 | raise NotImplementedError('only arxiv has been implemented') 119 | 120 | def build_tasker(args,dataset): 121 | if args.task == 'link_pred': 122 | return lpt.Link_Pred_Tasker(args,dataset) 123 | elif args.task == 'edge_cls': 124 | return ect.Edge_Cls_Tasker(args,dataset) 125 | elif args.task == 'node_cls': 126 | return nct.Node_Cls_Tasker(args,dataset) 127 | elif args.task == 'static_node_cls': 128 | return nct.Static_Node_Cls_Tasker(args,dataset) 129 | 130 | else: 131 | raise NotImplementedError('still need to implement the other tasks') 132 | 133 | def build_gcn(args,tasker): 134 | gcn_args = u.Namespace(args.gcn_parameters) 135 | gcn_args.feats_per_node = tasker.feats_per_node 136 | if args.model == 'gcn': 137 | return mls.Sp_GCN(gcn_args,activation = torch.nn.RReLU()).to(args.device) 138 | elif args.model == 'skipgcn': 139 | return mls.Sp_Skip_GCN(gcn_args,activation = torch.nn.RReLU()).to(args.device) 140 | elif args.model == 'skipfeatsgcn': 141 | return mls.Sp_Skip_NodeFeats_GCN(gcn_args,activation = torch.nn.RReLU()).to(args.device) 142 | else: 143 | assert args.num_hist_steps > 0, 'more than one step is necessary to train LSTM' 144 | if args.model == 'lstmA': 145 | return mls.Sp_GCN_LSTM_A(gcn_args,activation = torch.nn.RReLU()).to(args.device) 146 | elif args.model == 'gruA': 147 | return mls.Sp_GCN_GRU_A(gcn_args,activation = torch.nn.RReLU()).to(args.device) 148 | elif args.model == 'lstmB': 149 | return mls.Sp_GCN_LSTM_B(gcn_args,activation = torch.nn.RReLU()).to(args.device) 150 | elif args.model == 'gruB': 151 | return mls.Sp_GCN_GRU_B(gcn_args,activation = torch.nn.RReLU()).to(args.device) 152 | elif args.model == 'egcn': 153 | return egcn.EGCN(gcn_args, activation = torch.nn.RReLU()).to(args.device) 154 | elif args.model == 'egcn_h': 155 | return egcn_h.EGCN(gcn_args, activation = torch.nn.RReLU(), device = args.device) 156 | elif args.model == 'skipfeatsegcn_h': 157 | return egcn_h.EGCN(gcn_args, activation = torch.nn.RReLU(), device = args.device, skipfeats=True) 158 | elif args.model == 'egcn_o': 159 | return egcn_o.EGCN(gcn_args, activation = torch.nn.RReLU(), device = args.device) 160 | else: 161 | raise NotImplementedError('need to finish modifying the models') 162 | 163 | def build_classifier(args,tasker): 164 | if 'node_cls' == args.task or 'static_node_cls' == args.task: 165 | mult = 1 166 | else: 167 | mult = 2 168 | if 'gru' in args.model or 'lstm' in args.model: 169 | in_feats = args.gcn_parameters['lstm_l2_feats'] * mult 170 | elif args.model == 'skipfeatsgcn' or args.model == 'skipfeatsegcn_h': 171 | in_feats = (args.gcn_parameters['layer_2_feats'] + args.gcn_parameters['feats_per_node']) * mult 172 | else: 173 | in_feats = args.gcn_parameters['layer_2_feats'] * mult 174 | 175 | return mls.Classifier(args,in_features = in_feats, out_features = tasker.num_classes).to(args.device) 176 | 177 | if __name__ == '__main__': 178 | parser = u.create_parser() 179 | args = u.parse_args(parser) 180 | 181 | global rank, wsize, use_cuda 182 | args.use_cuda = (torch.cuda.is_available() and args.use_cuda) 183 | args.device='cpu' 184 | if args.use_cuda: 185 | args.device='cuda' 186 | print ("use CUDA:", args.use_cuda, "- device:", args.device) 187 | try: 188 | dist.init_process_group(backend='mpi') #, world_size=4 189 | rank = dist.get_rank() 190 | wsize = dist.get_world_size() 191 | print('Hello from process {} (out of {})'.format(dist.get_rank(), dist.get_world_size())) 192 | if args.use_cuda: 193 | torch.cuda.set_device(rank ) # are we sure of the rank+1???? 194 | print('using the device {}'.format(torch.cuda.current_device())) 195 | except: 196 | rank = 0 197 | wsize = 1 198 | print(('MPI backend not preset. Set process rank to {} (out of {})'.format(rank, 199 | wsize))) 200 | 201 | if args.seed is None and args.seed!='None': 202 | seed = 123+rank#int(time.time())+rank 203 | else: 204 | seed=args.seed#+rank 205 | np.random.seed(seed) 206 | random.seed(seed) 207 | torch.manual_seed(seed) 208 | torch.cuda.manual_seed(seed) 209 | torch.cuda.manual_seed_all(seed) 210 | args.seed=seed 211 | args.rank=rank 212 | args.wsize=wsize 213 | 214 | # Assign the requested random hyper parameters 215 | args = build_random_hyper_params(args) 216 | 217 | #build the dataset 218 | dataset = build_dataset(args) 219 | #build the tasker 220 | tasker = build_tasker(args,dataset) 221 | #build the splitter 222 | splitter = sp.splitter(args,tasker) 223 | #build the models 224 | gcn = build_gcn(args, tasker) 225 | classifier = build_classifier(args,tasker) 226 | #build a loss 227 | cross_entropy = ce.Cross_Entropy(args,dataset).to(args.device) 228 | 229 | #trainer 230 | trainer = tr.Trainer(args, 231 | splitter = splitter, 232 | gcn = gcn, 233 | classifier = classifier, 234 | comp_loss = cross_entropy, 235 | dataset = dataset, 236 | num_classes = tasker.num_classes) 237 | 238 | trainer.train() 239 | -------------------------------------------------------------------------------- /sbm_dl.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import utils as u 3 | import os 4 | 5 | class sbm_dataset(): 6 | def __init__(self,args): 7 | assert args.task in ['link_pred'], 'sbm only implements link_pred' 8 | self.ecols = u.Namespace({'FromNodeId': 0, 9 | 'ToNodeId': 1, 10 | 'Weight': 2, 11 | 'TimeStep': 3 12 | }) 13 | args.sbm_args = u.Namespace(args.sbm_args) 14 | 15 | #build edge data structure 16 | edges = self.load_edges(args.sbm_args) 17 | timesteps = u.aggregate_by_time(edges[:,self.ecols.TimeStep], args.sbm_args.aggr_time) 18 | self.max_time = timesteps.max() 19 | self.min_time = timesteps.min() 20 | print ('TIME', self.max_time, self.min_time ) 21 | edges[:,self.ecols.TimeStep] = timesteps 22 | 23 | edges[:,self.ecols.Weight] = self.cluster_negs_and_positives(edges[:,self.ecols.Weight]) 24 | self.num_classes = edges[:,self.ecols.Weight].unique().size(0) 25 | 26 | self.edges = self.edges_to_sp_dict(edges) 27 | 28 | #random node features 29 | self.num_nodes = int(self.get_num_nodes(edges)) 30 | self.feats_per_node = args.sbm_args.feats_per_node 31 | self.nodes_feats = torch.rand((self.num_nodes,self.feats_per_node)) 32 | 33 | self.num_non_existing = self.num_nodes ** 2 - edges.size(0) 34 | 35 | def cluster_negs_and_positives(self,ratings): 36 | pos_indices = ratings >= 0 37 | neg_indices = ratings < 0 38 | ratings[pos_indices] = 1 39 | ratings[neg_indices] = 0 40 | return ratings 41 | 42 | def prepare_node_feats(self,node_feats): 43 | node_feats = node_feats[0] 44 | return node_feats 45 | 46 | def edges_to_sp_dict(self,edges): 47 | idx = edges[:,[self.ecols.FromNodeId, 48 | self.ecols.ToNodeId, 49 | self.ecols.TimeStep]] 50 | 51 | vals = edges[:,self.ecols.Weight] 52 | return {'idx': idx, 53 | 'vals': vals} 54 | 55 | def get_num_nodes(self,edges): 56 | all_ids = edges[:,[self.ecols.FromNodeId,self.ecols.ToNodeId]] 57 | num_nodes = all_ids.max() + 1 58 | return num_nodes 59 | 60 | def load_edges(self,sbm_args, starting_line = 1): 61 | file = os.path.join(sbm_args.folder,sbm_args.edges_file) 62 | with open(file) as f: 63 | lines = f.read().splitlines() 64 | edges = [[float(r) for r in row.split(',')] for row in lines[starting_line:]] 65 | edges = torch.tensor(edges,dtype = torch.long) 66 | return edges 67 | 68 | def make_contigous_node_ids(self,edges): 69 | new_edges = edges[:,[self.ecols.FromNodeId,self.ecols.ToNodeId]] 70 | _, new_edges = new_edges.unique(return_inverse=True) 71 | edges[:,[self.ecols.FromNodeId,self.ecols.ToNodeId]] = new_edges 72 | return edges 73 | -------------------------------------------------------------------------------- /splitter.py: -------------------------------------------------------------------------------- 1 | from torch.utils.data import Dataset, DataLoader 2 | import torch 3 | import numpy as np 4 | import utils as u 5 | 6 | class splitter(): 7 | ''' 8 | creates 3 splits 9 | train 10 | dev 11 | test 12 | ''' 13 | def __init__(self,args,tasker): 14 | 15 | 16 | if tasker.is_static: #### For static datsets 17 | assert args.train_proportion + args.dev_proportion < 1, \ 18 | 'there\'s no space for test samples' 19 | #only the training one requires special handling on start, the others are fine with the split IDX. 20 | 21 | random_perm=False 22 | indexes = tasker.data.nodes_with_label 23 | 24 | if random_perm: 25 | perm_idx = torch.randperm(indexes.size(0)) 26 | perm_idx = indexes[perm_idx] 27 | else: 28 | print ('tasker.data.nodes',indexes.size()) 29 | perm_idx, _ = indexes.sort() 30 | #print ('perm_idx',perm_idx[:10]) 31 | 32 | self.train_idx = perm_idx[:int(args.train_proportion*perm_idx.size(0))] 33 | self.dev_idx = perm_idx[int(args.train_proportion*perm_idx.size(0)): int((args.train_proportion+args.dev_proportion)*perm_idx.size(0))] 34 | self.test_idx = perm_idx[int((args.train_proportion+args.dev_proportion)*perm_idx.size(0)):] 35 | # print ('train,dev,test',self.train_idx.size(), self.dev_idx.size(), self.test_idx.size()) 36 | 37 | train = static_data_split(tasker, self.train_idx, test = False) 38 | train = DataLoader(train, shuffle=True,**args.data_loading_params) 39 | 40 | dev = static_data_split(tasker, self.dev_idx, test = True) 41 | dev = DataLoader(dev, shuffle=False,**args.data_loading_params) 42 | 43 | test = static_data_split(tasker, self.test_idx, test = True) 44 | test = DataLoader(test, shuffle=False,**args.data_loading_params) 45 | 46 | self.tasker = tasker 47 | self.train = train 48 | self.dev = dev 49 | self.test = test 50 | 51 | 52 | else: #### For datsets with time 53 | assert args.train_proportion + args.dev_proportion < 1, \ 54 | 'there\'s no space for test samples' 55 | #only the training one requires special handling on start, the others are fine with the split IDX. 56 | start = tasker.data.min_time + args.num_hist_steps #-1 + args.adj_mat_time_window 57 | end = args.train_proportion 58 | 59 | end = int(np.floor(tasker.data.max_time.type(torch.float) * end)) 60 | train = data_split(tasker, start, end, test = False) 61 | train = DataLoader(train,**args.data_loading_params) 62 | 63 | start = end 64 | end = args.dev_proportion + args.train_proportion 65 | end = int(np.floor(tasker.data.max_time.type(torch.float) * end)) 66 | if args.task == 'link_pred': 67 | dev = data_split(tasker, start, end, test = True, all_edges=True) 68 | else: 69 | dev = data_split(tasker, start, end, test = True) 70 | 71 | dev = DataLoader(dev,num_workers=args.data_loading_params['num_workers']) 72 | 73 | start = end 74 | 75 | #the +1 is because I assume that max_time exists in the dataset 76 | end = int(tasker.max_time) + 1 77 | if args.task == 'link_pred': 78 | test = data_split(tasker, start, end, test = True, all_edges=True) 79 | else: 80 | test = data_split(tasker, start, end, test = True) 81 | 82 | test = DataLoader(test,num_workers=args.data_loading_params['num_workers']) 83 | 84 | print ('Dataset splits sizes: train',len(train), 'dev',len(dev), 'test',len(test)) 85 | 86 | 87 | 88 | self.tasker = tasker 89 | self.train = train 90 | self.dev = dev 91 | self.test = test 92 | 93 | 94 | 95 | class data_split(Dataset): 96 | def __init__(self, tasker, start, end, test, **kwargs): 97 | ''' 98 | start and end are indices indicating what items belong to this split 99 | ''' 100 | self.tasker = tasker 101 | self.start = start 102 | self.end = end 103 | self.test = test 104 | self.kwargs = kwargs 105 | 106 | def __len__(self): 107 | return self.end-self.start 108 | 109 | def __getitem__(self,idx): 110 | idx = self.start + idx 111 | t = self.tasker.get_sample(idx, test = self.test, **self.kwargs) 112 | return t 113 | 114 | 115 | class static_data_split(Dataset): 116 | def __init__(self, tasker, indexes, test): 117 | ''' 118 | start and end are indices indicating what items belong to this split 119 | ''' 120 | self.tasker = tasker 121 | self.indexes = indexes 122 | self.test = test 123 | self.adj_matrix = tasker.adj_matrix 124 | 125 | def __len__(self): 126 | return len(self.indexes) 127 | 128 | def __getitem__(self,idx): 129 | idx = self.indexes[idx] 130 | return self.tasker.get_sample(idx,test = self.test) 131 | -------------------------------------------------------------------------------- /taskers_utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import utils as u 3 | import numpy as np 4 | import time 5 | 6 | ECOLS = u.Namespace({'source': 0, 7 | 'target': 1, 8 | 'time': 2, 9 | 'label':3}) #--> added for edge_cls 10 | 11 | # def get_2_hot_deg_feats(adj,max_deg_out,max_deg_in,num_nodes): 12 | # #For now it'll just return a 2-hot vector 13 | # adj['vals'] = torch.ones(adj['idx'].size(0)) 14 | # degs_out, degs_in = get_degree_vects(adj,num_nodes) 15 | 16 | # degs_out = {'idx': torch.cat([torch.arange(num_nodes).view(-1,1), 17 | # degs_out.view(-1,1)],dim=1), 18 | # 'vals': torch.ones(num_nodes)} 19 | 20 | # # print ('XXX degs_out',degs_out['idx'].size(),degs_out['vals'].size()) 21 | # degs_out = u.make_sparse_tensor(degs_out,'long',[num_nodes,max_deg_out]) 22 | 23 | # degs_in = {'idx': torch.cat([torch.arange(num_nodes).view(-1,1), 24 | # degs_in.view(-1,1)],dim=1), 25 | # 'vals': torch.ones(num_nodes)} 26 | # degs_in = u.make_sparse_tensor(degs_in,'long',[num_nodes,max_deg_in]) 27 | 28 | # hot_2 = torch.cat([degs_out,degs_in],dim = 1) 29 | # hot_2 = {'idx': hot_2._indices().t(), 30 | # 'vals': hot_2._values()} 31 | 32 | # return hot_2 33 | 34 | def get_1_hot_deg_feats(adj,max_deg,num_nodes): 35 | #For now it'll just return a 2-hot vector 36 | new_vals = torch.ones(adj['idx'].size(0)) 37 | new_adj = {'idx':adj['idx'], 'vals': new_vals} 38 | degs_out, _ = get_degree_vects(new_adj,num_nodes) 39 | 40 | degs_out = {'idx': torch.cat([torch.arange(num_nodes).view(-1,1), 41 | degs_out.view(-1,1)],dim=1), 42 | 'vals': torch.ones(num_nodes)} 43 | 44 | # print ('XXX degs_out',degs_out['idx'].size(),degs_out['vals'].size()) 45 | degs_out = u.make_sparse_tensor(degs_out,'long',[num_nodes,max_deg]) 46 | 47 | hot_1 = {'idx': degs_out._indices().t(), 48 | 'vals': degs_out._values()} 49 | return hot_1 50 | 51 | def get_max_degs(args,dataset,all_window=False): 52 | max_deg_out = [] 53 | max_deg_in = [] 54 | for t in range(dataset.min_time, dataset.max_time): 55 | if all_window: 56 | window = t+1 57 | else: 58 | window = args.adj_mat_time_window 59 | 60 | cur_adj = get_sp_adj(edges = dataset.edges, 61 | time = t, 62 | weighted = False, 63 | time_window = window) 64 | # print(window) 65 | cur_out, cur_in = get_degree_vects(cur_adj,dataset.num_nodes) 66 | max_deg_out.append(cur_out.max()) 67 | max_deg_in.append(cur_in.max()) 68 | # max_deg_out = torch.stack([max_deg_out,cur_out.max()]).max() 69 | # max_deg_in = torch.stack([max_deg_in,cur_in.max()]).max() 70 | # exit() 71 | max_deg_out = torch.stack(max_deg_out).max() 72 | max_deg_in = torch.stack(max_deg_in).max() 73 | max_deg_out = int(max_deg_out) + 1 74 | max_deg_in = int(max_deg_in) + 1 75 | 76 | return max_deg_out, max_deg_in 77 | 78 | def get_max_degs_static(num_nodes, adj_matrix): 79 | cur_out, cur_in = get_degree_vects(adj_matrix, num_nodes) 80 | max_deg_out = int(cur_out.max().item()) + 1 81 | max_deg_in = int(cur_in.max().item()) + 1 82 | 83 | return max_deg_out, max_deg_in 84 | 85 | 86 | def get_degree_vects(adj,num_nodes): 87 | adj = u.make_sparse_tensor(adj,'long',[num_nodes]) 88 | degs_out = adj.matmul(torch.ones(num_nodes,1,dtype = torch.long)) 89 | degs_in = adj.t().matmul(torch.ones(num_nodes,1,dtype = torch.long)) 90 | return degs_out, degs_in 91 | 92 | def get_sp_adj(edges,time,weighted,time_window): 93 | idx = edges['idx'] 94 | subset = idx[:,ECOLS.time] <= time 95 | subset = subset * (idx[:,ECOLS.time] > (time - time_window)) 96 | idx = edges['idx'][subset][:,[ECOLS.source, ECOLS.target]] 97 | vals = edges['vals'][subset] 98 | out = torch.sparse.FloatTensor(idx.t(),vals).coalesce() 99 | 100 | 101 | idx = out._indices().t() 102 | if weighted: 103 | vals = out._values() 104 | else: 105 | vals = torch.ones(idx.size(0),dtype=torch.long) 106 | 107 | return {'idx': idx, 'vals': vals} 108 | 109 | def get_edge_labels(edges,time): 110 | idx = edges['idx'] 111 | subset = idx[:,ECOLS.time] == time 112 | idx = edges['idx'][subset][:,[ECOLS.source, ECOLS.target]] 113 | vals = edges['idx'][subset][:,ECOLS.label] 114 | 115 | return {'idx': idx, 'vals': vals} 116 | 117 | 118 | def get_node_mask(cur_adj,num_nodes): 119 | mask = torch.zeros(num_nodes) - float("Inf") 120 | non_zero = cur_adj['idx'].unique() 121 | 122 | mask[non_zero] = 0 123 | 124 | return mask 125 | 126 | def get_static_sp_adj(edges,weighted): 127 | idx = edges['idx'] 128 | #subset = idx[:,ECOLS.time] <= time 129 | #subset = subset * (idx[:,ECOLS.time] > (time - time_window)) 130 | 131 | #idx = edges['idx'][subset][:,[ECOLS.source, ECOLS.target]] 132 | if weighted: 133 | vals = edges['vals'][subset] 134 | else: 135 | vals = torch.ones(idx.size(0),dtype = torch.long) 136 | 137 | return {'idx': idx, 'vals': vals} 138 | 139 | def get_sp_adj_only_new(edges,time,weighted): 140 | return get_sp_adj(edges, time, weighted, time_window=1) 141 | 142 | def normalize_adj(adj,num_nodes): 143 | ''' 144 | takes an adj matrix as a dict with idx and vals and normalize it by: 145 | - adding an identity matrix, 146 | - computing the degree vector 147 | - multiplying each element of the adj matrix (aij) by (di*dj)^-1/2 148 | ''' 149 | idx = adj['idx'] 150 | vals = adj['vals'] 151 | 152 | 153 | sp_tensor = torch.sparse.FloatTensor(idx.t(),vals.type(torch.float),torch.Size([num_nodes,num_nodes])) 154 | 155 | sparse_eye = make_sparse_eye(num_nodes) 156 | sp_tensor = sparse_eye + sp_tensor 157 | 158 | idx = sp_tensor._indices() 159 | vals = sp_tensor._values() 160 | 161 | degree = torch.sparse.sum(sp_tensor,dim=1).to_dense() 162 | di = degree[idx[0]] 163 | dj = degree[idx[1]] 164 | 165 | vals = vals * ((di * dj) ** -0.5) 166 | 167 | return {'idx': idx.t(), 'vals': vals} 168 | 169 | def make_sparse_eye(size): 170 | eye_idx = torch.arange(size) 171 | eye_idx = torch.stack([eye_idx,eye_idx],dim=1).t() 172 | vals = torch.ones(size) 173 | eye = torch.sparse.FloatTensor(eye_idx,vals,torch.Size([size,size])) 174 | return eye 175 | 176 | def get_all_non_existing_edges(adj,tot_nodes): 177 | true_ids = adj['idx'].t().numpy() 178 | true_ids = get_edges_ids(true_ids,tot_nodes) 179 | 180 | all_edges_idx = np.arange(tot_nodes) 181 | all_edges_idx = np.array(np.meshgrid(all_edges_idx, 182 | all_edges_idx)).reshape(2,-1) 183 | 184 | all_edges_ids = get_edges_ids(all_edges_idx,tot_nodes) 185 | 186 | #only edges that are not in the true_ids should keep here 187 | mask = np.logical_not(np.isin(all_edges_ids,true_ids)) 188 | 189 | non_existing_edges_idx = all_edges_idx[:,mask] 190 | edges = torch.tensor(non_existing_edges_idx).t() 191 | vals = torch.zeros(edges.size(0), dtype = torch.long) 192 | return {'idx': edges, 'vals': vals} 193 | 194 | 195 | def get_non_existing_edges(adj,number, tot_nodes, smart_sampling, existing_nodes=None): 196 | # print('----------') 197 | t0 = time.time() 198 | idx = adj['idx'].t().numpy() 199 | true_ids = get_edges_ids(idx,tot_nodes) 200 | 201 | true_ids = set(true_ids) 202 | 203 | #the maximum of edges would be all edges that don't exist between nodes that have edges 204 | num_edges = min(number,idx.shape[1] * (idx.shape[1]-1) - len(true_ids)) 205 | 206 | if smart_sampling: 207 | #existing_nodes = existing_nodes.numpy() 208 | def sample_edges(num_edges): 209 | # print('smart_sampling') 210 | from_id = np.random.choice(idx[0],size = num_edges,replace = True) 211 | to_id = np.random.choice(existing_nodes,size = num_edges, replace = True) 212 | #print ('smart_sampling', from_id, to_id) 213 | 214 | if num_edges>1: 215 | edges = np.stack([from_id,to_id]) 216 | else: 217 | edges = np.concatenate([from_id,to_id]) 218 | return edges 219 | else: 220 | def sample_edges(num_edges): 221 | if num_edges > 1: 222 | edges = np.random.randint(0,tot_nodes,(2,num_edges)) 223 | else: 224 | edges = np.random.randint(0,tot_nodes,(2,)) 225 | return edges 226 | 227 | edges = sample_edges(num_edges*4) 228 | 229 | edge_ids = edges[0] * tot_nodes + edges[1] 230 | 231 | out_ids = set() 232 | num_sampled = 0 233 | sampled_indices = [] 234 | for i in range(num_edges*4): 235 | eid = edge_ids[i] 236 | #ignore if any of these conditions happen 237 | if eid in out_ids or edges[0,i] == edges[1,i] or eid in true_ids: 238 | continue 239 | 240 | #add the eid and the index to a list 241 | out_ids.add(eid) 242 | sampled_indices.append(i) 243 | num_sampled += 1 244 | 245 | #if we have sampled enough edges break 246 | if num_sampled >= num_edges: 247 | break 248 | 249 | edges = edges[:,sampled_indices] 250 | edges = torch.tensor(edges).t() 251 | vals = torch.zeros(edges.size(0),dtype = torch.long) 252 | return {'idx': edges, 'vals': vals} 253 | 254 | def get_edges_ids(sp_idx, tot_nodes): 255 | # print(sp_idx) 256 | # print(tot_nodes) 257 | # print(sp_idx[0]*tot_nodes) 258 | return sp_idx[0]*tot_nodes + sp_idx[1] 259 | 260 | 261 | -------------------------------------------------------------------------------- /trainer.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import utils as u 3 | import logger 4 | import time 5 | import pandas as pd 6 | import numpy as np 7 | 8 | class Trainer(): 9 | def __init__(self,args, splitter, gcn, classifier, comp_loss, dataset, num_classes): 10 | self.args = args 11 | self.splitter = splitter 12 | self.tasker = splitter.tasker 13 | self.gcn = gcn 14 | self.classifier = classifier 15 | self.comp_loss = comp_loss 16 | 17 | self.num_nodes = dataset.num_nodes 18 | self.data = dataset 19 | self.num_classes = num_classes 20 | 21 | self.logger = logger.Logger(args, self.num_classes) 22 | 23 | self.init_optimizers(args) 24 | 25 | if self.tasker.is_static: 26 | adj_matrix = u.sparse_prepare_tensor(self.tasker.adj_matrix, torch_size = [self.num_nodes], ignore_batch_dim = False) 27 | self.hist_adj_list = [adj_matrix] 28 | self.hist_ndFeats_list = [self.tasker.nodes_feats.float()] 29 | 30 | def init_optimizers(self,args): 31 | params = self.gcn.parameters() 32 | self.gcn_opt = torch.optim.Adam(params, lr = args.learning_rate) 33 | params = self.classifier.parameters() 34 | self.classifier_opt = torch.optim.Adam(params, lr = args.learning_rate) 35 | self.gcn_opt.zero_grad() 36 | self.classifier_opt.zero_grad() 37 | 38 | def save_checkpoint(self, state, filename='checkpoint.pth.tar'): 39 | torch.save(state, filename) 40 | 41 | def load_checkpoint(self, filename, model): 42 | if os.path.isfile(filename): 43 | print("=> loading checkpoint '{}'".format(filename)) 44 | checkpoint = torch.load(filename) 45 | epoch = checkpoint['epoch'] 46 | self.gcn.load_state_dict(checkpoint['gcn_dict']) 47 | self.classifier.load_state_dict(checkpoint['classifier_dict']) 48 | self.gcn_opt.load_state_dict(checkpoint['gcn_optimizer']) 49 | self.classifier_opt.load_state_dict(checkpoint['classifier_optimizer']) 50 | self.logger.log_str("=> loaded checkpoint '{}' (epoch {})".format(filename, checkpoint['epoch'])) 51 | return epoch 52 | else: 53 | self.logger.log_str("=> no checkpoint found at '{}'".format(filename)) 54 | return 0 55 | 56 | def train(self): 57 | self.tr_step = 0 58 | best_eval_valid = 0 59 | eval_valid = 0 60 | epochs_without_impr = 0 61 | 62 | for e in range(self.args.num_epochs): 63 | eval_train, nodes_embs = self.run_epoch(self.splitter.train, e, 'TRAIN', grad = True) 64 | if len(self.splitter.dev)>0 and e>self.args.eval_after_epochs: 65 | eval_valid, _ = self.run_epoch(self.splitter.dev, e, 'VALID', grad = False) 66 | if eval_valid>best_eval_valid: 67 | best_eval_valid = eval_valid 68 | epochs_without_impr = 0 69 | print ('### w'+str(self.args.rank)+') ep '+str(e)+' - Best valid measure:'+str(eval_valid)) 70 | else: 71 | epochs_without_impr+=1 72 | if epochs_without_impr>self.args.early_stop_patience: 73 | print ('### w'+str(self.args.rank)+') ep '+str(e)+' - Early stop.') 74 | break 75 | 76 | if len(self.splitter.test)>0 and eval_valid==best_eval_valid and e>self.args.eval_after_epochs: 77 | eval_test, _ = self.run_epoch(self.splitter.test, e, 'TEST', grad = False) 78 | 79 | if self.args.save_node_embeddings: 80 | self.save_node_embs_csv(nodes_embs, self.splitter.train_idx, log_file+'_train_nodeembs.csv.gz') 81 | self.save_node_embs_csv(nodes_embs, self.splitter.dev_idx, log_file+'_valid_nodeembs.csv.gz') 82 | self.save_node_embs_csv(nodes_embs, self.splitter.test_idx, log_file+'_test_nodeembs.csv.gz') 83 | 84 | 85 | def run_epoch(self, split, epoch, set_name, grad): 86 | t0 = time.time() 87 | log_interval=999 88 | if set_name=='TEST': 89 | log_interval=1 90 | self.logger.log_epoch_start(epoch, len(split), set_name, minibatch_log_interval=log_interval) 91 | 92 | torch.set_grad_enabled(grad) 93 | for s in split: 94 | if self.tasker.is_static: 95 | s = self.prepare_static_sample(s) 96 | else: 97 | s = self.prepare_sample(s) 98 | 99 | predictions, nodes_embs = self.predict(s.hist_adj_list, 100 | s.hist_ndFeats_list, 101 | s.label_sp['idx'], 102 | s.node_mask_list) 103 | 104 | loss = self.comp_loss(predictions,s.label_sp['vals']) 105 | # print(loss) 106 | if set_name in ['TEST', 'VALID'] and self.args.task == 'link_pred': 107 | self.logger.log_minibatch(predictions, s.label_sp['vals'], loss.detach(), adj = s.label_sp['idx']) 108 | else: 109 | self.logger.log_minibatch(predictions, s.label_sp['vals'], loss.detach()) 110 | if grad: 111 | self.optim_step(loss) 112 | 113 | torch.set_grad_enabled(True) 114 | eval_measure = self.logger.log_epoch_done() 115 | 116 | return eval_measure, nodes_embs 117 | 118 | def predict(self,hist_adj_list,hist_ndFeats_list,node_indices,mask_list): 119 | nodes_embs = self.gcn(hist_adj_list, 120 | hist_ndFeats_list, 121 | mask_list) 122 | 123 | predict_batch_size = 100000 124 | gather_predictions=[] 125 | for i in range(1 +(node_indices.size(1)//predict_batch_size)): 126 | cls_input = self.gather_node_embs(nodes_embs, node_indices[:, i*predict_batch_size:(i+1)*predict_batch_size]) 127 | predictions = self.classifier(cls_input) 128 | gather_predictions.append(predictions) 129 | gather_predictions=torch.cat(gather_predictions, dim=0) 130 | return gather_predictions, nodes_embs 131 | 132 | def gather_node_embs(self,nodes_embs,node_indices): 133 | cls_input = [] 134 | 135 | for node_set in node_indices: 136 | cls_input.append(nodes_embs[node_set]) 137 | return torch.cat(cls_input,dim = 1) 138 | 139 | def optim_step(self,loss): 140 | self.tr_step += 1 141 | loss.backward() 142 | 143 | if self.tr_step % self.args.steps_accum_gradients == 0: 144 | self.gcn_opt.step() 145 | self.classifier_opt.step() 146 | 147 | self.gcn_opt.zero_grad() 148 | self.classifier_opt.zero_grad() 149 | 150 | 151 | def prepare_sample(self,sample): 152 | sample = u.Namespace(sample) 153 | for i,adj in enumerate(sample.hist_adj_list): 154 | adj = u.sparse_prepare_tensor(adj,torch_size = [self.num_nodes]) 155 | sample.hist_adj_list[i] = adj.to(self.args.device) 156 | 157 | nodes = self.tasker.prepare_node_feats(sample.hist_ndFeats_list[i]) 158 | 159 | sample.hist_ndFeats_list[i] = nodes.to(self.args.device) 160 | node_mask = sample.node_mask_list[i] 161 | sample.node_mask_list[i] = node_mask.to(self.args.device).t() #transposed to have same dimensions as scorer 162 | 163 | label_sp = self.ignore_batch_dim(sample.label_sp) 164 | 165 | if self.args.task in ["link_pred", "edge_cls"]: 166 | label_sp['idx'] = label_sp['idx'].to(self.args.device).t() ####### ALDO TO CHECK why there was the .t() -----> because I concatenate embeddings when there are pairs of them, the embeddings are row vectors after the transpose 167 | else: 168 | label_sp['idx'] = label_sp['idx'].to(self.args.device) 169 | 170 | label_sp['vals'] = label_sp['vals'].type(torch.long).to(self.args.device) 171 | sample.label_sp = label_sp 172 | 173 | return sample 174 | 175 | def prepare_static_sample(self,sample): 176 | sample = u.Namespace(sample) 177 | 178 | sample.hist_adj_list = self.hist_adj_list 179 | 180 | sample.hist_ndFeats_list = self.hist_ndFeats_list 181 | 182 | label_sp = {} 183 | label_sp['idx'] = [sample.idx] 184 | label_sp['vals'] = sample.label 185 | sample.label_sp = label_sp 186 | 187 | return sample 188 | 189 | def ignore_batch_dim(self,adj): 190 | if self.args.task in ["link_pred", "edge_cls"]: 191 | adj['idx'] = adj['idx'][0] 192 | adj['vals'] = adj['vals'][0] 193 | return adj 194 | 195 | def save_node_embs_csv(self, nodes_embs, indexes, file_name): 196 | csv_node_embs = [] 197 | for node_id in indexes: 198 | orig_ID = torch.DoubleTensor([self.tasker.data.contID_to_origID[node_id]]) 199 | 200 | csv_node_embs.append(torch.cat((orig_ID,nodes_embs[node_id].double())).detach().numpy()) 201 | 202 | pd.DataFrame(np.array(csv_node_embs)).to_csv(file_name, header=None, index=None, compression='gzip') 203 | #print ('Node embs saved in',file_name) 204 | -------------------------------------------------------------------------------- /uc_irv_mess_dl.py: -------------------------------------------------------------------------------- 1 | import utils as u 2 | import os 3 | 4 | import tarfile 5 | 6 | import torch 7 | 8 | 9 | class Uc_Irvine_Message_Dataset(): 10 | def __init__(self,args): 11 | args.uc_irc_args = u.Namespace(args.uc_irc_args) 12 | 13 | tar_file = os.path.join(args.uc_irc_args.folder, args.uc_irc_args.tar_file) 14 | tar_archive = tarfile.open(tar_file, 'r:bz2') 15 | 16 | self.edges = self.load_edges(args,tar_archive) 17 | 18 | def load_edges(self,args,tar_archive): 19 | data = u.load_data_from_tar(args.uc_irc_args.edges_file, 20 | tar_archive, 21 | starting_line=2, 22 | sep=' ') 23 | cols = u.Namespace({'source': 0, 24 | 'target': 1, 25 | 'weight': 2, 26 | 'time': 3}) 27 | 28 | data = data.long() 29 | 30 | self.num_nodes = int(data[:,[cols.source,cols.target]].max()) 31 | 32 | #first id should be 0 (they are already contiguous) 33 | data[:,[cols.source,cols.target]] -= 1 34 | 35 | #add edges in the other direction (simmetric) 36 | data = torch.cat([data, 37 | data[:,[cols.target, 38 | cols.source, 39 | cols.weight, 40 | cols.time]]], 41 | dim=0) 42 | 43 | data[:,cols.time] = u.aggregate_by_time(data[:,cols.time], 44 | args.uc_irc_args.aggr_time) 45 | 46 | ids = data[:,cols.source] * self.num_nodes + data[:,cols.target] 47 | self.num_non_existing = float(self.num_nodes**2 - ids.unique().size(0)) 48 | 49 | idx = data[:,[cols.source, 50 | cols.target, 51 | cols.time]] 52 | 53 | self.max_time = data[:,cols.time].max() 54 | self.min_time = data[:,cols.time].min() 55 | 56 | 57 | return {'idx': idx, 'vals': torch.ones(idx.size(0))} -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import yaml 3 | import torch 4 | import numpy as np 5 | import time 6 | import random 7 | import math 8 | 9 | def pad_with_last_col(matrix,cols): 10 | out = [matrix] 11 | pad = [matrix[:,[-1]]] * (cols - matrix.size(1)) 12 | out.extend(pad) 13 | return torch.cat(out,dim=1) 14 | 15 | def pad_with_last_val(vect,k): 16 | device = 'cuda' if vect.is_cuda else 'cpu' 17 | pad = torch.ones(k - vect.size(0), 18 | dtype=torch.long, 19 | device = device) * vect[-1] 20 | vect = torch.cat([vect,pad]) 21 | return vect 22 | 23 | 24 | 25 | def sparse_prepare_tensor(tensor,torch_size, ignore_batch_dim = True): 26 | if ignore_batch_dim: 27 | tensor = sp_ignore_batch_dim(tensor) 28 | tensor = make_sparse_tensor(tensor, 29 | tensor_type = 'float', 30 | torch_size = torch_size) 31 | return tensor 32 | 33 | def sp_ignore_batch_dim(tensor_dict): 34 | tensor_dict['idx'] = tensor_dict['idx'][0] 35 | tensor_dict['vals'] = tensor_dict['vals'][0] 36 | return tensor_dict 37 | 38 | def aggregate_by_time(time_vector,time_win_aggr): 39 | time_vector = time_vector - time_vector.min() 40 | time_vector = time_vector // time_win_aggr 41 | return time_vector 42 | 43 | def sort_by_time(data,time_col): 44 | _, sort = torch.sort(data[:,time_col]) 45 | data = data[sort] 46 | return data 47 | 48 | def print_sp_tensor(sp_tensor,size): 49 | print(torch.sparse.FloatTensor(sp_tensor['idx'].t(),sp_tensor['vals'],torch.Size([size,size])).to_dense()) 50 | 51 | def reset_param(t): 52 | stdv = 2. / math.sqrt(t.size(0)) 53 | t.data.uniform_(-stdv,stdv) 54 | 55 | def make_sparse_tensor(adj,tensor_type,torch_size): 56 | if len(torch_size) == 2: 57 | tensor_size = torch.Size(torch_size) 58 | elif len(torch_size) == 1: 59 | tensor_size = torch.Size(torch_size*2) 60 | 61 | if tensor_type == 'float': 62 | test = torch.sparse.FloatTensor(adj['idx'].t(), 63 | adj['vals'].type(torch.float), 64 | tensor_size) 65 | return torch.sparse.FloatTensor(adj['idx'].t(), 66 | adj['vals'].type(torch.float), 67 | tensor_size) 68 | elif tensor_type == 'long': 69 | return torch.sparse.LongTensor(adj['idx'].t(), 70 | adj['vals'].type(torch.long), 71 | tensor_size) 72 | else: 73 | raise NotImplementedError('only make floats or long sparse tensors') 74 | 75 | def sp_to_dict(sp_tensor): 76 | return {'idx': sp_tensor._indices().t(), 77 | 'vals': sp_tensor._values()} 78 | 79 | class Namespace(object): 80 | ''' 81 | helps referencing object in a dictionary as dict.key instead of dict['key'] 82 | ''' 83 | def __init__(self, adict): 84 | self.__dict__.update(adict) 85 | 86 | def set_seeds(rank): 87 | seed = int(time.time())+rank 88 | np.random.seed(seed) 89 | random.seed(seed) 90 | torch.manual_seed(seed) 91 | torch.cuda.manual_seed(seed) 92 | torch.cuda.manual_seed_all(seed) 93 | 94 | 95 | def random_param_value(param, param_min, param_max, type='int'): 96 | if str(param) is None or str(param).lower()=='none': 97 | if type=='int': 98 | return random.randrange(param_min, param_max+1) 99 | elif type=='logscale': 100 | interval=np.logspace(np.log10(param_min), np.log10(param_max), num=100) 101 | return np.random.choice(interval,1)[0] 102 | else: 103 | return random.uniform(param_min, param_max) 104 | else: 105 | return param 106 | 107 | def load_data(file): 108 | with open(file) as file: 109 | file = file.read().splitlines() 110 | data = torch.tensor([[float(r) for r in row.split(',')] for row in file[1:]]) 111 | return data 112 | 113 | def load_data_from_tar(file, tar_archive, replace_unknow=False, starting_line=1, sep=',', type_fn = float, tensor_const = torch.DoubleTensor): 114 | f = tar_archive.extractfile(file) 115 | lines = f.read()# 116 | lines=lines.decode('utf-8') 117 | if replace_unknow: 118 | lines=lines.replace('unknow', '-1') 119 | lines=lines.replace('-1n', '-1') 120 | 121 | lines=lines.splitlines() 122 | 123 | data = [[type_fn(r) for r in row.split(sep)] for row in lines[starting_line:]] 124 | data = tensor_const(data) 125 | #print (file,'data size', data.size()) 126 | return data 127 | 128 | def create_parser(): 129 | parser = argparse.ArgumentParser(formatter_class=argparse.RawTextHelpFormatter) 130 | parser.add_argument('--config_file',default='experiments/parameters_example.yaml', type=argparse.FileType(mode='r'), help='optional, yaml file containing parameters to be used, overrides command line parameters') 131 | return parser 132 | 133 | def parse_args(parser): 134 | args = parser.parse_args() 135 | if args.config_file: 136 | data = yaml.load(args.config_file) 137 | delattr(args, 'config_file') 138 | # print(data) 139 | arg_dict = args.__dict__ 140 | for key, value in data.items(): 141 | arg_dict[key] = value 142 | 143 | args.learning_rate =random_param_value(args.learning_rate, args.learning_rate_min, args.learning_rate_max, type='logscale') 144 | # args.adj_mat_time_window = random_param_value(args.adj_mat_time_window, args.adj_mat_time_window_min, args.adj_mat_time_window_max, type='int') 145 | args.num_hist_steps = random_param_value(args.num_hist_steps, args.num_hist_steps_min, args.num_hist_steps_max, type='int') 146 | args.gcn_parameters['feats_per_node'] =random_param_value(args.gcn_parameters['feats_per_node'], args.gcn_parameters['feats_per_node_min'], args.gcn_parameters['feats_per_node_max'], type='int') 147 | args.gcn_parameters['layer_1_feats'] =random_param_value(args.gcn_parameters['layer_1_feats'], args.gcn_parameters['layer_1_feats_min'], args.gcn_parameters['layer_1_feats_max'], type='int') 148 | if args.gcn_parameters['layer_2_feats_same_as_l1'] or args.gcn_parameters['layer_2_feats_same_as_l1'].lower()=='true': 149 | args.gcn_parameters['layer_2_feats'] = args.gcn_parameters['layer_1_feats'] 150 | else: 151 | args.gcn_parameters['layer_2_feats'] =random_param_value(args.gcn_parameters['layer_2_feats'], args.gcn_parameters['layer_1_feats_min'], args.gcn_parameters['layer_1_feats_max'], type='int') 152 | args.gcn_parameters['lstm_l1_feats'] =random_param_value(args.gcn_parameters['lstm_l1_feats'], args.gcn_parameters['lstm_l1_feats_min'], args.gcn_parameters['lstm_l1_feats_max'], type='int') 153 | if args.gcn_parameters['lstm_l2_feats_same_as_l1'] or args.gcn_parameters['lstm_l2_feats_same_as_l1'].lower()=='true': 154 | args.gcn_parameters['lstm_l2_feats'] = args.gcn_parameters['lstm_l1_feats'] 155 | else: 156 | args.gcn_parameters['lstm_l2_feats'] =random_param_value(args.gcn_parameters['lstm_l2_feats'], args.gcn_parameters['lstm_l1_feats_min'], args.gcn_parameters['lstm_l1_feats_max'], type='int') 157 | args.gcn_parameters['cls_feats'] =random_param_value(args.gcn_parameters['cls_feats'], args.gcn_parameters['cls_feats_min'], args.gcn_parameters['cls_feats_max'], type='int') 158 | 159 | return args 160 | --------------------------------------------------------------------------------