├── LICENSE ├── README.md ├── data_config.json ├── data_loading.py ├── data_util.py ├── env.yml ├── format_kaggle_files.py ├── inference.py ├── main.py ├── model_settings.json ├── models.py ├── train_util.py ├── training.py └── util.py /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright [yyyy] [name of copyright owner] 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Multi-GNN 2 | This repository contains all models and adaptations needed to run Multi-GNN for Anti-Money Laundering. The repository consists of four Graph Neural Network model classes ([GIN](https://arxiv.org/abs/1810.00826), [GAT](https://arxiv.org/abs/1710.10903), [PNA](https://arxiv.org/abs/2004.05718), [RGCN](https://arxiv.org/abs/1703.06103)) and the below-described model adaptations utilized for financial crime detection in [Egressy et al.](https://arxiv.org/abs/2306.11586). Note that this repository solely focuses on the Anti-Money Laundering use case. This repository has been created for experiments in [Provably Powerful Graph Neural Networks for Directed Multigraphs](https://arxiv.org/abs/2306.11586) [AAAI 2024] and [Realistic Synthetic Financial Transactions for Anti-Money Laundering Models](https://arxiv.org/abs/2306.16424) [NeurIPS 2023]. 3 | 4 | ## Setup 5 | To use the repository, you first need to install the conda environment via 6 | ``` 7 | conda env create -f env.yml 8 | ``` 9 | Then, the data needed for the experiments can be found on [Kaggle](https://www.kaggle.com/datasets/ealtman2019/ibm-transactions-for-anti-money-laundering-aml/data). To use this data with the provided training scripts, you first need to perform a pre-processing step for the downloaded transaction files (e.g. `HI-Small_Trans.csv`): 10 | ``` 11 | python format_kaggle_files.py /path/to/kaggle-files/HI-Small_Trans.csv 12 | ``` 13 | Make sure to change the filepaths in the `data_config.json` file. The `aml_data` path should be changed to wherever you stored the `formatted_transactions.csv` file generated by the pre-processing step. 14 | 15 | ## Usage 16 | To run the experiments you need to run the `main.py` function and specify any arguments you want to use. There are two required arguments, namely `--data` and `--model`. For the `--data` argument, make sure you store the different datasets in different folders. Then, specify the folder name, e.g `--data Small_HI`. The `--model` parameter should be set to any of the model classed that are available, i.e. to one of `--model [gin, gat, rgcn, pna]`. Thus, to run a standard GNN, you need to run, e.g.: 17 | ``` 18 | python main.py --data Small_HI --model gin 19 | ``` 20 | Then you can add different adaptations to the models by selecting the respective arguments from: 21 | 22 |
23 | 24 | | Argument | Description | 25 | | -------------- | ---------------------------- | 26 | | `--emlps` | Edge updates via MLPs | 27 | | `--reverse_mp` | Reverse Message Passing | 28 | | `--ego` | Ego ID's to the center nodes | 29 | | `--ports` | Port Numberings for edges | 30 | 31 |
32 | Thus, to run Multi-GIN with edge updates, you would run the following command: 33 | 34 | ``` 35 | python main.py --data Small_HI --model gin --emlps --reverse_mp --ego --ports 36 | ``` 37 | 38 | ## Additional functionalities 39 | There are several arguments that can be set for additional functionality. Here's a list with them: 40 | 41 |
42 | 43 | | Argument | Description | 44 | | -------------- | ---------------------------------------------------------------------------------------------------------------------------------------------------------| 45 | | `--tqdm` | Displays a progress bar during training and inference. | 46 | | `--save_model` | Saves the best model to the specified `model_to_save` path in the `data_config.json` file. Requires argment `--unique_name` to be specified. | 47 | | `--finetune` | Loads a previously trained model (with name given by `--unique_name` and stored in `model_to_load` path in the `data_config.json`) to be finetuned. | 48 | | `--inference` | Loads a previously trained model (with name given by `--unique_name` and stored in `model_to_load` path in the `data_config.json`) to do inference only. | 49 | 50 |
51 | 52 | ## Licence 53 | Apache License 54 | Version 2.0, January 2004 -------------------------------------------------------------------------------- /data_config.json: -------------------------------------------------------------------------------- 1 | { 2 | "paths": { 3 | "aml_data": "/path/to/aml_data", 4 | "model_to_load": "/path/to/model_you_want_to_load (e.g for inference or fine-tuning)", 5 | "model_to_save": "/path/to/model_save_location (where you want to store the model you are going to train)" 6 | } 7 | } -------------------------------------------------------------------------------- /data_loading.py: -------------------------------------------------------------------------------- 1 | import pandas as pd 2 | import numpy as np 3 | import torch 4 | import logging 5 | import itertools 6 | from data_util import GraphData, HeteroData, z_norm, create_hetero_obj 7 | 8 | def get_data(args, data_config): 9 | '''Loads the AML transaction data. 10 | 11 | 1. The data is loaded from the csv and the necessary features are chosen. 12 | 2. The data is split into training, validation and test data. 13 | 3. PyG Data objects are created with the respective data splits. 14 | ''' 15 | 16 | transaction_file = f"{data_config['paths']['aml_data']}/{args.data}/formatted_transactions.csv" #replace this with your path to the respective AML data objects 17 | df_edges = pd.read_csv(transaction_file) 18 | 19 | logging.info(f'Available Edge Features: {df_edges.columns.tolist()}') 20 | 21 | df_edges['Timestamp'] = df_edges['Timestamp'] - df_edges['Timestamp'].min() 22 | 23 | max_n_id = df_edges.loc[:, ['from_id', 'to_id']].to_numpy().max() + 1 24 | df_nodes = pd.DataFrame({'NodeID': np.arange(max_n_id), 'Feature': np.ones(max_n_id)}) 25 | timestamps = torch.Tensor(df_edges['Timestamp'].to_numpy()) 26 | y = torch.LongTensor(df_edges['Is Laundering'].to_numpy()) 27 | 28 | logging.info(f"Illicit ratio = {sum(y)} / {len(y)} = {sum(y) / len(y) * 100:.2f}%") 29 | logging.info(f"Number of nodes (holdings doing transcations) = {df_nodes.shape[0]}") 30 | logging.info(f"Number of transactions = {df_edges.shape[0]}") 31 | 32 | edge_features = ['Timestamp', 'Amount Received', 'Received Currency', 'Payment Format'] 33 | node_features = ['Feature'] 34 | 35 | logging.info(f'Edge features being used: {edge_features}') 36 | logging.info(f'Node features being used: {node_features} ("Feature" is a placeholder feature of all 1s)') 37 | 38 | x = torch.tensor(df_nodes.loc[:, node_features].to_numpy()).float() 39 | edge_index = torch.LongTensor(df_edges.loc[:, ['from_id', 'to_id']].to_numpy().T) 40 | edge_attr = torch.tensor(df_edges.loc[:, edge_features].to_numpy()).float() 41 | 42 | n_days = int(timestamps.max() / (3600 * 24) + 1) 43 | n_samples = y.shape[0] 44 | logging.info(f'number of days and transactions in the data: {n_days} days, {n_samples} transactions') 45 | 46 | #data splitting 47 | daily_irs, weighted_daily_irs, daily_inds, daily_trans = [], [], [], [] #irs = illicit ratios, inds = indices, trans = transactions 48 | for day in range(n_days): 49 | l = day * 24 * 3600 50 | r = (day + 1) * 24 * 3600 51 | day_inds = torch.where((timestamps >= l) & (timestamps < r))[0] 52 | daily_irs.append(y[day_inds].float().mean()) 53 | weighted_daily_irs.append(y[day_inds].float().mean() * day_inds.shape[0] / n_samples) 54 | daily_inds.append(day_inds) 55 | daily_trans.append(day_inds.shape[0]) 56 | 57 | split_per = [0.6, 0.2, 0.2] 58 | daily_totals = np.array(daily_trans) 59 | d_ts = daily_totals 60 | I = list(range(len(d_ts))) 61 | split_scores = dict() 62 | for i,j in itertools.combinations(I, 2): 63 | if j >= i: 64 | split_totals = [d_ts[:i].sum(), d_ts[i:j].sum(), d_ts[j:].sum()] 65 | split_totals_sum = np.sum(split_totals) 66 | split_props = [v/split_totals_sum for v in split_totals] 67 | split_error = [abs(v-t)/t for v,t in zip(split_props, split_per)] 68 | score = max(split_error) #- (split_totals_sum/total) + 1 69 | split_scores[(i,j)] = score 70 | else: 71 | continue 72 | 73 | i,j = min(split_scores, key=split_scores.get) 74 | #split contains a list for each split (train, validation and test) and each list contains the days that are part of the respective split 75 | split = [list(range(i)), list(range(i, j)), list(range(j, len(daily_totals)))] 76 | logging.info(f'Calculate split: {split}') 77 | 78 | #Now, we seperate the transactions based on their indices in the timestamp array 79 | split_inds = {k: [] for k in range(3)} 80 | for i in range(3): 81 | for day in split[i]: 82 | split_inds[i].append(daily_inds[day]) #split_inds contains a list for each split (tr,val,te) which contains the indices of each day seperately 83 | 84 | tr_inds = torch.cat(split_inds[0]) 85 | val_inds = torch.cat(split_inds[1]) 86 | te_inds = torch.cat(split_inds[2]) 87 | 88 | logging.info(f"Total train samples: {tr_inds.shape[0] / y.shape[0] * 100 :.2f}% || IR: " 89 | f"{y[tr_inds].float().mean() * 100 :.2f}% || Train days: {split[0][:5]}") 90 | logging.info(f"Total val samples: {val_inds.shape[0] / y.shape[0] * 100 :.2f}% || IR: " 91 | f"{y[val_inds].float().mean() * 100:.2f}% || Val days: {split[1][:5]}") 92 | logging.info(f"Total test samples: {te_inds.shape[0] / y.shape[0] * 100 :.2f}% || IR: " 93 | f"{y[te_inds].float().mean() * 100:.2f}% || Test days: {split[2][:5]}") 94 | 95 | #Creating the final data objects 96 | tr_x, val_x, te_x = x, x, x 97 | e_tr = tr_inds.numpy() 98 | e_val = np.concatenate([tr_inds, val_inds]) 99 | 100 | tr_edge_index, tr_edge_attr, tr_y, tr_edge_times = edge_index[:,e_tr], edge_attr[e_tr], y[e_tr], timestamps[e_tr] 101 | val_edge_index, val_edge_attr, val_y, val_edge_times = edge_index[:,e_val], edge_attr[e_val], y[e_val], timestamps[e_val] 102 | te_edge_index, te_edge_attr, te_y, te_edge_times = edge_index, edge_attr, y, timestamps 103 | 104 | tr_data = GraphData (x=tr_x, y=tr_y, edge_index=tr_edge_index, edge_attr=tr_edge_attr, timestamps=tr_edge_times ) 105 | val_data = GraphData(x=val_x, y=val_y, edge_index=val_edge_index, edge_attr=val_edge_attr, timestamps=val_edge_times) 106 | te_data = GraphData (x=te_x, y=te_y, edge_index=te_edge_index, edge_attr=te_edge_attr, timestamps=te_edge_times ) 107 | 108 | #Adding ports and time-deltas if applicable 109 | if args.ports: 110 | logging.info(f"Start: adding ports") 111 | tr_data.add_ports() 112 | val_data.add_ports() 113 | te_data.add_ports() 114 | logging.info(f"Done: adding ports") 115 | if args.tds: 116 | logging.info(f"Start: adding time-deltas") 117 | tr_data.add_time_deltas() 118 | val_data.add_time_deltas() 119 | te_data.add_time_deltas() 120 | logging.info(f"Done: adding time-deltas") 121 | 122 | #Normalize data 123 | tr_data.x = val_data.x = te_data.x = z_norm(tr_data.x) 124 | if not args.model == 'rgcn': 125 | tr_data.edge_attr, val_data.edge_attr, te_data.edge_attr = z_norm(tr_data.edge_attr), z_norm(val_data.edge_attr), z_norm(te_data.edge_attr) 126 | else: 127 | tr_data.edge_attr[:, :-1], val_data.edge_attr[:, :-1], te_data.edge_attr[:, :-1] = z_norm(tr_data.edge_attr[:, :-1]), z_norm(val_data.edge_attr[:, :-1]), z_norm(te_data.edge_attr[:, :-1]) 128 | 129 | #Create heterogenous if reverese MP is enabled 130 | #TODO: if I observe wierd behaviour, maybe add .detach.clone() to all torch tensors, but I don't think they're attached to any computation graph just yet 131 | if args.reverse_mp: 132 | tr_data = create_hetero_obj(tr_data.x, tr_data.y, tr_data.edge_index, tr_data.edge_attr, tr_data.timestamps, args) 133 | val_data = create_hetero_obj(val_data.x, val_data.y, val_data.edge_index, val_data.edge_attr, val_data.timestamps, args) 134 | te_data = create_hetero_obj(te_data.x, te_data.y, te_data.edge_index, te_data.edge_attr, te_data.timestamps, args) 135 | 136 | logging.info(f'train data object: {tr_data}') 137 | logging.info(f'validation data object: {val_data}') 138 | logging.info(f'test data object: {te_data}') 139 | 140 | return tr_data, val_data, te_data, tr_inds, val_inds, te_inds 141 | -------------------------------------------------------------------------------- /data_util.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch_geometric.data import Data, HeteroData 3 | from torch_geometric.typing import OptTensor 4 | import numpy as np 5 | 6 | def to_adj_nodes_with_times(data): 7 | num_nodes = data.num_nodes 8 | timestamps = torch.zeros((data.edge_index.shape[1], 1)) if data.timestamps is None else data.timestamps.reshape((-1,1)) 9 | edges = torch.cat((data.edge_index.T, timestamps), dim=1) if not isinstance(data, HeteroData) else torch.cat((data['node', 'to', 'node'].edge_index.T, timestamps), dim=1) 10 | adj_list_out = dict([(i, []) for i in range(num_nodes)]) 11 | adj_list_in = dict([(i, []) for i in range(num_nodes)]) 12 | for u,v,t in edges: 13 | u,v,t = int(u), int(v), int(t) 14 | adj_list_out[u] += [(v, t)] 15 | adj_list_in[v] += [(u, t)] 16 | return adj_list_in, adj_list_out 17 | 18 | def to_adj_edges_with_times(data): 19 | num_nodes = data.num_nodes 20 | timestamps = torch.zeros((data.edge_index.shape[1], 1)) if data.timestamps is None else data.timestamps.reshape((-1,1)) 21 | edges = torch.cat((data.edge_index.T, timestamps), dim=1) 22 | # calculate adjacent edges with times per node 23 | adj_edges_out = dict([(i, []) for i in range(num_nodes)]) 24 | adj_edges_in = dict([(i, []) for i in range(num_nodes)]) 25 | for i, (u,v,t) in enumerate(edges): 26 | u,v,t = int(u), int(v), int(t) 27 | adj_edges_out[u] += [(i, v, t)] 28 | adj_edges_in[v] += [(i, u, t)] 29 | return adj_edges_in, adj_edges_out 30 | 31 | def ports(edge_index, adj_list): 32 | ports = torch.zeros(edge_index.shape[1], 1) 33 | ports_dict = {} 34 | for v, nbs in adj_list.items(): 35 | if len(nbs) < 1: continue 36 | a = np.array(nbs) 37 | a = a[a[:, -1].argsort()] 38 | _, idx = np.unique(a[:,[0]],return_index=True,axis=0) 39 | nbs_unique = a[np.sort(idx)][:,0] 40 | for i, u in enumerate(nbs_unique): 41 | ports_dict[(u,v)] = i 42 | for i, e in enumerate(edge_index.T): 43 | ports[i] = ports_dict[tuple(e.numpy())] 44 | return ports 45 | 46 | def time_deltas(data, adj_edges_list): 47 | time_deltas = torch.zeros(data.edge_index.shape[1], 1) 48 | if data.timestamps is None: 49 | return time_deltas 50 | for v, edges in adj_edges_list.items(): 51 | if len(edges) < 1: continue 52 | a = np.array(edges) 53 | a = a[a[:, -1].argsort()] 54 | a_tds = [0] + [a[i+1,-1] - a[i,-1] for i in range(a.shape[0]-1)] 55 | tds = np.hstack((a[:,0].reshape(-1,1), np.array(a_tds).reshape(-1,1))) 56 | for i,td in tds: 57 | time_deltas[i] = td 58 | return time_deltas 59 | 60 | class GraphData(Data): 61 | '''This is the homogenous graph object we use for GNN training if reverse MP is not enabled''' 62 | def __init__( 63 | self, x: OptTensor = None, edge_index: OptTensor = None, edge_attr: OptTensor = None, y: OptTensor = None, pos: OptTensor = None, 64 | readout: str = 'edge', 65 | num_nodes: int = None, 66 | timestamps: OptTensor = None, 67 | node_timestamps: OptTensor = None, 68 | **kwargs 69 | ): 70 | super().__init__(x, edge_index, edge_attr, y, pos, **kwargs) 71 | self.readout = readout 72 | self.loss_fn = 'ce' 73 | self.num_nodes = int(self.x.shape[0]) 74 | self.node_timestamps = node_timestamps 75 | if timestamps is not None: 76 | self.timestamps = timestamps 77 | elif edge_attr is not None: 78 | self.timestamps = edge_attr[:,0].clone() 79 | else: 80 | self.timestamps = None 81 | 82 | def add_ports(self): 83 | '''Adds port numberings to the edge features''' 84 | reverse_ports = True 85 | adj_list_in, adj_list_out = to_adj_nodes_with_times(self) 86 | in_ports = ports(self.edge_index, adj_list_in) 87 | out_ports = [ports(self.edge_index.flipud(), adj_list_out)] if reverse_ports else [] 88 | self.edge_attr = torch.cat([self.edge_attr, in_ports] + out_ports, dim=1) 89 | return self 90 | 91 | def add_time_deltas(self): 92 | '''Adds time deltas (i.e. the time between subsequent transactions) to the edge features''' 93 | reverse_tds = True 94 | adj_list_in, adj_list_out = to_adj_edges_with_times(self) 95 | in_tds = time_deltas(self, adj_list_in) 96 | out_tds = [time_deltas(self, adj_list_out)] if reverse_tds else [] 97 | self.edge_attr = torch.cat([self.edge_attr, in_tds] + out_tds, dim=1) 98 | return self 99 | 100 | class HeteroGraphData(HeteroData): 101 | '''This is the heterogenous graph object we use for GNN training if reverse MP is enabled''' 102 | def __init__( 103 | self, 104 | readout: str = 'edge', 105 | **kwargs 106 | ): 107 | super().__init__(**kwargs) 108 | self.readout = readout 109 | 110 | @property 111 | def num_nodes(self): 112 | return self['node'].x.shape[0] 113 | 114 | @property 115 | def timestamps(self): 116 | return self['node', 'to', 'node'].timestamps 117 | 118 | def add_ports(self): 119 | '''Adds port numberings to the edge features''' 120 | adj_list_in, adj_list_out = to_adj_nodes_with_times(self) 121 | in_ports = ports(self['node', 'to', 'node'].edge_index, adj_list_in) 122 | out_ports = ports(self['node', 'rev_to', 'node'].edge_index, adj_list_out) 123 | self['node', 'to', 'node'].edge_attr = torch.cat([self['node', 'to', 'node'].edge_attr, in_ports], dim=1) 124 | self['node', 'rev_to', 'node'].edge_attr = torch.cat([self['node', 'rev_to', 'node'].edge_attr, out_ports], dim=1) 125 | return self 126 | 127 | def add_time_deltas(self): 128 | '''Adds time deltas (i.e. the time between subsequent transactions) to the edge features''' 129 | adj_list_in, adj_list_out = to_adj_edges_with_times(self) 130 | in_tds = time_deltas(self, adj_list_in) 131 | out_tds = time_deltas(self, adj_list_out) 132 | self['node', 'to', 'node'].edge_attr = torch.cat([self['node', 'to', 'node'].edge_attr, in_tds], dim=1) 133 | self['node', 'rev_to', 'node'].edge_attr = torch.cat([self['node', 'rev_to', 'node'].edge_attr, out_tds], dim=1) 134 | return self 135 | 136 | def z_norm(data): 137 | std = data.std(0).unsqueeze(0) 138 | std = torch.where(std == 0, torch.tensor(1, dtype=torch.float32).cpu(), std) 139 | return (data - data.mean(0).unsqueeze(0)) / std 140 | 141 | def create_hetero_obj(x, y, edge_index, edge_attr, timestamps, args): 142 | '''Creates a heterogenous graph object for reverse message passing''' 143 | data = HeteroGraphData() 144 | 145 | data['node'].x = x 146 | data['node', 'to', 'node'].edge_index = edge_index 147 | data['node', 'rev_to', 'node'].edge_index = edge_index.flipud() 148 | data['node', 'to', 'node'].edge_attr = edge_attr 149 | data['node', 'rev_to', 'node'].edge_attr = edge_attr 150 | if args.ports: 151 | #swap the in- and outgoing port numberings for the reverse edges 152 | data['node', 'rev_to', 'node'].edge_attr[:, [-1, -2]] = data['node', 'rev_to', 'node'].edge_attr[:, [-2, -1]] 153 | data['node', 'to', 'node'].y = y 154 | data['node', 'to', 'node'].timestamps = timestamps 155 | 156 | return data -------------------------------------------------------------------------------- /env.yml: -------------------------------------------------------------------------------- 1 | name: multignn 2 | channels: 3 | - pytorch 4 | - nvidia 5 | - pyg 6 | - conda-forge 7 | dependencies: 8 | - python=3.9 9 | - pip 10 | - pytorch::pytorch 11 | - pytorch::torchvision 12 | - pytorch::pytorch-cuda=11.8 13 | - cudatoolkit=11.8 14 | - pyg::pyg 15 | - pyg::pytorch-scatter 16 | - bzip2=1.0.8 17 | - ipykernel 18 | - ipython 19 | - ipywidgets 20 | - lz4 21 | - matplotlib 22 | - munch 23 | - numpy>=1.25.2 24 | - pandas=2.0.3 25 | - scikit-learn=1.3.0 26 | - scipy 27 | - tqdm 28 | - wandb 29 | - zstandard 30 | - zstd 31 | - datatable 32 | - pyg::pytorch-sparse 33 | - tabulate=0.9.0 -------------------------------------------------------------------------------- /format_kaggle_files.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import datatable as dt 3 | from datetime import datetime 4 | from datatable import f,join,sort 5 | import sys 6 | import os 7 | 8 | n = len(sys.argv) 9 | 10 | if n == 1: 11 | print("No input path") 12 | sys.exit() 13 | 14 | inPath = sys.argv[1] 15 | outPath = os.path.dirname(inPath) + "/formatted_transactions.csv" 16 | 17 | raw = dt.fread(inPath, columns = dt.str32) 18 | 19 | currency = dict() 20 | paymentFormat = dict() 21 | bankAcc = dict() 22 | account = dict() 23 | 24 | def get_dict_val(name, collection): 25 | if name in collection: 26 | val = collection[name] 27 | else: 28 | val = len(collection) 29 | collection[name] = val 30 | return val 31 | 32 | header = "EdgeID,from_id,to_id,Timestamp,\ 33 | Amount Sent,Sent Currency,Amount Received,Received Currency,\ 34 | Payment Format,Is Laundering\n" 35 | 36 | firstTs = -1 37 | 38 | with open(outPath, 'w') as writer: 39 | writer.write(header) 40 | for i in range(raw.nrows): 41 | datetime_object = datetime.strptime(raw[i,"Timestamp"], '%Y/%m/%d %H:%M') 42 | ts = datetime_object.timestamp() 43 | day = datetime_object.day 44 | month = datetime_object.month 45 | year = datetime_object.year 46 | hour = datetime_object.hour 47 | minute = datetime_object.minute 48 | 49 | if firstTs == -1: 50 | startTime = datetime(year, month, day) 51 | firstTs = startTime.timestamp() - 10 52 | 53 | ts = ts - firstTs 54 | 55 | cur1 = get_dict_val(raw[i,"Receiving Currency"], currency) 56 | cur2 = get_dict_val(raw[i,"Payment Currency"], currency) 57 | 58 | fmt = get_dict_val(raw[i,"Payment Format"], paymentFormat) 59 | 60 | fromAccIdStr = raw[i,"From Bank"] + raw[i,2] 61 | fromId = get_dict_val(fromAccIdStr, account) 62 | 63 | toAccIdStr = raw[i,"To Bank"] + raw[i,4] 64 | toId = get_dict_val(toAccIdStr, account) 65 | 66 | amountReceivedOrig = float(raw[i,"Amount Received"]) 67 | amountPaidOrig = float(raw[i,"Amount Paid"]) 68 | 69 | isl = int(raw[i,"Is Laundering"]) 70 | 71 | line = '%d,%d,%d,%d,%f,%d,%f,%d,%d,%d\n' % \ 72 | (i,fromId,toId,ts,amountPaidOrig,cur2, amountReceivedOrig,cur1,fmt,isl) 73 | 74 | writer.write(line) 75 | 76 | formatted = dt.fread(outPath) 77 | formatted = formatted[:,:,sort(3)] 78 | 79 | formatted.to_csv(outPath) -------------------------------------------------------------------------------- /inference.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import pandas as pd 3 | from train_util import AddEgoIds, extract_param, add_arange_ids, get_loaders, evaluate_homo, evaluate_hetero 4 | from training import get_model 5 | from torch_geometric.nn import to_hetero, summary 6 | import wandb 7 | import logging 8 | import os 9 | import sys 10 | import time 11 | 12 | script_start = time.time() 13 | 14 | def infer_gnn(tr_data, val_data, te_data, tr_inds, val_inds, te_inds, args, data_config): 15 | #set device 16 | device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") 17 | 18 | #define a model config dictionary and wandb logging at the same time 19 | wandb.init( 20 | mode="disabled" if args.testing else "online", 21 | project="your_proj_name", 22 | 23 | config={ 24 | "epochs": args.n_epochs, 25 | "batch_size": args.batch_size, 26 | "model": args.model, 27 | "data": args.data, 28 | "num_neighbors": args.num_neighs, 29 | "lr": extract_param("lr", args), 30 | "n_hidden": extract_param("n_hidden", args), 31 | "n_gnn_layers": extract_param("n_gnn_layers", args), 32 | "loss": "ce", 33 | "w_ce1": extract_param("w_ce1", args), 34 | "w_ce2": extract_param("w_ce2", args), 35 | "dropout": extract_param("dropout", args), 36 | "final_dropout": extract_param("final_dropout", args), 37 | "n_heads": extract_param("n_heads", args) if args.model == 'gat' else None 38 | } 39 | ) 40 | 41 | config = wandb.config 42 | 43 | #set the transform if ego ids should be used 44 | if args.ego: 45 | transform = AddEgoIds() 46 | else: 47 | transform = None 48 | 49 | #add the unique ids to later find the seed edges 50 | add_arange_ids([tr_data, val_data, te_data]) 51 | 52 | tr_loader, val_loader, te_loader = get_loaders(tr_data, val_data, te_data, tr_inds, val_inds, te_inds, transform, args) 53 | 54 | #get the model 55 | sample_batch = next(iter(tr_loader)) 56 | model = get_model(sample_batch, config, args) 57 | 58 | if args.reverse_mp: 59 | model = to_hetero(model, te_data.metadata(), aggr='mean') 60 | 61 | if not (args.avg_tps or args.finetune): 62 | command = " ".join(sys.argv) 63 | name = "" 64 | name = '-'.join(name.split('-')[3:]) 65 | args.unique_name = name 66 | 67 | logging.info("=> loading model checkpoint") 68 | checkpoint = torch.load(f'{data_config["paths"]["model_to_load"]}/checkpoint_{args.unique_name}.tar') 69 | start_epoch = checkpoint['epoch'] 70 | model.load_state_dict(checkpoint['model_state_dict']) 71 | model.to(device) 72 | 73 | logging.info("=> loaded checkpoint (epoch {})".format(start_epoch)) 74 | 75 | if not args.reverse_mp: 76 | te_f1, te_prec, te_rec = evaluate_homo(te_loader, te_inds, model, te_data, device, args, precrec=True) 77 | else: 78 | te_f1, te_prec, te_rec = evaluate_hetero(te_loader, te_inds, model, te_data, device, args, precrec=True) 79 | 80 | wandb.finish() -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | import time 2 | import logging 3 | from util import create_parser, set_seed, logger_setup 4 | from data_loading import get_data 5 | from training import train_gnn 6 | from inference import infer_gnn 7 | import json 8 | 9 | def main(): 10 | parser = create_parser() 11 | args = parser.parse_args() 12 | 13 | with open('data_config.json', 'r') as config_file: 14 | data_config = json.load(config_file) 15 | 16 | # Setup logging 17 | logger_setup() 18 | 19 | #set seed 20 | set_seed(args.seed) 21 | 22 | #get data 23 | logging.info("Retrieving data") 24 | t1 = time.perf_counter() 25 | 26 | tr_data, val_data, te_data, tr_inds, val_inds, te_inds = get_data(args, data_config) 27 | 28 | t2 = time.perf_counter() 29 | logging.info(f"Retrieved data in {t2-t1:.2f}s") 30 | 31 | if not args.inference: 32 | #Training 33 | logging.info(f"Running Training") 34 | train_gnn(tr_data, val_data, te_data, tr_inds, val_inds, te_inds, args, data_config) 35 | else: 36 | #Inference 37 | logging.info(f"Running Inference") 38 | infer_gnn(tr_data, val_data, te_data, tr_inds, val_inds, te_inds, args, data_config) 39 | 40 | if __name__ == "__main__": 41 | main() 42 | -------------------------------------------------------------------------------- /model_settings.json: -------------------------------------------------------------------------------- 1 | { 2 | "gin": { 3 | "params": { 4 | "lr": 0.006213266113989207, "n_hidden": 66.00315515631006, "n_mlp_layers": 1, "n_gnn_layers": 2, "loss": "ce", 5 | "w_ce1": 1.0000182882773443, "w_ce2": 6.275014431494497, "norm_method": "z_normalize", "dropout": 0.00983468338330501, "final_dropout": 0.10527690625126304 6 | }, 7 | "bayes_opt_params": { 8 | "lr": [0.002, 0.007], "n_hidden": [66.0, 66.01], "n_mlp_layers": [1, 1.001], "n_gnn_layers": [2.0, 2.001], "loss": [0.0, 0.1], 9 | "w_ce1": [1, 1.001], "w_ce2": [6, 12], "norm_method": [0, 0.001], "dropout": [0, 0.05], "final_dropout": [0, 0.2] 10 | }, 11 | "header": "run,tb,lr,n_hidden,n_mlp_layers,n_gnn_layers,loss,w_ce1,w_ce2,norm_method,dropout,final_dropout,epoch,tr_acc,tr_prec,tr_rec,tr_f1,tr_auc,val_acc,val_prec,val_rec,val_f1,val_auc,te_acc,te_prec,te_rec,te_f1,te_auc\n" 12 | }, 13 | "pna": { 14 | "params": { 15 | "lr": 0.0006116418195373612, "n_hidden": 20, "n_mlp_layers": 1, "n_gnn_layers": 2, "loss": "ce", "w_ce1": 1.0003967674742307, 16 | "w_ce2": 7.077633468006714, "norm_method": "z_normalize", "dropout": 0.08340440094051481, "final_dropout": 0.28812979737686323 17 | }, 18 | "bayes_opt_params": { 19 | "lr": [0.0001, 0.001], "n_hidden": [16, 64], "n_mlp_layers": [1, 1.001], "n_gnn_layers": [2.00, 2.01], "loss": [0.0, 0.1], 20 | "w_ce1": [1, 1.001], "w_ce2": [6, 12], "norm_method": [0, 0.1], "dropout": [0.0, 0.2], "final_dropout": [0.0, 0.4] 21 | }, 22 | "header": "run,tb,lr,n_hidden,n_mlp_layers,n_gnn_layers,loss,w_ce1,w_ce2,norm_method,dropout,final_dropout,epoch,tr_acc,tr_prec,tr_rec,tr_f1,tr_auc,val_acc,val_prec,val_rec,val_f1,val_auc,te_acc,te_prec,te_rec,te_f1,te_auc\n" 23 | }, 24 | "gat": { 25 | "params": { 26 | "lr": 0.006, "n_hidden": 64, "n_heads": 4, "n_mlp_layers": 1, "n_gnn_layers": 2, "loss": "ce", "w_ce1": 1, "w_ce2": 6, 27 | "norm_method": "z_normalize", "dropout": 0.009, "final_dropout": 0.1 28 | }, 29 | "bayes_opt_params": { 30 | "lr": [0.01, 0.04], "n_hidden": [4, 24], "n_heads": [1.5, 4.5], "n_mlp_layers": [1, 1.001], "n_gnn_layers": [3, 7], 31 | "loss": [0, 0.1], "w_ce1": [1, 1.001], "w_ce2": [1, 10], "norm_method": [0, 0.1], "dropout": [0, 0.5], "final_dropout": [0, 0.8] 32 | }, 33 | "header": "run,tb,lr,n_hidden,n_heads,n_mlp_layers,n_gnn_layers,loss,w_ce1,w_ce2,norm_method,dropout,final_dropout,epoch,tr_acc,tr_prec,tr_rec,tr_f1,tr_auc,val_acc,val_prec,val_rec,val_f1,val_auc,te_acc,te_prec,te_rec,te_f1,te_auc\n" 34 | }, 35 | "mlp": { 36 | "params": { 37 | "lr": 0.006213266113989207, "n_hidden": 66.00315515631006, "n_mlp_layers": 1, "n_gnn_layers": 2, "loss": "ce", "w_ce1": 1.0000182882773443, 38 | "w_ce2": 9.23, "norm_method": "z_normalize", "dropout": 0.00983468338330501, "final_dropout": 0.10527690625126304 39 | }, 40 | "bayes_opt_params": { 41 | "lr": [0.006, 0.0064], "n_hidden": [66.0, 66.01], "n_mlp_layers": [1, 1.001], "n_gnn_layers": [2.0, 2.001], "loss": [0.0, 0.1], 42 | "w_ce1": [1, 1.001], "w_ce2": [6, 12], "norm_method": [0, 0.001], "dropout": [0, 0.05], "final_dropout": [0, 0.2] 43 | }, 44 | "header": "run,tb,lr,n_hidden,n_mlp_layers,n_gnn_layers,loss,w_ce1,w_ce2,norm_method,dropout,final_dropout,epoch,tr_acc,tr_prec,tr_rec,tr_f1,tr_auc,val_acc,val_prec,val_rec,val_f1,val_auc,te_acc,te_prec,te_rec,te_f1,te_auc\n" 45 | }, 46 | "rgcn": { 47 | "params": { 48 | "lr": 0.006213266113989207, "n_hidden": 66.00315515631006, "n_mlp_layers": 1, "n_gnn_layers": 2, "loss": "ce", "w_ce1": 1.0000182882773443, 49 | "w_ce2": 9.23, "norm_method": "z_normalize", "dropout": 0.00983468338330501, "final_dropout": 0.10527690625126304 50 | }, 51 | "bayes_opt_params": { 52 | "lr": [0.006, 0.0064], "n_hidden": [66.0, 66.01], "n_mlp_layers": [1, 1.001], "n_gnn_layers": [2.0, 2.001], "loss": [0.0, 0.1], 53 | "w_ce1": [1, 1.001], "w_ce2": [6, 12], "norm_method": [0, 0.001], "dropout": [0, 0.05], "final_dropout": [0, 0.2] 54 | }, 55 | "header": "run,tb,lr,n_hidden,n_mlp_layers,n_gnn_layers,loss,w_ce1,w_ce2,norm_method,dropout,final_dropout,epoch,tr_acc,tr_prec,tr_rec,tr_f1,tr_auc,val_acc,val_prec,val_rec,val_f1,val_auc,te_acc,te_prec,te_rec,te_f1,te_auc\n" 56 | } 57 | } -------------------------------------------------------------------------------- /models.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | from torch_geometric.nn import GINEConv, BatchNorm, Linear, GATConv, PNAConv, RGCNConv 3 | import torch.nn.functional as F 4 | import torch 5 | import logging 6 | 7 | class GINe(torch.nn.Module): 8 | def __init__(self, num_features, num_gnn_layers, n_classes=2, 9 | n_hidden=100, edge_updates=False, residual=True, 10 | edge_dim=None, dropout=0.0, final_dropout=0.5): 11 | super().__init__() 12 | self.n_hidden = n_hidden 13 | self.num_gnn_layers = num_gnn_layers 14 | self.edge_updates = edge_updates 15 | self.final_dropout = final_dropout 16 | 17 | self.node_emb = nn.Linear(num_features, n_hidden) 18 | self.edge_emb = nn.Linear(edge_dim, n_hidden) 19 | 20 | self.convs = nn.ModuleList() 21 | self.emlps = nn.ModuleList() 22 | self.batch_norms = nn.ModuleList() 23 | for _ in range(self.num_gnn_layers): 24 | conv = GINEConv(nn.Sequential( 25 | nn.Linear(self.n_hidden, self.n_hidden), 26 | nn.ReLU(), 27 | nn.Linear(self.n_hidden, self.n_hidden) 28 | ), edge_dim=self.n_hidden) 29 | if self.edge_updates: self.emlps.append(nn.Sequential( 30 | nn.Linear(3 * self.n_hidden, self.n_hidden), 31 | nn.ReLU(), 32 | nn.Linear(self.n_hidden, self.n_hidden), 33 | )) 34 | self.convs.append(conv) 35 | self.batch_norms.append(BatchNorm(n_hidden)) 36 | 37 | self.mlp = nn.Sequential(Linear(n_hidden*3, 50), nn.ReLU(), nn.Dropout(self.final_dropout),Linear(50, 25), nn.ReLU(), nn.Dropout(self.final_dropout), 38 | Linear(25, n_classes)) 39 | 40 | def forward(self, x, edge_index, edge_attr): 41 | src, dst = edge_index 42 | 43 | x = self.node_emb(x) 44 | edge_attr = self.edge_emb(edge_attr) 45 | 46 | for i in range(self.num_gnn_layers): 47 | x = (x + F.relu(self.batch_norms[i](self.convs[i](x, edge_index, edge_attr)))) / 2 48 | if self.edge_updates: 49 | edge_attr = edge_attr + self.emlps[i](torch.cat([x[src], x[dst], edge_attr], dim=-1)) / 2 50 | 51 | x = x[edge_index.T].reshape(-1, 2 * self.n_hidden).relu() 52 | x = torch.cat((x, edge_attr.view(-1, edge_attr.shape[1])), 1) 53 | out = x 54 | 55 | return self.mlp(out) 56 | 57 | class GATe(torch.nn.Module): 58 | def __init__(self, num_features, num_gnn_layers, n_classes=2, n_hidden=100, n_heads=4, edge_updates=False, edge_dim=None, dropout=0.0, final_dropout=0.5): 59 | super().__init__() 60 | # GAT specific code 61 | tmp_out = n_hidden // n_heads 62 | n_hidden = tmp_out * n_heads 63 | 64 | self.n_hidden = n_hidden 65 | self.n_heads = n_heads 66 | self.num_gnn_layers = num_gnn_layers 67 | self.edge_updates = edge_updates 68 | self.dropout = dropout 69 | self.final_dropout = final_dropout 70 | 71 | self.node_emb = nn.Linear(num_features, n_hidden) 72 | self.edge_emb = nn.Linear(edge_dim, n_hidden) 73 | 74 | self.convs = nn.ModuleList() 75 | self.emlps = nn.ModuleList() 76 | self.batch_norms = nn.ModuleList() 77 | 78 | for _ in range(self.num_gnn_layers): 79 | conv = GATConv(self.n_hidden, tmp_out, self.n_heads, concat = True, dropout = self.dropout, add_self_loops = True, edge_dim=self.n_hidden) 80 | if self.edge_updates: self.emlps.append(nn.Sequential(nn.Linear(3 * self.n_hidden, self.n_hidden),nn.ReLU(),nn.Linear(self.n_hidden, self.n_hidden),)) 81 | self.convs.append(conv) 82 | self.batch_norms.append(BatchNorm(n_hidden)) 83 | 84 | self.mlp = nn.Sequential(Linear(n_hidden*3, 50), nn.ReLU(), nn.Dropout(self.final_dropout),Linear(50, 25), nn.ReLU(), nn.Dropout(self.final_dropout),Linear(25, n_classes)) 85 | 86 | def forward(self, x, edge_index, edge_attr): 87 | src, dst = edge_index 88 | 89 | x = self.node_emb(x) 90 | edge_attr = self.edge_emb(edge_attr) 91 | 92 | for i in range(self.num_gnn_layers): 93 | x = (x + F.relu(self.batch_norms[i](self.convs[i](x, edge_index, edge_attr)))) / 2 94 | if self.edge_updates: 95 | edge_attr = edge_attr + self.emlps[i](torch.cat([x[src], x[dst], edge_attr], dim=-1)) / 2 96 | 97 | logging.debug(f"x.shape = {x.shape}, x[edge_index.T].shape = {x[edge_index.T].shape}") 98 | x = x[edge_index.T].reshape(-1, 2 * self.n_hidden).relu() 99 | logging.debug(f"x.shape = {x.shape}") 100 | x = torch.cat((x, edge_attr.view(-1, edge_attr.shape[1])), 1) 101 | logging.debug(f"x.shape = {x.shape}") 102 | out = x 103 | 104 | return self.mlp(out) 105 | 106 | class PNA(torch.nn.Module): 107 | def __init__(self, num_features, num_gnn_layers, n_classes=2, 108 | n_hidden=100, edge_updates=True, 109 | edge_dim=None, dropout=0.0, final_dropout=0.5, deg=None): 110 | super().__init__() 111 | n_hidden = int((n_hidden // 5) * 5) 112 | self.n_hidden = n_hidden 113 | self.num_gnn_layers = num_gnn_layers 114 | self.edge_updates = edge_updates 115 | self.final_dropout = final_dropout 116 | 117 | aggregators = ['mean', 'min', 'max', 'std'] 118 | scalers = ['identity', 'amplification', 'attenuation'] 119 | 120 | self.node_emb = nn.Linear(num_features, n_hidden) 121 | self.edge_emb = nn.Linear(edge_dim, n_hidden) 122 | 123 | self.convs = nn.ModuleList() 124 | self.emlps = nn.ModuleList() 125 | self.batch_norms = nn.ModuleList() 126 | for _ in range(self.num_gnn_layers): 127 | conv = PNAConv(in_channels=n_hidden, out_channels=n_hidden, 128 | aggregators=aggregators, scalers=scalers, deg=deg, 129 | edge_dim=n_hidden, towers=5, pre_layers=1, post_layers=1, 130 | divide_input=False) 131 | if self.edge_updates: self.emlps.append(nn.Sequential( 132 | nn.Linear(3 * self.n_hidden, self.n_hidden), 133 | nn.ReLU(), 134 | nn.Linear(self.n_hidden, self.n_hidden), 135 | )) 136 | self.convs.append(conv) 137 | self.batch_norms.append(BatchNorm(n_hidden)) 138 | 139 | self.mlp = nn.Sequential(Linear(n_hidden*3, 50), nn.ReLU(), nn.Dropout(self.final_dropout),Linear(50, 25), nn.ReLU(), nn.Dropout(self.final_dropout), 140 | Linear(25, n_classes)) 141 | 142 | def forward(self, x, edge_index, edge_attr): 143 | src, dst = edge_index 144 | 145 | x = self.node_emb(x) 146 | edge_attr = self.edge_emb(edge_attr) 147 | 148 | for i in range(self.num_gnn_layers): 149 | x = (x + F.relu(self.batch_norms[i](self.convs[i](x, edge_index, edge_attr)))) / 2 150 | if self.edge_updates: 151 | edge_attr = edge_attr + self.emlps[i](torch.cat([x[src], x[dst], edge_attr], dim=-1)) / 2 152 | 153 | logging.debug(f"x.shape = {x.shape}, x[edge_index.T].shape = {x[edge_index.T].shape}") 154 | x = x[edge_index.T].reshape(-1, 2 * self.n_hidden).relu() 155 | logging.debug(f"x.shape = {x.shape}") 156 | x = torch.cat((x, edge_attr.view(-1, edge_attr.shape[1])), 1) 157 | logging.debug(f"x.shape = {x.shape}") 158 | out = x 159 | return self.mlp(out) 160 | 161 | class RGCN(nn.Module): 162 | def __init__(self, num_features, edge_dim, num_relations, num_gnn_layers, n_classes=2, 163 | n_hidden=100, edge_update=False, 164 | residual=True, 165 | dropout=0.0, final_dropout=0.5, n_bases=-1): 166 | super(RGCN, self).__init__() 167 | 168 | self.num_features = num_features 169 | self.num_gnn_layers = num_gnn_layers 170 | self.n_hidden = n_hidden 171 | self.residual = residual 172 | self.dropout = dropout 173 | self.final_dropout = final_dropout 174 | self.n_classes = n_classes 175 | self.edge_update = edge_update 176 | self.num_relations = num_relations 177 | self.n_bases = n_bases 178 | 179 | self.node_emb = nn.Linear(num_features, n_hidden) 180 | self.edge_emb = nn.Linear(edge_dim, n_hidden) 181 | 182 | self.convs = nn.ModuleList() 183 | self.bns = nn.ModuleList() 184 | self.mlp = nn.ModuleList() 185 | 186 | if self.edge_update: 187 | self.emlps = nn.ModuleList() 188 | self.emlps.append(nn.Sequential( 189 | nn.Linear(3 * self.n_hidden, self.n_hidden), 190 | nn.ReLU(), 191 | nn.Linear(self.n_hidden, self.n_hidden), 192 | )) 193 | 194 | for _ in range(self.num_gnn_layers): 195 | conv = RGCNConv(self.n_hidden, self.n_hidden, num_relations, num_bases=self.n_bases) 196 | self.convs.append(conv) 197 | self.bns.append(nn.BatchNorm1d(self.n_hidden)) 198 | 199 | if self.edge_update: 200 | self.emlps.append(nn.Sequential( 201 | nn.Linear(3 * self.n_hidden, self.n_hidden), 202 | nn.ReLU(), 203 | nn.Linear(self.n_hidden, self.n_hidden), 204 | )) 205 | 206 | self.mlp = nn.Sequential(Linear(n_hidden*3, 50), nn.ReLU(), nn.Dropout(self.final_dropout), Linear(50, 25), nn.ReLU(), nn.Dropout(self.final_dropout), 207 | Linear(25, n_classes)) 208 | 209 | def reset_parameters(self): 210 | for m in self.modules(): 211 | if isinstance(m, nn.Linear): 212 | m.reset_parameters() 213 | elif isinstance(m, RGCNConv): 214 | m.reset_parameters() 215 | elif isinstance(m, nn.BatchNorm1d): 216 | m.reset_parameters() 217 | 218 | def forward(self, x, edge_index, edge_attr): 219 | edge_type = edge_attr[:, -1].long() 220 | #edge_attr = edge_attr[:, :-1] 221 | src, dst = edge_index 222 | 223 | x = self.node_emb(x) 224 | edge_attr = self.edge_emb(edge_attr) 225 | 226 | for i in range(self.num_gnn_layers): 227 | x = (x + F.relu(self.bns[i](self.convs[i](x, edge_index, edge_type)))) / 2 228 | if self.edge_update: 229 | edge_attr = (edge_attr + F.relu(self.emlps[i](torch.cat([x[src], x[dst], edge_attr], dim=-1)))) / 2 230 | 231 | x = x[edge_index.T].reshape(-1, 2 * self.n_hidden).relu() 232 | x = torch.cat((x, edge_attr.view(-1, edge_attr.shape[1])), 1) 233 | x = self.mlp(x) 234 | out = x 235 | 236 | return x -------------------------------------------------------------------------------- /train_util.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import tqdm 3 | from torch_geometric.transforms import BaseTransform 4 | from typing import Union 5 | from torch_geometric.data import Data, HeteroData 6 | from torch_geometric.loader import LinkNeighborLoader 7 | from sklearn.metrics import f1_score 8 | import json 9 | 10 | class AddEgoIds(BaseTransform): 11 | r"""Add IDs to the centre nodes of the batch. 12 | """ 13 | def __init__(self): 14 | pass 15 | 16 | def __call__(self, data: Union[Data, HeteroData]): 17 | x = data.x if not isinstance(data, HeteroData) else data['node'].x 18 | device = x.device 19 | ids = torch.zeros((x.shape[0], 1), device=device) 20 | if not isinstance(data, HeteroData): 21 | nodes = torch.unique(data.edge_label_index.view(-1)).to(device) 22 | else: 23 | nodes = torch.unique(data['node', 'to', 'node'].edge_label_index.view(-1)).to(device) 24 | ids[nodes] = 1 25 | if not isinstance(data, HeteroData): 26 | data.x = torch.cat([x, ids], dim=1) 27 | else: 28 | data['node'].x = torch.cat([x, ids], dim=1) 29 | 30 | return data 31 | 32 | def extract_param(parameter_name: str, args) -> float: 33 | """ 34 | Extract the value of the specified parameter for the given model. 35 | 36 | Args: 37 | - parameter_name (str): Name of the parameter (e.g., "lr"). 38 | - args (argparser): Arguments given to this specific run. 39 | 40 | Returns: 41 | - float: Value of the specified parameter. 42 | """ 43 | file_path = './model_settings.json' 44 | with open(file_path, "r") as file: 45 | data = json.load(file) 46 | 47 | return data.get(args.model, {}).get("params", {}).get(parameter_name, None) 48 | 49 | def add_arange_ids(data_list): 50 | ''' 51 | Add the index as an id to the edge features to find seed edges in training, validation and testing. 52 | 53 | Args: 54 | - data_list (str): List of tr_data, val_data and te_data. 55 | ''' 56 | for data in data_list: 57 | if isinstance(data, HeteroData): 58 | data['node', 'to', 'node'].edge_attr = torch.cat([torch.arange(data['node', 'to', 'node'].edge_attr.shape[0]).view(-1, 1), data['node', 'to', 'node'].edge_attr], dim=1) 59 | offset = data['node', 'to', 'node'].edge_attr.shape[0] 60 | data['node', 'rev_to', 'node'].edge_attr = torch.cat([torch.arange(offset, data['node', 'rev_to', 'node'].edge_attr.shape[0] + offset).view(-1, 1), data['node', 'rev_to', 'node'].edge_attr], dim=1) 61 | else: 62 | data.edge_attr = torch.cat([torch.arange(data.edge_attr.shape[0]).view(-1, 1), data.edge_attr], dim=1) 63 | 64 | def get_loaders(tr_data, val_data, te_data, tr_inds, val_inds, te_inds, transform, args): 65 | if isinstance(tr_data, HeteroData): 66 | tr_edge_label_index = tr_data['node', 'to', 'node'].edge_index 67 | tr_edge_label = tr_data['node', 'to', 'node'].y 68 | 69 | 70 | tr_loader = LinkNeighborLoader(tr_data, num_neighbors=args.num_neighs, 71 | edge_label_index=(('node', 'to', 'node'), tr_edge_label_index), 72 | edge_label=tr_edge_label, batch_size=args.batch_size, shuffle=True, transform=transform) 73 | 74 | val_edge_label_index = val_data['node', 'to', 'node'].edge_index[:,val_inds] 75 | val_edge_label = val_data['node', 'to', 'node'].y[val_inds] 76 | 77 | 78 | val_loader = LinkNeighborLoader(val_data, num_neighbors=args.num_neighs, 79 | edge_label_index=(('node', 'to', 'node'), val_edge_label_index), 80 | edge_label=val_edge_label, batch_size=args.batch_size, shuffle=False, transform=transform) 81 | 82 | te_edge_label_index = te_data['node', 'to', 'node'].edge_index[:,te_inds] 83 | te_edge_label = te_data['node', 'to', 'node'].y[te_inds] 84 | 85 | 86 | te_loader = LinkNeighborLoader(te_data, num_neighbors=args.num_neighs, 87 | edge_label_index=(('node', 'to', 'node'), te_edge_label_index), 88 | edge_label=te_edge_label, batch_size=args.batch_size, shuffle=False, transform=transform) 89 | else: 90 | tr_loader = LinkNeighborLoader(tr_data, num_neighbors=args.num_neighs, batch_size=args.batch_size, shuffle=True, transform=transform) 91 | val_loader = LinkNeighborLoader(val_data,num_neighbors=args.num_neighs, edge_label_index=val_data.edge_index[:, val_inds], 92 | edge_label=val_data.y[val_inds], batch_size=args.batch_size, shuffle=False, transform=transform) 93 | te_loader = LinkNeighborLoader(te_data,num_neighbors=args.num_neighs, edge_label_index=te_data.edge_index[:, te_inds], 94 | edge_label=te_data.y[te_inds], batch_size=args.batch_size, shuffle=False, transform=transform) 95 | 96 | return tr_loader, val_loader, te_loader 97 | 98 | @torch.no_grad() 99 | def evaluate_homo(loader, inds, model, data, device, args): 100 | '''Evaluates the model performane for homogenous graph data.''' 101 | preds = [] 102 | ground_truths = [] 103 | for batch in tqdm.tqdm(loader, disable=not args.tqdm): 104 | #select the seed edges from which the batch was created 105 | inds = inds.detach().cpu() 106 | batch_edge_inds = inds[batch.input_id.detach().cpu()] 107 | batch_edge_ids = loader.data.edge_attr.detach().cpu()[batch_edge_inds, 0] 108 | mask = torch.isin(batch.edge_attr[:, 0].detach().cpu(), batch_edge_ids) 109 | 110 | #add the seed edges that have not been sampled to the batch 111 | missing = ~torch.isin(batch_edge_ids, batch.edge_attr[:, 0].detach().cpu()) 112 | 113 | if missing.sum() != 0 and (args.data == 'Small_J' or args.data == 'Small_Q'): 114 | missing_ids = batch_edge_ids[missing].int() 115 | n_ids = batch.n_id 116 | add_edge_index = data.edge_index[:, missing_ids].detach().clone() 117 | node_mapping = {value.item(): idx for idx, value in enumerate(n_ids)} 118 | add_edge_index = torch.tensor([[node_mapping[val.item()] for val in row] for row in add_edge_index]) 119 | add_edge_attr = data.edge_attr[missing_ids, :].detach().clone() 120 | add_y = data.y[missing_ids].detach().clone() 121 | 122 | batch.edge_index = torch.cat((batch.edge_index, add_edge_index), 1) 123 | batch.edge_attr = torch.cat((batch.edge_attr, add_edge_attr), 0) 124 | batch.y = torch.cat((batch.y, add_y), 0) 125 | 126 | mask = torch.cat((mask, torch.ones(add_y.shape[0], dtype=torch.bool))) 127 | 128 | #remove the unique edge id from the edge features, as it's no longer needed 129 | batch.edge_attr = batch.edge_attr[:, 1:] 130 | 131 | with torch.no_grad(): 132 | batch.to(device) 133 | out = model(batch.x, batch.edge_index, batch.edge_attr) 134 | out = out[mask] 135 | pred = out.argmax(dim=-1) 136 | preds.append(pred) 137 | ground_truths.append(batch.y[mask]) 138 | pred = torch.cat(preds, dim=0).cpu().numpy() 139 | ground_truth = torch.cat(ground_truths, dim=0).cpu().numpy() 140 | f1 = f1_score(ground_truth, pred) 141 | 142 | return f1 143 | 144 | @torch.no_grad() 145 | def evaluate_hetero(loader, inds, model, data, device, args): 146 | '''Evaluates the model performane for heterogenous graph data.''' 147 | preds = [] 148 | ground_truths = [] 149 | for batch in tqdm.tqdm(loader, disable=not args.tqdm): 150 | #select the seed edges from which the batch was created 151 | inds = inds.detach().cpu() 152 | batch_edge_inds = inds[batch['node', 'to', 'node'].input_id.detach().cpu()] 153 | batch_edge_ids = loader.data['node', 'to', 'node'].edge_attr.detach().cpu()[batch_edge_inds, 0] 154 | mask = torch.isin(batch['node', 'to', 'node'].edge_attr[:, 0].detach().cpu(), batch_edge_ids) 155 | 156 | #add the seed edges that have not been sampled to the batch 157 | missing = ~torch.isin(batch_edge_ids, batch['node', 'to', 'node'].edge_attr[:, 0].detach().cpu()) 158 | 159 | if missing.sum() != 0 and (args.data == 'Small_J' or args.data == 'Small_Q'): 160 | missing_ids = batch_edge_ids[missing].int() 161 | n_ids = batch['node'].n_id 162 | add_edge_index = data['node', 'to', 'node'].edge_index[:, missing_ids].detach().clone() 163 | node_mapping = {value.item(): idx for idx, value in enumerate(n_ids)} 164 | add_edge_index = torch.tensor([[node_mapping[val.item()] for val in row] for row in add_edge_index]) 165 | add_edge_attr = data['node', 'to', 'node'].edge_attr[missing_ids, :].detach().clone() 166 | add_y = data['node', 'to', 'node'].y[missing_ids].detach().clone() 167 | 168 | batch['node', 'to', 'node'].edge_index = torch.cat((batch['node', 'to', 'node'].edge_index, add_edge_index), 1) 169 | batch['node', 'to', 'node'].edge_attr = torch.cat((batch['node', 'to', 'node'].edge_attr, add_edge_attr), 0) 170 | batch['node', 'to', 'node'].y = torch.cat((batch['node', 'to', 'node'].y, add_y), 0) 171 | 172 | mask = torch.cat((mask, torch.ones(add_y.shape[0], dtype=torch.bool))) 173 | 174 | #remove the unique edge id from the edge features, as it's no longer needed 175 | batch['node', 'to', 'node'].edge_attr = batch['node', 'to', 'node'].edge_attr[:, 1:] 176 | batch['node', 'rev_to', 'node'].edge_attr = batch['node', 'rev_to', 'node'].edge_attr[:, 1:] 177 | 178 | with torch.no_grad(): 179 | batch.to(device) 180 | out = model(batch.x_dict, batch.edge_index_dict, batch.edge_attr_dict) 181 | out = out[('node', 'to', 'node')] 182 | out = out[mask] 183 | pred = out.argmax(dim=-1) 184 | preds.append(pred) 185 | ground_truths.append(batch['node', 'to', 'node'].y[mask]) 186 | pred = torch.cat(preds, dim=0).cpu().numpy() 187 | ground_truth = torch.cat(ground_truths, dim=0).cpu().numpy() 188 | f1 = f1_score(ground_truth, pred) 189 | 190 | return f1 191 | 192 | def save_model(model, optimizer, epoch, args, data_config): 193 | # Save the model in a dictionary 194 | torch.save({ 195 | 'epoch': epoch + 1, 196 | 'model_state_dict': model.state_dict(), 197 | 'optimizer_state_dict': optimizer.state_dict() 198 | }, f'{data_config["paths"]["model_to_save"]}/checkpoint_{args.unique_name}{"" if not args.finetune else "_finetuned"}.tar') 199 | 200 | def load_model(model, device, args, config, data_config): 201 | checkpoint = torch.load(f'{data_config["paths"]["model_to_load"]}/checkpoint_{args.unique_name}.tar') 202 | model.load_state_dict(checkpoint['model_state_dict']) 203 | model.to(device) 204 | optimizer = torch.optim.Adam(model.parameters(), lr=config.lr) 205 | optimizer.load_state_dict(checkpoint['optimizer_state_dict']) 206 | 207 | return model, optimizer -------------------------------------------------------------------------------- /training.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import tqdm 3 | from sklearn.metrics import f1_score 4 | from train_util import AddEgoIds, extract_param, add_arange_ids, get_loaders, evaluate_homo, evaluate_hetero, save_model, load_model 5 | from models import GINe, PNA, GATe, RGCN 6 | from torch_geometric.data import Data, HeteroData 7 | from torch_geometric.nn import to_hetero, summary 8 | from torch_geometric.utils import degree 9 | import wandb 10 | import logging 11 | 12 | def train_homo(tr_loader, val_loader, te_loader, tr_inds, val_inds, te_inds, model, optimizer, loss_fn, args, config, device, val_data, te_data, data_config): 13 | #training 14 | best_val_f1 = 0 15 | for epoch in range(config.epochs): 16 | total_loss = total_examples = 0 17 | preds = [] 18 | ground_truths = [] 19 | for batch in tqdm.tqdm(tr_loader, disable=not args.tqdm): 20 | optimizer.zero_grad() 21 | #select the seed edges from which the batch was created 22 | inds = tr_inds.detach().cpu() 23 | batch_edge_inds = inds[batch.input_id.detach().cpu()] 24 | batch_edge_ids = tr_loader.data.edge_attr.detach().cpu()[batch_edge_inds, 0] 25 | mask = torch.isin(batch.edge_attr[:, 0].detach().cpu(), batch_edge_ids) 26 | 27 | #remove the unique edge id from the edge features, as it's no longer needed 28 | batch.edge_attr = batch.edge_attr[:, 1:] 29 | 30 | batch.to(device) 31 | out = model(batch.x, batch.edge_index, batch.edge_attr) 32 | pred = out[mask] 33 | ground_truth = batch.y[mask] 34 | preds.append(pred.argmax(dim=-1)) 35 | ground_truths.append(ground_truth) 36 | loss = loss_fn(pred, ground_truth) 37 | 38 | loss.backward() 39 | optimizer.step() 40 | 41 | total_loss += float(loss) * pred.numel() 42 | total_examples += pred.numel() 43 | 44 | pred = torch.cat(preds, dim=0).detach().cpu().numpy() 45 | ground_truth = torch.cat(ground_truths, dim=0).detach().cpu().numpy() 46 | f1 = f1_score(ground_truth, pred) 47 | wandb.log({"f1/train": f1}, step=epoch) 48 | logging.info(f'Train F1: {f1:.4f}') 49 | 50 | #evaluate 51 | val_f1 = evaluate_homo(val_loader, val_inds, model, val_data, device, args) 52 | te_f1 = evaluate_homo(te_loader, te_inds, model, te_data, device, args) 53 | 54 | wandb.log({"f1/validation": val_f1}, step=epoch) 55 | wandb.log({"f1/test": te_f1}, step=epoch) 56 | logging.info(f'Validation F1: {val_f1:.4f}') 57 | logging.info(f'Test F1: {te_f1:.4f}') 58 | 59 | if epoch == 0: 60 | wandb.log({"best_test_f1": te_f1}, step=epoch) 61 | elif val_f1 > best_val_f1: 62 | best_val_f1 = val_f1 63 | wandb.log({"best_test_f1": te_f1}, step=epoch) 64 | if args.save_model: 65 | save_model(model, optimizer, epoch, args, data_config) 66 | 67 | return model 68 | 69 | def train_hetero(tr_loader, val_loader, te_loader, tr_inds, val_inds, te_inds, model, optimizer, loss_fn, args, config, device, val_data, te_data, data_config): 70 | #training 71 | best_val_f1 = 0 72 | for epoch in range(config.epochs): 73 | total_loss = total_examples = 0 74 | preds = [] 75 | ground_truths = [] 76 | for batch in tqdm.tqdm(tr_loader, disable=not args.tqdm): 77 | optimizer.zero_grad() 78 | #select the seed edges from which the batch was created 79 | inds = tr_inds.detach().cpu() 80 | batch_edge_inds = inds[batch['node', 'to', 'node'].input_id.detach().cpu()] 81 | batch_edge_ids = tr_loader.data['node', 'to', 'node'].edge_attr.detach().cpu()[batch_edge_inds, 0] 82 | mask = torch.isin(batch['node', 'to', 'node'].edge_attr[:, 0].detach().cpu(), batch_edge_ids) 83 | 84 | #remove the unique edge id from the edge features, as it's no longer needed 85 | batch['node', 'to', 'node'].edge_attr = batch['node', 'to', 'node'].edge_attr[:, 1:] 86 | batch['node', 'rev_to', 'node'].edge_attr = batch['node', 'rev_to', 'node'].edge_attr[:, 1:] 87 | 88 | batch.to(device) 89 | out = model(batch.x_dict, batch.edge_index_dict, batch.edge_attr_dict) 90 | out = out[('node', 'to', 'node')] 91 | pred = out[mask] 92 | ground_truth = batch['node', 'to', 'node'].y[mask] 93 | preds.append(pred.argmax(dim=-1)) 94 | ground_truths.append(batch['node', 'to', 'node'].y[mask]) 95 | loss = loss_fn(pred, ground_truth) 96 | 97 | loss.backward() 98 | optimizer.step() 99 | 100 | total_loss += float(loss) * pred.numel() 101 | total_examples += pred.numel() 102 | 103 | pred = torch.cat(preds, dim=0).detach().cpu().numpy() 104 | ground_truth = torch.cat(ground_truths, dim=0).detach().cpu().numpy() 105 | f1 = f1_score(ground_truth, pred) 106 | wandb.log({"f1/train": f1}, step=epoch) 107 | logging.info(f'Train F1: {f1:.4f}') 108 | 109 | #evaluate 110 | val_f1 = evaluate_hetero(val_loader, val_inds, model, val_data, device, args) 111 | te_f1 = evaluate_hetero(te_loader, te_inds, model, te_data, device, args) 112 | 113 | wandb.log({"f1/validation": val_f1}, step=epoch) 114 | wandb.log({"f1/test": te_f1}, step=epoch) 115 | logging.info(f'Validation F1: {val_f1:.4f}') 116 | logging.info(f'Test F1: {te_f1:.4f}') 117 | 118 | if epoch == 0: 119 | wandb.log({"best_test_f1": te_f1}, step=epoch) 120 | elif val_f1 > best_val_f1: 121 | best_val_f1 = val_f1 122 | wandb.log({"best_test_f1": te_f1}, step=epoch) 123 | if args.save_model: 124 | save_model(model, optimizer, epoch, args, data_config) 125 | 126 | return model 127 | 128 | def get_model(sample_batch, config, args): 129 | n_feats = sample_batch.x.shape[1] if not isinstance(sample_batch, HeteroData) else sample_batch['node'].x.shape[1] 130 | e_dim = (sample_batch.edge_attr.shape[1] - 1) if not isinstance(sample_batch, HeteroData) else (sample_batch['node', 'to', 'node'].edge_attr.shape[1] - 1) 131 | 132 | if args.model == "gin": 133 | model = GINe( 134 | num_features=n_feats, num_gnn_layers=config.n_gnn_layers, n_classes=2, 135 | n_hidden=round(config.n_hidden), residual=False, edge_updates=args.emlps, edge_dim=e_dim, 136 | dropout=config.dropout, final_dropout=config.final_dropout 137 | ) 138 | elif args.model == "gat": 139 | model = GATe( 140 | num_features=n_feats, num_gnn_layers=config.n_gnn_layers, n_classes=2, 141 | n_hidden=round(config.n_hidden), n_heads=round(config.n_heads), 142 | edge_updates=args.emlps, edge_dim=e_dim, 143 | dropout=config.dropout, final_dropout=config.final_dropout 144 | ) 145 | elif args.model == "pna": 146 | if not isinstance(sample_batch, HeteroData): 147 | d = degree(sample_batch.edge_index[1], dtype=torch.long) 148 | else: 149 | index = torch.cat((sample_batch['node', 'to', 'node'].edge_index[1], sample_batch['node', 'rev_to', 'node'].edge_index[1]), 0) 150 | d = degree(index, dtype=torch.long) 151 | deg = torch.bincount(d, minlength=1) 152 | model = PNA( 153 | num_features=n_feats, num_gnn_layers=config.n_gnn_layers, n_classes=2, 154 | n_hidden=round(config.n_hidden), edge_updates=args.emlps, edge_dim=e_dim, 155 | dropout=config.dropout, deg=deg, final_dropout=config.final_dropout 156 | ) 157 | elif config.model == "rgcn": 158 | model = RGCN( 159 | num_features=n_feats, edge_dim=e_dim, num_relations=8, num_gnn_layers=round(config.n_gnn_layers), 160 | n_classes=2, n_hidden=round(config.n_hidden), 161 | edge_update=args.emlps, dropout=config.dropout, final_dropout=config.final_dropout, n_bases=None #(maybe) 162 | ) 163 | 164 | return model 165 | 166 | def train_gnn(tr_data, val_data, te_data, tr_inds, val_inds, te_inds, args, data_config): 167 | #set device 168 | device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") 169 | 170 | #define a model config dictionary and wandb logging at the same time 171 | wandb.init( 172 | mode="disabled" if args.testing else "online", 173 | project="your_proj_name", #replace this with your wandb project name if you want to use wandb logging 174 | 175 | config={ 176 | "epochs": args.n_epochs, 177 | "batch_size": args.batch_size, 178 | "model": args.model, 179 | "data": args.data, 180 | "num_neighbors": args.num_neighs, 181 | "lr": extract_param("lr", args), 182 | "n_hidden": extract_param("n_hidden", args), 183 | "n_gnn_layers": extract_param("n_gnn_layers", args), 184 | "loss": "ce", 185 | "w_ce1": extract_param("w_ce1", args), 186 | "w_ce2": extract_param("w_ce2", args), 187 | "dropout": extract_param("dropout", args), 188 | "final_dropout": extract_param("final_dropout", args), 189 | "n_heads": extract_param("n_heads", args) if args.model == 'gat' else None 190 | } 191 | ) 192 | 193 | config = wandb.config 194 | 195 | #set the transform if ego ids should be used 196 | if args.ego: 197 | transform = AddEgoIds() 198 | else: 199 | transform = None 200 | 201 | #add the unique ids to later find the seed edges 202 | add_arange_ids([tr_data, val_data, te_data]) 203 | 204 | tr_loader, val_loader, te_loader = get_loaders(tr_data, val_data, te_data, tr_inds, val_inds, te_inds, transform, args) 205 | 206 | #get the model 207 | sample_batch = next(iter(tr_loader)) 208 | model = get_model(sample_batch, config, args) 209 | 210 | if args.reverse_mp: 211 | model = to_hetero(model, te_data.metadata(), aggr='mean') 212 | 213 | if args.finetune: 214 | model, optimizer = load_model(model, device, args, config, data_config) 215 | else: 216 | model.to(device) 217 | optimizer = torch.optim.Adam(model.parameters(), lr=config.lr) 218 | 219 | sample_batch.to(device) 220 | sample_x = sample_batch.x if not isinstance(sample_batch, HeteroData) else sample_batch.x_dict 221 | sample_edge_index = sample_batch.edge_index if not isinstance(sample_batch, HeteroData) else sample_batch.edge_index_dict 222 | if isinstance(sample_batch, HeteroData): 223 | sample_batch['node', 'to', 'node'].edge_attr = sample_batch['node', 'to', 'node'].edge_attr[:, 1:] 224 | sample_batch['node', 'rev_to', 'node'].edge_attr = sample_batch['node', 'rev_to', 'node'].edge_attr[:, 1:] 225 | else: 226 | sample_batch.edge_attr = sample_batch.edge_attr[:, 1:] 227 | sample_edge_attr = sample_batch.edge_attr if not isinstance(sample_batch, HeteroData) else sample_batch.edge_attr_dict 228 | logging.info(summary(model, sample_x, sample_edge_index, sample_edge_attr)) 229 | 230 | loss_fn = torch.nn.CrossEntropyLoss(weight=torch.FloatTensor([config.w_ce1, config.w_ce2]).to(device)) 231 | 232 | if args.reverse_mp: 233 | model = train_hetero(tr_loader, val_loader, te_loader, tr_inds, val_inds, te_inds, model, optimizer, loss_fn, args, config, device, val_data, te_data, data_config) 234 | else: 235 | model = train_homo(tr_loader, val_loader, te_loader, tr_inds, val_inds, te_inds, model, optimizer, loss_fn, args, config, device, val_data, te_data, data_config) 236 | 237 | wandb.finish() -------------------------------------------------------------------------------- /util.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import numpy as np 3 | import torch 4 | import random 5 | import logging 6 | import os 7 | import sys 8 | 9 | def logger_setup(): 10 | # Setup logging 11 | log_directory = "logs" 12 | if not os.path.exists(log_directory): 13 | os.makedirs(log_directory) 14 | logging.basicConfig( 15 | level=logging.INFO, 16 | format="%(asctime)s [%(levelname)-5.5s] %(message)s", 17 | handlers=[ 18 | logging.FileHandler(os.path.join(log_directory, "logs.log")), ## log to local log file 19 | logging.StreamHandler(sys.stdout) ## log also to stdout (i.e., print to screen) 20 | ] 21 | ) 22 | 23 | def create_parser(): 24 | parser = argparse.ArgumentParser() 25 | 26 | #Adaptations 27 | parser.add_argument("--emlps", action='store_true', help="Use emlps in GNN training") 28 | parser.add_argument("--reverse_mp", action='store_true', help="Use reverse MP in GNN training") 29 | parser.add_argument("--ports", action='store_true', help="Use port numberings in GNN training") 30 | parser.add_argument("--tds", action='store_true', help="Use time deltas (i.e. the time between subsequent transactions) in GNN training") 31 | parser.add_argument("--ego", action='store_true', help="Use ego IDs in GNN training") 32 | 33 | #Model parameters 34 | parser.add_argument("--batch_size", default=8192, type=int, help="Select the batch size for GNN training") 35 | parser.add_argument("--n_epochs", default=100, type=int, help="Select the number of epochs for GNN training") 36 | parser.add_argument('--num_neighs', nargs='+', default=[100,100], help='Pass the number of neighors to be sampled in each hop (descending).') 37 | 38 | #Misc 39 | parser.add_argument("--seed", default=1, type=int, help="Select the random seed for reproducability") 40 | parser.add_argument("--tqdm", action='store_true', help="Use tqdm logging (when running interactively in terminal)") 41 | parser.add_argument("--data", default=None, type=str, help="Select the AML dataset. Needs to be either small or medium.", required=True) 42 | parser.add_argument("--model", default=None, type=str, help="Select the model architecture. Needs to be one of [gin, gat, rgcn, pna]", required=True) 43 | parser.add_argument("--testing", action='store_true', help="Disable wandb logging while running the script in 'testing' mode.") 44 | parser.add_argument("--save_model", action='store_true', help="Save the best model.") 45 | parser.add_argument("--unique_name", action='store_true', help="Unique name under which the model will be stored.") 46 | parser.add_argument("--finetune", action='store_true', help="Fine-tune a model. Note that args.unique_name needs to point to the pre-trained model.") 47 | parser.add_argument("--inference", action='store_true', help="Load a trained model and only do AML inference with it. args.unique name needs to point to the trained model.") 48 | 49 | return parser 50 | 51 | def set_seed(seed: int = 0) -> None: 52 | np.random.seed(seed) 53 | random.seed(seed) 54 | torch.manual_seed(seed) 55 | torch.cuda.manual_seed(seed) 56 | # When running on the CuDNN backend, two further options must be set 57 | torch.backends.cudnn.deterministic = True 58 | torch.backends.cudnn.benchmark = False 59 | # Set a fixed value for the hash seed 60 | os.environ["PYTHONHASHSEED"] = str(seed) 61 | logging.info(f"Random seed set as {seed}") --------------------------------------------------------------------------------