├── LICENSE
├── README.md
├── data_config.json
├── data_loading.py
├── data_util.py
├── env.yml
├── format_kaggle_files.py
├── inference.py
├── main.py
├── model_settings.json
├── models.py
├── train_util.py
├── training.py
└── util.py
/LICENSE:
--------------------------------------------------------------------------------
1 | Apache License
2 | Version 2.0, January 2004
3 | http://www.apache.org/licenses/
4 |
5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
6 |
7 | 1. Definitions.
8 |
9 | "License" shall mean the terms and conditions for use, reproduction,
10 | and distribution as defined by Sections 1 through 9 of this document.
11 |
12 | "Licensor" shall mean the copyright owner or entity authorized by
13 | the copyright owner that is granting the License.
14 |
15 | "Legal Entity" shall mean the union of the acting entity and all
16 | other entities that control, are controlled by, or are under common
17 | control with that entity. For the purposes of this definition,
18 | "control" means (i) the power, direct or indirect, to cause the
19 | direction or management of such entity, whether by contract or
20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the
21 | outstanding shares, or (iii) beneficial ownership of such entity.
22 |
23 | "You" (or "Your") shall mean an individual or Legal Entity
24 | exercising permissions granted by this License.
25 |
26 | "Source" form shall mean the preferred form for making modifications,
27 | including but not limited to software source code, documentation
28 | source, and configuration files.
29 |
30 | "Object" form shall mean any form resulting from mechanical
31 | transformation or translation of a Source form, including but
32 | not limited to compiled object code, generated documentation,
33 | and conversions to other media types.
34 |
35 | "Work" shall mean the work of authorship, whether in Source or
36 | Object form, made available under the License, as indicated by a
37 | copyright notice that is included in or attached to the work
38 | (an example is provided in the Appendix below).
39 |
40 | "Derivative Works" shall mean any work, whether in Source or Object
41 | form, that is based on (or derived from) the Work and for which the
42 | editorial revisions, annotations, elaborations, or other modifications
43 | represent, as a whole, an original work of authorship. For the purposes
44 | of this License, Derivative Works shall not include works that remain
45 | separable from, or merely link (or bind by name) to the interfaces of,
46 | the Work and Derivative Works thereof.
47 |
48 | "Contribution" shall mean any work of authorship, including
49 | the original version of the Work and any modifications or additions
50 | to that Work or Derivative Works thereof, that is intentionally
51 | submitted to Licensor for inclusion in the Work by the copyright owner
52 | or by an individual or Legal Entity authorized to submit on behalf of
53 | the copyright owner. For the purposes of this definition, "submitted"
54 | means any form of electronic, verbal, or written communication sent
55 | to the Licensor or its representatives, including but not limited to
56 | communication on electronic mailing lists, source code control systems,
57 | and issue tracking systems that are managed by, or on behalf of, the
58 | Licensor for the purpose of discussing and improving the Work, but
59 | excluding communication that is conspicuously marked or otherwise
60 | designated in writing by the copyright owner as "Not a Contribution."
61 |
62 | "Contributor" shall mean Licensor and any individual or Legal Entity
63 | on behalf of whom a Contribution has been received by Licensor and
64 | subsequently incorporated within the Work.
65 |
66 | 2. Grant of Copyright License. Subject to the terms and conditions of
67 | this License, each Contributor hereby grants to You a perpetual,
68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
69 | copyright license to reproduce, prepare Derivative Works of,
70 | publicly display, publicly perform, sublicense, and distribute the
71 | Work and such Derivative Works in Source or Object form.
72 |
73 | 3. Grant of Patent License. Subject to the terms and conditions of
74 | this License, each Contributor hereby grants to You a perpetual,
75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
76 | (except as stated in this section) patent license to make, have made,
77 | use, offer to sell, sell, import, and otherwise transfer the Work,
78 | where such license applies only to those patent claims licensable
79 | by such Contributor that are necessarily infringed by their
80 | Contribution(s) alone or by combination of their Contribution(s)
81 | with the Work to which such Contribution(s) was submitted. If You
82 | institute patent litigation against any entity (including a
83 | cross-claim or counterclaim in a lawsuit) alleging that the Work
84 | or a Contribution incorporated within the Work constitutes direct
85 | or contributory patent infringement, then any patent licenses
86 | granted to You under this License for that Work shall terminate
87 | as of the date such litigation is filed.
88 |
89 | 4. Redistribution. You may reproduce and distribute copies of the
90 | Work or Derivative Works thereof in any medium, with or without
91 | modifications, and in Source or Object form, provided that You
92 | meet the following conditions:
93 |
94 | (a) You must give any other recipients of the Work or
95 | Derivative Works a copy of this License; and
96 |
97 | (b) You must cause any modified files to carry prominent notices
98 | stating that You changed the files; and
99 |
100 | (c) You must retain, in the Source form of any Derivative Works
101 | that You distribute, all copyright, patent, trademark, and
102 | attribution notices from the Source form of the Work,
103 | excluding those notices that do not pertain to any part of
104 | the Derivative Works; and
105 |
106 | (d) If the Work includes a "NOTICE" text file as part of its
107 | distribution, then any Derivative Works that You distribute must
108 | include a readable copy of the attribution notices contained
109 | within such NOTICE file, excluding those notices that do not
110 | pertain to any part of the Derivative Works, in at least one
111 | of the following places: within a NOTICE text file distributed
112 | as part of the Derivative Works; within the Source form or
113 | documentation, if provided along with the Derivative Works; or,
114 | within a display generated by the Derivative Works, if and
115 | wherever such third-party notices normally appear. The contents
116 | of the NOTICE file are for informational purposes only and
117 | do not modify the License. You may add Your own attribution
118 | notices within Derivative Works that You distribute, alongside
119 | or as an addendum to the NOTICE text from the Work, provided
120 | that such additional attribution notices cannot be construed
121 | as modifying the License.
122 |
123 | You may add Your own copyright statement to Your modifications and
124 | may provide additional or different license terms and conditions
125 | for use, reproduction, or distribution of Your modifications, or
126 | for any such Derivative Works as a whole, provided Your use,
127 | reproduction, and distribution of the Work otherwise complies with
128 | the conditions stated in this License.
129 |
130 | 5. Submission of Contributions. Unless You explicitly state otherwise,
131 | any Contribution intentionally submitted for inclusion in the Work
132 | by You to the Licensor shall be under the terms and conditions of
133 | this License, without any additional terms or conditions.
134 | Notwithstanding the above, nothing herein shall supersede or modify
135 | the terms of any separate license agreement you may have executed
136 | with Licensor regarding such Contributions.
137 |
138 | 6. Trademarks. This License does not grant permission to use the trade
139 | names, trademarks, service marks, or product names of the Licensor,
140 | except as required for reasonable and customary use in describing the
141 | origin of the Work and reproducing the content of the NOTICE file.
142 |
143 | 7. Disclaimer of Warranty. Unless required by applicable law or
144 | agreed to in writing, Licensor provides the Work (and each
145 | Contributor provides its Contributions) on an "AS IS" BASIS,
146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
147 | implied, including, without limitation, any warranties or conditions
148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
149 | PARTICULAR PURPOSE. You are solely responsible for determining the
150 | appropriateness of using or redistributing the Work and assume any
151 | risks associated with Your exercise of permissions under this License.
152 |
153 | 8. Limitation of Liability. In no event and under no legal theory,
154 | whether in tort (including negligence), contract, or otherwise,
155 | unless required by applicable law (such as deliberate and grossly
156 | negligent acts) or agreed to in writing, shall any Contributor be
157 | liable to You for damages, including any direct, indirect, special,
158 | incidental, or consequential damages of any character arising as a
159 | result of this License or out of the use or inability to use the
160 | Work (including but not limited to damages for loss of goodwill,
161 | work stoppage, computer failure or malfunction, or any and all
162 | other commercial damages or losses), even if such Contributor
163 | has been advised of the possibility of such damages.
164 |
165 | 9. Accepting Warranty or Additional Liability. While redistributing
166 | the Work or Derivative Works thereof, You may choose to offer,
167 | and charge a fee for, acceptance of support, warranty, indemnity,
168 | or other liability obligations and/or rights consistent with this
169 | License. However, in accepting such obligations, You may act only
170 | on Your own behalf and on Your sole responsibility, not on behalf
171 | of any other Contributor, and only if You agree to indemnify,
172 | defend, and hold each Contributor harmless for any liability
173 | incurred by, or claims asserted against, such Contributor by reason
174 | of your accepting any such warranty or additional liability.
175 |
176 | END OF TERMS AND CONDITIONS
177 |
178 | APPENDIX: How to apply the Apache License to your work.
179 |
180 | To apply the Apache License to your work, attach the following
181 | boilerplate notice, with the fields enclosed by brackets "[]"
182 | replaced with your own identifying information. (Don't include
183 | the brackets!) The text should be enclosed in the appropriate
184 | comment syntax for the file format. We also recommend that a
185 | file or class name and description of purpose be included on the
186 | same "printed page" as the copyright notice for easier
187 | identification within third-party archives.
188 |
189 | Copyright [yyyy] [name of copyright owner]
190 |
191 | Licensed under the Apache License, Version 2.0 (the "License");
192 | you may not use this file except in compliance with the License.
193 | You may obtain a copy of the License at
194 |
195 | http://www.apache.org/licenses/LICENSE-2.0
196 |
197 | Unless required by applicable law or agreed to in writing, software
198 | distributed under the License is distributed on an "AS IS" BASIS,
199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
200 | See the License for the specific language governing permissions and
201 | limitations under the License.
202 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # Multi-GNN
2 | This repository contains all models and adaptations needed to run Multi-GNN for Anti-Money Laundering. The repository consists of four Graph Neural Network model classes ([GIN](https://arxiv.org/abs/1810.00826), [GAT](https://arxiv.org/abs/1710.10903), [PNA](https://arxiv.org/abs/2004.05718), [RGCN](https://arxiv.org/abs/1703.06103)) and the below-described model adaptations utilized for financial crime detection in [Egressy et al.](https://arxiv.org/abs/2306.11586). Note that this repository solely focuses on the Anti-Money Laundering use case. This repository has been created for experiments in [Provably Powerful Graph Neural Networks for Directed Multigraphs](https://arxiv.org/abs/2306.11586) [AAAI 2024] and [Realistic Synthetic Financial Transactions for Anti-Money Laundering Models](https://arxiv.org/abs/2306.16424) [NeurIPS 2023].
3 |
4 | ## Setup
5 | To use the repository, you first need to install the conda environment via
6 | ```
7 | conda env create -f env.yml
8 | ```
9 | Then, the data needed for the experiments can be found on [Kaggle](https://www.kaggle.com/datasets/ealtman2019/ibm-transactions-for-anti-money-laundering-aml/data). To use this data with the provided training scripts, you first need to perform a pre-processing step for the downloaded transaction files (e.g. `HI-Small_Trans.csv`):
10 | ```
11 | python format_kaggle_files.py /path/to/kaggle-files/HI-Small_Trans.csv
12 | ```
13 | Make sure to change the filepaths in the `data_config.json` file. The `aml_data` path should be changed to wherever you stored the `formatted_transactions.csv` file generated by the pre-processing step.
14 |
15 | ## Usage
16 | To run the experiments you need to run the `main.py` function and specify any arguments you want to use. There are two required arguments, namely `--data` and `--model`. For the `--data` argument, make sure you store the different datasets in different folders. Then, specify the folder name, e.g `--data Small_HI`. The `--model` parameter should be set to any of the model classed that are available, i.e. to one of `--model [gin, gat, rgcn, pna]`. Thus, to run a standard GNN, you need to run, e.g.:
17 | ```
18 | python main.py --data Small_HI --model gin
19 | ```
20 | Then you can add different adaptations to the models by selecting the respective arguments from:
21 |
22 |
23 |
24 | | Argument | Description |
25 | | -------------- | ---------------------------- |
26 | | `--emlps` | Edge updates via MLPs |
27 | | `--reverse_mp` | Reverse Message Passing |
28 | | `--ego` | Ego ID's to the center nodes |
29 | | `--ports` | Port Numberings for edges |
30 |
31 |
32 | Thus, to run Multi-GIN with edge updates, you would run the following command:
33 |
34 | ```
35 | python main.py --data Small_HI --model gin --emlps --reverse_mp --ego --ports
36 | ```
37 |
38 | ## Additional functionalities
39 | There are several arguments that can be set for additional functionality. Here's a list with them:
40 |
41 |
42 |
43 | | Argument | Description |
44 | | -------------- | ---------------------------------------------------------------------------------------------------------------------------------------------------------|
45 | | `--tqdm` | Displays a progress bar during training and inference. |
46 | | `--save_model` | Saves the best model to the specified `model_to_save` path in the `data_config.json` file. Requires argment `--unique_name` to be specified. |
47 | | `--finetune` | Loads a previously trained model (with name given by `--unique_name` and stored in `model_to_load` path in the `data_config.json`) to be finetuned. |
48 | | `--inference` | Loads a previously trained model (with name given by `--unique_name` and stored in `model_to_load` path in the `data_config.json`) to do inference only. |
49 |
50 |
51 |
52 | ## Licence
53 | Apache License
54 | Version 2.0, January 2004
--------------------------------------------------------------------------------
/data_config.json:
--------------------------------------------------------------------------------
1 | {
2 | "paths": {
3 | "aml_data": "/path/to/aml_data",
4 | "model_to_load": "/path/to/model_you_want_to_load (e.g for inference or fine-tuning)",
5 | "model_to_save": "/path/to/model_save_location (where you want to store the model you are going to train)"
6 | }
7 | }
--------------------------------------------------------------------------------
/data_loading.py:
--------------------------------------------------------------------------------
1 | import pandas as pd
2 | import numpy as np
3 | import torch
4 | import logging
5 | import itertools
6 | from data_util import GraphData, HeteroData, z_norm, create_hetero_obj
7 |
8 | def get_data(args, data_config):
9 | '''Loads the AML transaction data.
10 |
11 | 1. The data is loaded from the csv and the necessary features are chosen.
12 | 2. The data is split into training, validation and test data.
13 | 3. PyG Data objects are created with the respective data splits.
14 | '''
15 |
16 | transaction_file = f"{data_config['paths']['aml_data']}/{args.data}/formatted_transactions.csv" #replace this with your path to the respective AML data objects
17 | df_edges = pd.read_csv(transaction_file)
18 |
19 | logging.info(f'Available Edge Features: {df_edges.columns.tolist()}')
20 |
21 | df_edges['Timestamp'] = df_edges['Timestamp'] - df_edges['Timestamp'].min()
22 |
23 | max_n_id = df_edges.loc[:, ['from_id', 'to_id']].to_numpy().max() + 1
24 | df_nodes = pd.DataFrame({'NodeID': np.arange(max_n_id), 'Feature': np.ones(max_n_id)})
25 | timestamps = torch.Tensor(df_edges['Timestamp'].to_numpy())
26 | y = torch.LongTensor(df_edges['Is Laundering'].to_numpy())
27 |
28 | logging.info(f"Illicit ratio = {sum(y)} / {len(y)} = {sum(y) / len(y) * 100:.2f}%")
29 | logging.info(f"Number of nodes (holdings doing transcations) = {df_nodes.shape[0]}")
30 | logging.info(f"Number of transactions = {df_edges.shape[0]}")
31 |
32 | edge_features = ['Timestamp', 'Amount Received', 'Received Currency', 'Payment Format']
33 | node_features = ['Feature']
34 |
35 | logging.info(f'Edge features being used: {edge_features}')
36 | logging.info(f'Node features being used: {node_features} ("Feature" is a placeholder feature of all 1s)')
37 |
38 | x = torch.tensor(df_nodes.loc[:, node_features].to_numpy()).float()
39 | edge_index = torch.LongTensor(df_edges.loc[:, ['from_id', 'to_id']].to_numpy().T)
40 | edge_attr = torch.tensor(df_edges.loc[:, edge_features].to_numpy()).float()
41 |
42 | n_days = int(timestamps.max() / (3600 * 24) + 1)
43 | n_samples = y.shape[0]
44 | logging.info(f'number of days and transactions in the data: {n_days} days, {n_samples} transactions')
45 |
46 | #data splitting
47 | daily_irs, weighted_daily_irs, daily_inds, daily_trans = [], [], [], [] #irs = illicit ratios, inds = indices, trans = transactions
48 | for day in range(n_days):
49 | l = day * 24 * 3600
50 | r = (day + 1) * 24 * 3600
51 | day_inds = torch.where((timestamps >= l) & (timestamps < r))[0]
52 | daily_irs.append(y[day_inds].float().mean())
53 | weighted_daily_irs.append(y[day_inds].float().mean() * day_inds.shape[0] / n_samples)
54 | daily_inds.append(day_inds)
55 | daily_trans.append(day_inds.shape[0])
56 |
57 | split_per = [0.6, 0.2, 0.2]
58 | daily_totals = np.array(daily_trans)
59 | d_ts = daily_totals
60 | I = list(range(len(d_ts)))
61 | split_scores = dict()
62 | for i,j in itertools.combinations(I, 2):
63 | if j >= i:
64 | split_totals = [d_ts[:i].sum(), d_ts[i:j].sum(), d_ts[j:].sum()]
65 | split_totals_sum = np.sum(split_totals)
66 | split_props = [v/split_totals_sum for v in split_totals]
67 | split_error = [abs(v-t)/t for v,t in zip(split_props, split_per)]
68 | score = max(split_error) #- (split_totals_sum/total) + 1
69 | split_scores[(i,j)] = score
70 | else:
71 | continue
72 |
73 | i,j = min(split_scores, key=split_scores.get)
74 | #split contains a list for each split (train, validation and test) and each list contains the days that are part of the respective split
75 | split = [list(range(i)), list(range(i, j)), list(range(j, len(daily_totals)))]
76 | logging.info(f'Calculate split: {split}')
77 |
78 | #Now, we seperate the transactions based on their indices in the timestamp array
79 | split_inds = {k: [] for k in range(3)}
80 | for i in range(3):
81 | for day in split[i]:
82 | split_inds[i].append(daily_inds[day]) #split_inds contains a list for each split (tr,val,te) which contains the indices of each day seperately
83 |
84 | tr_inds = torch.cat(split_inds[0])
85 | val_inds = torch.cat(split_inds[1])
86 | te_inds = torch.cat(split_inds[2])
87 |
88 | logging.info(f"Total train samples: {tr_inds.shape[0] / y.shape[0] * 100 :.2f}% || IR: "
89 | f"{y[tr_inds].float().mean() * 100 :.2f}% || Train days: {split[0][:5]}")
90 | logging.info(f"Total val samples: {val_inds.shape[0] / y.shape[0] * 100 :.2f}% || IR: "
91 | f"{y[val_inds].float().mean() * 100:.2f}% || Val days: {split[1][:5]}")
92 | logging.info(f"Total test samples: {te_inds.shape[0] / y.shape[0] * 100 :.2f}% || IR: "
93 | f"{y[te_inds].float().mean() * 100:.2f}% || Test days: {split[2][:5]}")
94 |
95 | #Creating the final data objects
96 | tr_x, val_x, te_x = x, x, x
97 | e_tr = tr_inds.numpy()
98 | e_val = np.concatenate([tr_inds, val_inds])
99 |
100 | tr_edge_index, tr_edge_attr, tr_y, tr_edge_times = edge_index[:,e_tr], edge_attr[e_tr], y[e_tr], timestamps[e_tr]
101 | val_edge_index, val_edge_attr, val_y, val_edge_times = edge_index[:,e_val], edge_attr[e_val], y[e_val], timestamps[e_val]
102 | te_edge_index, te_edge_attr, te_y, te_edge_times = edge_index, edge_attr, y, timestamps
103 |
104 | tr_data = GraphData (x=tr_x, y=tr_y, edge_index=tr_edge_index, edge_attr=tr_edge_attr, timestamps=tr_edge_times )
105 | val_data = GraphData(x=val_x, y=val_y, edge_index=val_edge_index, edge_attr=val_edge_attr, timestamps=val_edge_times)
106 | te_data = GraphData (x=te_x, y=te_y, edge_index=te_edge_index, edge_attr=te_edge_attr, timestamps=te_edge_times )
107 |
108 | #Adding ports and time-deltas if applicable
109 | if args.ports:
110 | logging.info(f"Start: adding ports")
111 | tr_data.add_ports()
112 | val_data.add_ports()
113 | te_data.add_ports()
114 | logging.info(f"Done: adding ports")
115 | if args.tds:
116 | logging.info(f"Start: adding time-deltas")
117 | tr_data.add_time_deltas()
118 | val_data.add_time_deltas()
119 | te_data.add_time_deltas()
120 | logging.info(f"Done: adding time-deltas")
121 |
122 | #Normalize data
123 | tr_data.x = val_data.x = te_data.x = z_norm(tr_data.x)
124 | if not args.model == 'rgcn':
125 | tr_data.edge_attr, val_data.edge_attr, te_data.edge_attr = z_norm(tr_data.edge_attr), z_norm(val_data.edge_attr), z_norm(te_data.edge_attr)
126 | else:
127 | tr_data.edge_attr[:, :-1], val_data.edge_attr[:, :-1], te_data.edge_attr[:, :-1] = z_norm(tr_data.edge_attr[:, :-1]), z_norm(val_data.edge_attr[:, :-1]), z_norm(te_data.edge_attr[:, :-1])
128 |
129 | #Create heterogenous if reverese MP is enabled
130 | #TODO: if I observe wierd behaviour, maybe add .detach.clone() to all torch tensors, but I don't think they're attached to any computation graph just yet
131 | if args.reverse_mp:
132 | tr_data = create_hetero_obj(tr_data.x, tr_data.y, tr_data.edge_index, tr_data.edge_attr, tr_data.timestamps, args)
133 | val_data = create_hetero_obj(val_data.x, val_data.y, val_data.edge_index, val_data.edge_attr, val_data.timestamps, args)
134 | te_data = create_hetero_obj(te_data.x, te_data.y, te_data.edge_index, te_data.edge_attr, te_data.timestamps, args)
135 |
136 | logging.info(f'train data object: {tr_data}')
137 | logging.info(f'validation data object: {val_data}')
138 | logging.info(f'test data object: {te_data}')
139 |
140 | return tr_data, val_data, te_data, tr_inds, val_inds, te_inds
141 |
--------------------------------------------------------------------------------
/data_util.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch_geometric.data import Data, HeteroData
3 | from torch_geometric.typing import OptTensor
4 | import numpy as np
5 |
6 | def to_adj_nodes_with_times(data):
7 | num_nodes = data.num_nodes
8 | timestamps = torch.zeros((data.edge_index.shape[1], 1)) if data.timestamps is None else data.timestamps.reshape((-1,1))
9 | edges = torch.cat((data.edge_index.T, timestamps), dim=1) if not isinstance(data, HeteroData) else torch.cat((data['node', 'to', 'node'].edge_index.T, timestamps), dim=1)
10 | adj_list_out = dict([(i, []) for i in range(num_nodes)])
11 | adj_list_in = dict([(i, []) for i in range(num_nodes)])
12 | for u,v,t in edges:
13 | u,v,t = int(u), int(v), int(t)
14 | adj_list_out[u] += [(v, t)]
15 | adj_list_in[v] += [(u, t)]
16 | return adj_list_in, adj_list_out
17 |
18 | def to_adj_edges_with_times(data):
19 | num_nodes = data.num_nodes
20 | timestamps = torch.zeros((data.edge_index.shape[1], 1)) if data.timestamps is None else data.timestamps.reshape((-1,1))
21 | edges = torch.cat((data.edge_index.T, timestamps), dim=1)
22 | # calculate adjacent edges with times per node
23 | adj_edges_out = dict([(i, []) for i in range(num_nodes)])
24 | adj_edges_in = dict([(i, []) for i in range(num_nodes)])
25 | for i, (u,v,t) in enumerate(edges):
26 | u,v,t = int(u), int(v), int(t)
27 | adj_edges_out[u] += [(i, v, t)]
28 | adj_edges_in[v] += [(i, u, t)]
29 | return adj_edges_in, adj_edges_out
30 |
31 | def ports(edge_index, adj_list):
32 | ports = torch.zeros(edge_index.shape[1], 1)
33 | ports_dict = {}
34 | for v, nbs in adj_list.items():
35 | if len(nbs) < 1: continue
36 | a = np.array(nbs)
37 | a = a[a[:, -1].argsort()]
38 | _, idx = np.unique(a[:,[0]],return_index=True,axis=0)
39 | nbs_unique = a[np.sort(idx)][:,0]
40 | for i, u in enumerate(nbs_unique):
41 | ports_dict[(u,v)] = i
42 | for i, e in enumerate(edge_index.T):
43 | ports[i] = ports_dict[tuple(e.numpy())]
44 | return ports
45 |
46 | def time_deltas(data, adj_edges_list):
47 | time_deltas = torch.zeros(data.edge_index.shape[1], 1)
48 | if data.timestamps is None:
49 | return time_deltas
50 | for v, edges in adj_edges_list.items():
51 | if len(edges) < 1: continue
52 | a = np.array(edges)
53 | a = a[a[:, -1].argsort()]
54 | a_tds = [0] + [a[i+1,-1] - a[i,-1] for i in range(a.shape[0]-1)]
55 | tds = np.hstack((a[:,0].reshape(-1,1), np.array(a_tds).reshape(-1,1)))
56 | for i,td in tds:
57 | time_deltas[i] = td
58 | return time_deltas
59 |
60 | class GraphData(Data):
61 | '''This is the homogenous graph object we use for GNN training if reverse MP is not enabled'''
62 | def __init__(
63 | self, x: OptTensor = None, edge_index: OptTensor = None, edge_attr: OptTensor = None, y: OptTensor = None, pos: OptTensor = None,
64 | readout: str = 'edge',
65 | num_nodes: int = None,
66 | timestamps: OptTensor = None,
67 | node_timestamps: OptTensor = None,
68 | **kwargs
69 | ):
70 | super().__init__(x, edge_index, edge_attr, y, pos, **kwargs)
71 | self.readout = readout
72 | self.loss_fn = 'ce'
73 | self.num_nodes = int(self.x.shape[0])
74 | self.node_timestamps = node_timestamps
75 | if timestamps is not None:
76 | self.timestamps = timestamps
77 | elif edge_attr is not None:
78 | self.timestamps = edge_attr[:,0].clone()
79 | else:
80 | self.timestamps = None
81 |
82 | def add_ports(self):
83 | '''Adds port numberings to the edge features'''
84 | reverse_ports = True
85 | adj_list_in, adj_list_out = to_adj_nodes_with_times(self)
86 | in_ports = ports(self.edge_index, adj_list_in)
87 | out_ports = [ports(self.edge_index.flipud(), adj_list_out)] if reverse_ports else []
88 | self.edge_attr = torch.cat([self.edge_attr, in_ports] + out_ports, dim=1)
89 | return self
90 |
91 | def add_time_deltas(self):
92 | '''Adds time deltas (i.e. the time between subsequent transactions) to the edge features'''
93 | reverse_tds = True
94 | adj_list_in, adj_list_out = to_adj_edges_with_times(self)
95 | in_tds = time_deltas(self, adj_list_in)
96 | out_tds = [time_deltas(self, adj_list_out)] if reverse_tds else []
97 | self.edge_attr = torch.cat([self.edge_attr, in_tds] + out_tds, dim=1)
98 | return self
99 |
100 | class HeteroGraphData(HeteroData):
101 | '''This is the heterogenous graph object we use for GNN training if reverse MP is enabled'''
102 | def __init__(
103 | self,
104 | readout: str = 'edge',
105 | **kwargs
106 | ):
107 | super().__init__(**kwargs)
108 | self.readout = readout
109 |
110 | @property
111 | def num_nodes(self):
112 | return self['node'].x.shape[0]
113 |
114 | @property
115 | def timestamps(self):
116 | return self['node', 'to', 'node'].timestamps
117 |
118 | def add_ports(self):
119 | '''Adds port numberings to the edge features'''
120 | adj_list_in, adj_list_out = to_adj_nodes_with_times(self)
121 | in_ports = ports(self['node', 'to', 'node'].edge_index, adj_list_in)
122 | out_ports = ports(self['node', 'rev_to', 'node'].edge_index, adj_list_out)
123 | self['node', 'to', 'node'].edge_attr = torch.cat([self['node', 'to', 'node'].edge_attr, in_ports], dim=1)
124 | self['node', 'rev_to', 'node'].edge_attr = torch.cat([self['node', 'rev_to', 'node'].edge_attr, out_ports], dim=1)
125 | return self
126 |
127 | def add_time_deltas(self):
128 | '''Adds time deltas (i.e. the time between subsequent transactions) to the edge features'''
129 | adj_list_in, adj_list_out = to_adj_edges_with_times(self)
130 | in_tds = time_deltas(self, adj_list_in)
131 | out_tds = time_deltas(self, adj_list_out)
132 | self['node', 'to', 'node'].edge_attr = torch.cat([self['node', 'to', 'node'].edge_attr, in_tds], dim=1)
133 | self['node', 'rev_to', 'node'].edge_attr = torch.cat([self['node', 'rev_to', 'node'].edge_attr, out_tds], dim=1)
134 | return self
135 |
136 | def z_norm(data):
137 | std = data.std(0).unsqueeze(0)
138 | std = torch.where(std == 0, torch.tensor(1, dtype=torch.float32).cpu(), std)
139 | return (data - data.mean(0).unsqueeze(0)) / std
140 |
141 | def create_hetero_obj(x, y, edge_index, edge_attr, timestamps, args):
142 | '''Creates a heterogenous graph object for reverse message passing'''
143 | data = HeteroGraphData()
144 |
145 | data['node'].x = x
146 | data['node', 'to', 'node'].edge_index = edge_index
147 | data['node', 'rev_to', 'node'].edge_index = edge_index.flipud()
148 | data['node', 'to', 'node'].edge_attr = edge_attr
149 | data['node', 'rev_to', 'node'].edge_attr = edge_attr
150 | if args.ports:
151 | #swap the in- and outgoing port numberings for the reverse edges
152 | data['node', 'rev_to', 'node'].edge_attr[:, [-1, -2]] = data['node', 'rev_to', 'node'].edge_attr[:, [-2, -1]]
153 | data['node', 'to', 'node'].y = y
154 | data['node', 'to', 'node'].timestamps = timestamps
155 |
156 | return data
--------------------------------------------------------------------------------
/env.yml:
--------------------------------------------------------------------------------
1 | name: multignn
2 | channels:
3 | - pytorch
4 | - nvidia
5 | - pyg
6 | - conda-forge
7 | dependencies:
8 | - python=3.9
9 | - pip
10 | - pytorch::pytorch
11 | - pytorch::torchvision
12 | - pytorch::pytorch-cuda=11.8
13 | - cudatoolkit=11.8
14 | - pyg::pyg
15 | - pyg::pytorch-scatter
16 | - bzip2=1.0.8
17 | - ipykernel
18 | - ipython
19 | - ipywidgets
20 | - lz4
21 | - matplotlib
22 | - munch
23 | - numpy>=1.25.2
24 | - pandas=2.0.3
25 | - scikit-learn=1.3.0
26 | - scipy
27 | - tqdm
28 | - wandb
29 | - zstandard
30 | - zstd
31 | - datatable
32 | - pyg::pytorch-sparse
33 | - tabulate=0.9.0
--------------------------------------------------------------------------------
/format_kaggle_files.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import datatable as dt
3 | from datetime import datetime
4 | from datatable import f,join,sort
5 | import sys
6 | import os
7 |
8 | n = len(sys.argv)
9 |
10 | if n == 1:
11 | print("No input path")
12 | sys.exit()
13 |
14 | inPath = sys.argv[1]
15 | outPath = os.path.dirname(inPath) + "/formatted_transactions.csv"
16 |
17 | raw = dt.fread(inPath, columns = dt.str32)
18 |
19 | currency = dict()
20 | paymentFormat = dict()
21 | bankAcc = dict()
22 | account = dict()
23 |
24 | def get_dict_val(name, collection):
25 | if name in collection:
26 | val = collection[name]
27 | else:
28 | val = len(collection)
29 | collection[name] = val
30 | return val
31 |
32 | header = "EdgeID,from_id,to_id,Timestamp,\
33 | Amount Sent,Sent Currency,Amount Received,Received Currency,\
34 | Payment Format,Is Laundering\n"
35 |
36 | firstTs = -1
37 |
38 | with open(outPath, 'w') as writer:
39 | writer.write(header)
40 | for i in range(raw.nrows):
41 | datetime_object = datetime.strptime(raw[i,"Timestamp"], '%Y/%m/%d %H:%M')
42 | ts = datetime_object.timestamp()
43 | day = datetime_object.day
44 | month = datetime_object.month
45 | year = datetime_object.year
46 | hour = datetime_object.hour
47 | minute = datetime_object.minute
48 |
49 | if firstTs == -1:
50 | startTime = datetime(year, month, day)
51 | firstTs = startTime.timestamp() - 10
52 |
53 | ts = ts - firstTs
54 |
55 | cur1 = get_dict_val(raw[i,"Receiving Currency"], currency)
56 | cur2 = get_dict_val(raw[i,"Payment Currency"], currency)
57 |
58 | fmt = get_dict_val(raw[i,"Payment Format"], paymentFormat)
59 |
60 | fromAccIdStr = raw[i,"From Bank"] + raw[i,2]
61 | fromId = get_dict_val(fromAccIdStr, account)
62 |
63 | toAccIdStr = raw[i,"To Bank"] + raw[i,4]
64 | toId = get_dict_val(toAccIdStr, account)
65 |
66 | amountReceivedOrig = float(raw[i,"Amount Received"])
67 | amountPaidOrig = float(raw[i,"Amount Paid"])
68 |
69 | isl = int(raw[i,"Is Laundering"])
70 |
71 | line = '%d,%d,%d,%d,%f,%d,%f,%d,%d,%d\n' % \
72 | (i,fromId,toId,ts,amountPaidOrig,cur2, amountReceivedOrig,cur1,fmt,isl)
73 |
74 | writer.write(line)
75 |
76 | formatted = dt.fread(outPath)
77 | formatted = formatted[:,:,sort(3)]
78 |
79 | formatted.to_csv(outPath)
--------------------------------------------------------------------------------
/inference.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import pandas as pd
3 | from train_util import AddEgoIds, extract_param, add_arange_ids, get_loaders, evaluate_homo, evaluate_hetero
4 | from training import get_model
5 | from torch_geometric.nn import to_hetero, summary
6 | import wandb
7 | import logging
8 | import os
9 | import sys
10 | import time
11 |
12 | script_start = time.time()
13 |
14 | def infer_gnn(tr_data, val_data, te_data, tr_inds, val_inds, te_inds, args, data_config):
15 | #set device
16 | device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
17 |
18 | #define a model config dictionary and wandb logging at the same time
19 | wandb.init(
20 | mode="disabled" if args.testing else "online",
21 | project="your_proj_name",
22 |
23 | config={
24 | "epochs": args.n_epochs,
25 | "batch_size": args.batch_size,
26 | "model": args.model,
27 | "data": args.data,
28 | "num_neighbors": args.num_neighs,
29 | "lr": extract_param("lr", args),
30 | "n_hidden": extract_param("n_hidden", args),
31 | "n_gnn_layers": extract_param("n_gnn_layers", args),
32 | "loss": "ce",
33 | "w_ce1": extract_param("w_ce1", args),
34 | "w_ce2": extract_param("w_ce2", args),
35 | "dropout": extract_param("dropout", args),
36 | "final_dropout": extract_param("final_dropout", args),
37 | "n_heads": extract_param("n_heads", args) if args.model == 'gat' else None
38 | }
39 | )
40 |
41 | config = wandb.config
42 |
43 | #set the transform if ego ids should be used
44 | if args.ego:
45 | transform = AddEgoIds()
46 | else:
47 | transform = None
48 |
49 | #add the unique ids to later find the seed edges
50 | add_arange_ids([tr_data, val_data, te_data])
51 |
52 | tr_loader, val_loader, te_loader = get_loaders(tr_data, val_data, te_data, tr_inds, val_inds, te_inds, transform, args)
53 |
54 | #get the model
55 | sample_batch = next(iter(tr_loader))
56 | model = get_model(sample_batch, config, args)
57 |
58 | if args.reverse_mp:
59 | model = to_hetero(model, te_data.metadata(), aggr='mean')
60 |
61 | if not (args.avg_tps or args.finetune):
62 | command = " ".join(sys.argv)
63 | name = ""
64 | name = '-'.join(name.split('-')[3:])
65 | args.unique_name = name
66 |
67 | logging.info("=> loading model checkpoint")
68 | checkpoint = torch.load(f'{data_config["paths"]["model_to_load"]}/checkpoint_{args.unique_name}.tar')
69 | start_epoch = checkpoint['epoch']
70 | model.load_state_dict(checkpoint['model_state_dict'])
71 | model.to(device)
72 |
73 | logging.info("=> loaded checkpoint (epoch {})".format(start_epoch))
74 |
75 | if not args.reverse_mp:
76 | te_f1, te_prec, te_rec = evaluate_homo(te_loader, te_inds, model, te_data, device, args, precrec=True)
77 | else:
78 | te_f1, te_prec, te_rec = evaluate_hetero(te_loader, te_inds, model, te_data, device, args, precrec=True)
79 |
80 | wandb.finish()
--------------------------------------------------------------------------------
/main.py:
--------------------------------------------------------------------------------
1 | import time
2 | import logging
3 | from util import create_parser, set_seed, logger_setup
4 | from data_loading import get_data
5 | from training import train_gnn
6 | from inference import infer_gnn
7 | import json
8 |
9 | def main():
10 | parser = create_parser()
11 | args = parser.parse_args()
12 |
13 | with open('data_config.json', 'r') as config_file:
14 | data_config = json.load(config_file)
15 |
16 | # Setup logging
17 | logger_setup()
18 |
19 | #set seed
20 | set_seed(args.seed)
21 |
22 | #get data
23 | logging.info("Retrieving data")
24 | t1 = time.perf_counter()
25 |
26 | tr_data, val_data, te_data, tr_inds, val_inds, te_inds = get_data(args, data_config)
27 |
28 | t2 = time.perf_counter()
29 | logging.info(f"Retrieved data in {t2-t1:.2f}s")
30 |
31 | if not args.inference:
32 | #Training
33 | logging.info(f"Running Training")
34 | train_gnn(tr_data, val_data, te_data, tr_inds, val_inds, te_inds, args, data_config)
35 | else:
36 | #Inference
37 | logging.info(f"Running Inference")
38 | infer_gnn(tr_data, val_data, te_data, tr_inds, val_inds, te_inds, args, data_config)
39 |
40 | if __name__ == "__main__":
41 | main()
42 |
--------------------------------------------------------------------------------
/model_settings.json:
--------------------------------------------------------------------------------
1 | {
2 | "gin": {
3 | "params": {
4 | "lr": 0.006213266113989207, "n_hidden": 66.00315515631006, "n_mlp_layers": 1, "n_gnn_layers": 2, "loss": "ce",
5 | "w_ce1": 1.0000182882773443, "w_ce2": 6.275014431494497, "norm_method": "z_normalize", "dropout": 0.00983468338330501, "final_dropout": 0.10527690625126304
6 | },
7 | "bayes_opt_params": {
8 | "lr": [0.002, 0.007], "n_hidden": [66.0, 66.01], "n_mlp_layers": [1, 1.001], "n_gnn_layers": [2.0, 2.001], "loss": [0.0, 0.1],
9 | "w_ce1": [1, 1.001], "w_ce2": [6, 12], "norm_method": [0, 0.001], "dropout": [0, 0.05], "final_dropout": [0, 0.2]
10 | },
11 | "header": "run,tb,lr,n_hidden,n_mlp_layers,n_gnn_layers,loss,w_ce1,w_ce2,norm_method,dropout,final_dropout,epoch,tr_acc,tr_prec,tr_rec,tr_f1,tr_auc,val_acc,val_prec,val_rec,val_f1,val_auc,te_acc,te_prec,te_rec,te_f1,te_auc\n"
12 | },
13 | "pna": {
14 | "params": {
15 | "lr": 0.0006116418195373612, "n_hidden": 20, "n_mlp_layers": 1, "n_gnn_layers": 2, "loss": "ce", "w_ce1": 1.0003967674742307,
16 | "w_ce2": 7.077633468006714, "norm_method": "z_normalize", "dropout": 0.08340440094051481, "final_dropout": 0.28812979737686323
17 | },
18 | "bayes_opt_params": {
19 | "lr": [0.0001, 0.001], "n_hidden": [16, 64], "n_mlp_layers": [1, 1.001], "n_gnn_layers": [2.00, 2.01], "loss": [0.0, 0.1],
20 | "w_ce1": [1, 1.001], "w_ce2": [6, 12], "norm_method": [0, 0.1], "dropout": [0.0, 0.2], "final_dropout": [0.0, 0.4]
21 | },
22 | "header": "run,tb,lr,n_hidden,n_mlp_layers,n_gnn_layers,loss,w_ce1,w_ce2,norm_method,dropout,final_dropout,epoch,tr_acc,tr_prec,tr_rec,tr_f1,tr_auc,val_acc,val_prec,val_rec,val_f1,val_auc,te_acc,te_prec,te_rec,te_f1,te_auc\n"
23 | },
24 | "gat": {
25 | "params": {
26 | "lr": 0.006, "n_hidden": 64, "n_heads": 4, "n_mlp_layers": 1, "n_gnn_layers": 2, "loss": "ce", "w_ce1": 1, "w_ce2": 6,
27 | "norm_method": "z_normalize", "dropout": 0.009, "final_dropout": 0.1
28 | },
29 | "bayes_opt_params": {
30 | "lr": [0.01, 0.04], "n_hidden": [4, 24], "n_heads": [1.5, 4.5], "n_mlp_layers": [1, 1.001], "n_gnn_layers": [3, 7],
31 | "loss": [0, 0.1], "w_ce1": [1, 1.001], "w_ce2": [1, 10], "norm_method": [0, 0.1], "dropout": [0, 0.5], "final_dropout": [0, 0.8]
32 | },
33 | "header": "run,tb,lr,n_hidden,n_heads,n_mlp_layers,n_gnn_layers,loss,w_ce1,w_ce2,norm_method,dropout,final_dropout,epoch,tr_acc,tr_prec,tr_rec,tr_f1,tr_auc,val_acc,val_prec,val_rec,val_f1,val_auc,te_acc,te_prec,te_rec,te_f1,te_auc\n"
34 | },
35 | "mlp": {
36 | "params": {
37 | "lr": 0.006213266113989207, "n_hidden": 66.00315515631006, "n_mlp_layers": 1, "n_gnn_layers": 2, "loss": "ce", "w_ce1": 1.0000182882773443,
38 | "w_ce2": 9.23, "norm_method": "z_normalize", "dropout": 0.00983468338330501, "final_dropout": 0.10527690625126304
39 | },
40 | "bayes_opt_params": {
41 | "lr": [0.006, 0.0064], "n_hidden": [66.0, 66.01], "n_mlp_layers": [1, 1.001], "n_gnn_layers": [2.0, 2.001], "loss": [0.0, 0.1],
42 | "w_ce1": [1, 1.001], "w_ce2": [6, 12], "norm_method": [0, 0.001], "dropout": [0, 0.05], "final_dropout": [0, 0.2]
43 | },
44 | "header": "run,tb,lr,n_hidden,n_mlp_layers,n_gnn_layers,loss,w_ce1,w_ce2,norm_method,dropout,final_dropout,epoch,tr_acc,tr_prec,tr_rec,tr_f1,tr_auc,val_acc,val_prec,val_rec,val_f1,val_auc,te_acc,te_prec,te_rec,te_f1,te_auc\n"
45 | },
46 | "rgcn": {
47 | "params": {
48 | "lr": 0.006213266113989207, "n_hidden": 66.00315515631006, "n_mlp_layers": 1, "n_gnn_layers": 2, "loss": "ce", "w_ce1": 1.0000182882773443,
49 | "w_ce2": 9.23, "norm_method": "z_normalize", "dropout": 0.00983468338330501, "final_dropout": 0.10527690625126304
50 | },
51 | "bayes_opt_params": {
52 | "lr": [0.006, 0.0064], "n_hidden": [66.0, 66.01], "n_mlp_layers": [1, 1.001], "n_gnn_layers": [2.0, 2.001], "loss": [0.0, 0.1],
53 | "w_ce1": [1, 1.001], "w_ce2": [6, 12], "norm_method": [0, 0.001], "dropout": [0, 0.05], "final_dropout": [0, 0.2]
54 | },
55 | "header": "run,tb,lr,n_hidden,n_mlp_layers,n_gnn_layers,loss,w_ce1,w_ce2,norm_method,dropout,final_dropout,epoch,tr_acc,tr_prec,tr_rec,tr_f1,tr_auc,val_acc,val_prec,val_rec,val_f1,val_auc,te_acc,te_prec,te_rec,te_f1,te_auc\n"
56 | }
57 | }
--------------------------------------------------------------------------------
/models.py:
--------------------------------------------------------------------------------
1 | import torch.nn as nn
2 | from torch_geometric.nn import GINEConv, BatchNorm, Linear, GATConv, PNAConv, RGCNConv
3 | import torch.nn.functional as F
4 | import torch
5 | import logging
6 |
7 | class GINe(torch.nn.Module):
8 | def __init__(self, num_features, num_gnn_layers, n_classes=2,
9 | n_hidden=100, edge_updates=False, residual=True,
10 | edge_dim=None, dropout=0.0, final_dropout=0.5):
11 | super().__init__()
12 | self.n_hidden = n_hidden
13 | self.num_gnn_layers = num_gnn_layers
14 | self.edge_updates = edge_updates
15 | self.final_dropout = final_dropout
16 |
17 | self.node_emb = nn.Linear(num_features, n_hidden)
18 | self.edge_emb = nn.Linear(edge_dim, n_hidden)
19 |
20 | self.convs = nn.ModuleList()
21 | self.emlps = nn.ModuleList()
22 | self.batch_norms = nn.ModuleList()
23 | for _ in range(self.num_gnn_layers):
24 | conv = GINEConv(nn.Sequential(
25 | nn.Linear(self.n_hidden, self.n_hidden),
26 | nn.ReLU(),
27 | nn.Linear(self.n_hidden, self.n_hidden)
28 | ), edge_dim=self.n_hidden)
29 | if self.edge_updates: self.emlps.append(nn.Sequential(
30 | nn.Linear(3 * self.n_hidden, self.n_hidden),
31 | nn.ReLU(),
32 | nn.Linear(self.n_hidden, self.n_hidden),
33 | ))
34 | self.convs.append(conv)
35 | self.batch_norms.append(BatchNorm(n_hidden))
36 |
37 | self.mlp = nn.Sequential(Linear(n_hidden*3, 50), nn.ReLU(), nn.Dropout(self.final_dropout),Linear(50, 25), nn.ReLU(), nn.Dropout(self.final_dropout),
38 | Linear(25, n_classes))
39 |
40 | def forward(self, x, edge_index, edge_attr):
41 | src, dst = edge_index
42 |
43 | x = self.node_emb(x)
44 | edge_attr = self.edge_emb(edge_attr)
45 |
46 | for i in range(self.num_gnn_layers):
47 | x = (x + F.relu(self.batch_norms[i](self.convs[i](x, edge_index, edge_attr)))) / 2
48 | if self.edge_updates:
49 | edge_attr = edge_attr + self.emlps[i](torch.cat([x[src], x[dst], edge_attr], dim=-1)) / 2
50 |
51 | x = x[edge_index.T].reshape(-1, 2 * self.n_hidden).relu()
52 | x = torch.cat((x, edge_attr.view(-1, edge_attr.shape[1])), 1)
53 | out = x
54 |
55 | return self.mlp(out)
56 |
57 | class GATe(torch.nn.Module):
58 | def __init__(self, num_features, num_gnn_layers, n_classes=2, n_hidden=100, n_heads=4, edge_updates=False, edge_dim=None, dropout=0.0, final_dropout=0.5):
59 | super().__init__()
60 | # GAT specific code
61 | tmp_out = n_hidden // n_heads
62 | n_hidden = tmp_out * n_heads
63 |
64 | self.n_hidden = n_hidden
65 | self.n_heads = n_heads
66 | self.num_gnn_layers = num_gnn_layers
67 | self.edge_updates = edge_updates
68 | self.dropout = dropout
69 | self.final_dropout = final_dropout
70 |
71 | self.node_emb = nn.Linear(num_features, n_hidden)
72 | self.edge_emb = nn.Linear(edge_dim, n_hidden)
73 |
74 | self.convs = nn.ModuleList()
75 | self.emlps = nn.ModuleList()
76 | self.batch_norms = nn.ModuleList()
77 |
78 | for _ in range(self.num_gnn_layers):
79 | conv = GATConv(self.n_hidden, tmp_out, self.n_heads, concat = True, dropout = self.dropout, add_self_loops = True, edge_dim=self.n_hidden)
80 | if self.edge_updates: self.emlps.append(nn.Sequential(nn.Linear(3 * self.n_hidden, self.n_hidden),nn.ReLU(),nn.Linear(self.n_hidden, self.n_hidden),))
81 | self.convs.append(conv)
82 | self.batch_norms.append(BatchNorm(n_hidden))
83 |
84 | self.mlp = nn.Sequential(Linear(n_hidden*3, 50), nn.ReLU(), nn.Dropout(self.final_dropout),Linear(50, 25), nn.ReLU(), nn.Dropout(self.final_dropout),Linear(25, n_classes))
85 |
86 | def forward(self, x, edge_index, edge_attr):
87 | src, dst = edge_index
88 |
89 | x = self.node_emb(x)
90 | edge_attr = self.edge_emb(edge_attr)
91 |
92 | for i in range(self.num_gnn_layers):
93 | x = (x + F.relu(self.batch_norms[i](self.convs[i](x, edge_index, edge_attr)))) / 2
94 | if self.edge_updates:
95 | edge_attr = edge_attr + self.emlps[i](torch.cat([x[src], x[dst], edge_attr], dim=-1)) / 2
96 |
97 | logging.debug(f"x.shape = {x.shape}, x[edge_index.T].shape = {x[edge_index.T].shape}")
98 | x = x[edge_index.T].reshape(-1, 2 * self.n_hidden).relu()
99 | logging.debug(f"x.shape = {x.shape}")
100 | x = torch.cat((x, edge_attr.view(-1, edge_attr.shape[1])), 1)
101 | logging.debug(f"x.shape = {x.shape}")
102 | out = x
103 |
104 | return self.mlp(out)
105 |
106 | class PNA(torch.nn.Module):
107 | def __init__(self, num_features, num_gnn_layers, n_classes=2,
108 | n_hidden=100, edge_updates=True,
109 | edge_dim=None, dropout=0.0, final_dropout=0.5, deg=None):
110 | super().__init__()
111 | n_hidden = int((n_hidden // 5) * 5)
112 | self.n_hidden = n_hidden
113 | self.num_gnn_layers = num_gnn_layers
114 | self.edge_updates = edge_updates
115 | self.final_dropout = final_dropout
116 |
117 | aggregators = ['mean', 'min', 'max', 'std']
118 | scalers = ['identity', 'amplification', 'attenuation']
119 |
120 | self.node_emb = nn.Linear(num_features, n_hidden)
121 | self.edge_emb = nn.Linear(edge_dim, n_hidden)
122 |
123 | self.convs = nn.ModuleList()
124 | self.emlps = nn.ModuleList()
125 | self.batch_norms = nn.ModuleList()
126 | for _ in range(self.num_gnn_layers):
127 | conv = PNAConv(in_channels=n_hidden, out_channels=n_hidden,
128 | aggregators=aggregators, scalers=scalers, deg=deg,
129 | edge_dim=n_hidden, towers=5, pre_layers=1, post_layers=1,
130 | divide_input=False)
131 | if self.edge_updates: self.emlps.append(nn.Sequential(
132 | nn.Linear(3 * self.n_hidden, self.n_hidden),
133 | nn.ReLU(),
134 | nn.Linear(self.n_hidden, self.n_hidden),
135 | ))
136 | self.convs.append(conv)
137 | self.batch_norms.append(BatchNorm(n_hidden))
138 |
139 | self.mlp = nn.Sequential(Linear(n_hidden*3, 50), nn.ReLU(), nn.Dropout(self.final_dropout),Linear(50, 25), nn.ReLU(), nn.Dropout(self.final_dropout),
140 | Linear(25, n_classes))
141 |
142 | def forward(self, x, edge_index, edge_attr):
143 | src, dst = edge_index
144 |
145 | x = self.node_emb(x)
146 | edge_attr = self.edge_emb(edge_attr)
147 |
148 | for i in range(self.num_gnn_layers):
149 | x = (x + F.relu(self.batch_norms[i](self.convs[i](x, edge_index, edge_attr)))) / 2
150 | if self.edge_updates:
151 | edge_attr = edge_attr + self.emlps[i](torch.cat([x[src], x[dst], edge_attr], dim=-1)) / 2
152 |
153 | logging.debug(f"x.shape = {x.shape}, x[edge_index.T].shape = {x[edge_index.T].shape}")
154 | x = x[edge_index.T].reshape(-1, 2 * self.n_hidden).relu()
155 | logging.debug(f"x.shape = {x.shape}")
156 | x = torch.cat((x, edge_attr.view(-1, edge_attr.shape[1])), 1)
157 | logging.debug(f"x.shape = {x.shape}")
158 | out = x
159 | return self.mlp(out)
160 |
161 | class RGCN(nn.Module):
162 | def __init__(self, num_features, edge_dim, num_relations, num_gnn_layers, n_classes=2,
163 | n_hidden=100, edge_update=False,
164 | residual=True,
165 | dropout=0.0, final_dropout=0.5, n_bases=-1):
166 | super(RGCN, self).__init__()
167 |
168 | self.num_features = num_features
169 | self.num_gnn_layers = num_gnn_layers
170 | self.n_hidden = n_hidden
171 | self.residual = residual
172 | self.dropout = dropout
173 | self.final_dropout = final_dropout
174 | self.n_classes = n_classes
175 | self.edge_update = edge_update
176 | self.num_relations = num_relations
177 | self.n_bases = n_bases
178 |
179 | self.node_emb = nn.Linear(num_features, n_hidden)
180 | self.edge_emb = nn.Linear(edge_dim, n_hidden)
181 |
182 | self.convs = nn.ModuleList()
183 | self.bns = nn.ModuleList()
184 | self.mlp = nn.ModuleList()
185 |
186 | if self.edge_update:
187 | self.emlps = nn.ModuleList()
188 | self.emlps.append(nn.Sequential(
189 | nn.Linear(3 * self.n_hidden, self.n_hidden),
190 | nn.ReLU(),
191 | nn.Linear(self.n_hidden, self.n_hidden),
192 | ))
193 |
194 | for _ in range(self.num_gnn_layers):
195 | conv = RGCNConv(self.n_hidden, self.n_hidden, num_relations, num_bases=self.n_bases)
196 | self.convs.append(conv)
197 | self.bns.append(nn.BatchNorm1d(self.n_hidden))
198 |
199 | if self.edge_update:
200 | self.emlps.append(nn.Sequential(
201 | nn.Linear(3 * self.n_hidden, self.n_hidden),
202 | nn.ReLU(),
203 | nn.Linear(self.n_hidden, self.n_hidden),
204 | ))
205 |
206 | self.mlp = nn.Sequential(Linear(n_hidden*3, 50), nn.ReLU(), nn.Dropout(self.final_dropout), Linear(50, 25), nn.ReLU(), nn.Dropout(self.final_dropout),
207 | Linear(25, n_classes))
208 |
209 | def reset_parameters(self):
210 | for m in self.modules():
211 | if isinstance(m, nn.Linear):
212 | m.reset_parameters()
213 | elif isinstance(m, RGCNConv):
214 | m.reset_parameters()
215 | elif isinstance(m, nn.BatchNorm1d):
216 | m.reset_parameters()
217 |
218 | def forward(self, x, edge_index, edge_attr):
219 | edge_type = edge_attr[:, -1].long()
220 | #edge_attr = edge_attr[:, :-1]
221 | src, dst = edge_index
222 |
223 | x = self.node_emb(x)
224 | edge_attr = self.edge_emb(edge_attr)
225 |
226 | for i in range(self.num_gnn_layers):
227 | x = (x + F.relu(self.bns[i](self.convs[i](x, edge_index, edge_type)))) / 2
228 | if self.edge_update:
229 | edge_attr = (edge_attr + F.relu(self.emlps[i](torch.cat([x[src], x[dst], edge_attr], dim=-1)))) / 2
230 |
231 | x = x[edge_index.T].reshape(-1, 2 * self.n_hidden).relu()
232 | x = torch.cat((x, edge_attr.view(-1, edge_attr.shape[1])), 1)
233 | x = self.mlp(x)
234 | out = x
235 |
236 | return x
--------------------------------------------------------------------------------
/train_util.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import tqdm
3 | from torch_geometric.transforms import BaseTransform
4 | from typing import Union
5 | from torch_geometric.data import Data, HeteroData
6 | from torch_geometric.loader import LinkNeighborLoader
7 | from sklearn.metrics import f1_score
8 | import json
9 |
10 | class AddEgoIds(BaseTransform):
11 | r"""Add IDs to the centre nodes of the batch.
12 | """
13 | def __init__(self):
14 | pass
15 |
16 | def __call__(self, data: Union[Data, HeteroData]):
17 | x = data.x if not isinstance(data, HeteroData) else data['node'].x
18 | device = x.device
19 | ids = torch.zeros((x.shape[0], 1), device=device)
20 | if not isinstance(data, HeteroData):
21 | nodes = torch.unique(data.edge_label_index.view(-1)).to(device)
22 | else:
23 | nodes = torch.unique(data['node', 'to', 'node'].edge_label_index.view(-1)).to(device)
24 | ids[nodes] = 1
25 | if not isinstance(data, HeteroData):
26 | data.x = torch.cat([x, ids], dim=1)
27 | else:
28 | data['node'].x = torch.cat([x, ids], dim=1)
29 |
30 | return data
31 |
32 | def extract_param(parameter_name: str, args) -> float:
33 | """
34 | Extract the value of the specified parameter for the given model.
35 |
36 | Args:
37 | - parameter_name (str): Name of the parameter (e.g., "lr").
38 | - args (argparser): Arguments given to this specific run.
39 |
40 | Returns:
41 | - float: Value of the specified parameter.
42 | """
43 | file_path = './model_settings.json'
44 | with open(file_path, "r") as file:
45 | data = json.load(file)
46 |
47 | return data.get(args.model, {}).get("params", {}).get(parameter_name, None)
48 |
49 | def add_arange_ids(data_list):
50 | '''
51 | Add the index as an id to the edge features to find seed edges in training, validation and testing.
52 |
53 | Args:
54 | - data_list (str): List of tr_data, val_data and te_data.
55 | '''
56 | for data in data_list:
57 | if isinstance(data, HeteroData):
58 | data['node', 'to', 'node'].edge_attr = torch.cat([torch.arange(data['node', 'to', 'node'].edge_attr.shape[0]).view(-1, 1), data['node', 'to', 'node'].edge_attr], dim=1)
59 | offset = data['node', 'to', 'node'].edge_attr.shape[0]
60 | data['node', 'rev_to', 'node'].edge_attr = torch.cat([torch.arange(offset, data['node', 'rev_to', 'node'].edge_attr.shape[0] + offset).view(-1, 1), data['node', 'rev_to', 'node'].edge_attr], dim=1)
61 | else:
62 | data.edge_attr = torch.cat([torch.arange(data.edge_attr.shape[0]).view(-1, 1), data.edge_attr], dim=1)
63 |
64 | def get_loaders(tr_data, val_data, te_data, tr_inds, val_inds, te_inds, transform, args):
65 | if isinstance(tr_data, HeteroData):
66 | tr_edge_label_index = tr_data['node', 'to', 'node'].edge_index
67 | tr_edge_label = tr_data['node', 'to', 'node'].y
68 |
69 |
70 | tr_loader = LinkNeighborLoader(tr_data, num_neighbors=args.num_neighs,
71 | edge_label_index=(('node', 'to', 'node'), tr_edge_label_index),
72 | edge_label=tr_edge_label, batch_size=args.batch_size, shuffle=True, transform=transform)
73 |
74 | val_edge_label_index = val_data['node', 'to', 'node'].edge_index[:,val_inds]
75 | val_edge_label = val_data['node', 'to', 'node'].y[val_inds]
76 |
77 |
78 | val_loader = LinkNeighborLoader(val_data, num_neighbors=args.num_neighs,
79 | edge_label_index=(('node', 'to', 'node'), val_edge_label_index),
80 | edge_label=val_edge_label, batch_size=args.batch_size, shuffle=False, transform=transform)
81 |
82 | te_edge_label_index = te_data['node', 'to', 'node'].edge_index[:,te_inds]
83 | te_edge_label = te_data['node', 'to', 'node'].y[te_inds]
84 |
85 |
86 | te_loader = LinkNeighborLoader(te_data, num_neighbors=args.num_neighs,
87 | edge_label_index=(('node', 'to', 'node'), te_edge_label_index),
88 | edge_label=te_edge_label, batch_size=args.batch_size, shuffle=False, transform=transform)
89 | else:
90 | tr_loader = LinkNeighborLoader(tr_data, num_neighbors=args.num_neighs, batch_size=args.batch_size, shuffle=True, transform=transform)
91 | val_loader = LinkNeighborLoader(val_data,num_neighbors=args.num_neighs, edge_label_index=val_data.edge_index[:, val_inds],
92 | edge_label=val_data.y[val_inds], batch_size=args.batch_size, shuffle=False, transform=transform)
93 | te_loader = LinkNeighborLoader(te_data,num_neighbors=args.num_neighs, edge_label_index=te_data.edge_index[:, te_inds],
94 | edge_label=te_data.y[te_inds], batch_size=args.batch_size, shuffle=False, transform=transform)
95 |
96 | return tr_loader, val_loader, te_loader
97 |
98 | @torch.no_grad()
99 | def evaluate_homo(loader, inds, model, data, device, args):
100 | '''Evaluates the model performane for homogenous graph data.'''
101 | preds = []
102 | ground_truths = []
103 | for batch in tqdm.tqdm(loader, disable=not args.tqdm):
104 | #select the seed edges from which the batch was created
105 | inds = inds.detach().cpu()
106 | batch_edge_inds = inds[batch.input_id.detach().cpu()]
107 | batch_edge_ids = loader.data.edge_attr.detach().cpu()[batch_edge_inds, 0]
108 | mask = torch.isin(batch.edge_attr[:, 0].detach().cpu(), batch_edge_ids)
109 |
110 | #add the seed edges that have not been sampled to the batch
111 | missing = ~torch.isin(batch_edge_ids, batch.edge_attr[:, 0].detach().cpu())
112 |
113 | if missing.sum() != 0 and (args.data == 'Small_J' or args.data == 'Small_Q'):
114 | missing_ids = batch_edge_ids[missing].int()
115 | n_ids = batch.n_id
116 | add_edge_index = data.edge_index[:, missing_ids].detach().clone()
117 | node_mapping = {value.item(): idx for idx, value in enumerate(n_ids)}
118 | add_edge_index = torch.tensor([[node_mapping[val.item()] for val in row] for row in add_edge_index])
119 | add_edge_attr = data.edge_attr[missing_ids, :].detach().clone()
120 | add_y = data.y[missing_ids].detach().clone()
121 |
122 | batch.edge_index = torch.cat((batch.edge_index, add_edge_index), 1)
123 | batch.edge_attr = torch.cat((batch.edge_attr, add_edge_attr), 0)
124 | batch.y = torch.cat((batch.y, add_y), 0)
125 |
126 | mask = torch.cat((mask, torch.ones(add_y.shape[0], dtype=torch.bool)))
127 |
128 | #remove the unique edge id from the edge features, as it's no longer needed
129 | batch.edge_attr = batch.edge_attr[:, 1:]
130 |
131 | with torch.no_grad():
132 | batch.to(device)
133 | out = model(batch.x, batch.edge_index, batch.edge_attr)
134 | out = out[mask]
135 | pred = out.argmax(dim=-1)
136 | preds.append(pred)
137 | ground_truths.append(batch.y[mask])
138 | pred = torch.cat(preds, dim=0).cpu().numpy()
139 | ground_truth = torch.cat(ground_truths, dim=0).cpu().numpy()
140 | f1 = f1_score(ground_truth, pred)
141 |
142 | return f1
143 |
144 | @torch.no_grad()
145 | def evaluate_hetero(loader, inds, model, data, device, args):
146 | '''Evaluates the model performane for heterogenous graph data.'''
147 | preds = []
148 | ground_truths = []
149 | for batch in tqdm.tqdm(loader, disable=not args.tqdm):
150 | #select the seed edges from which the batch was created
151 | inds = inds.detach().cpu()
152 | batch_edge_inds = inds[batch['node', 'to', 'node'].input_id.detach().cpu()]
153 | batch_edge_ids = loader.data['node', 'to', 'node'].edge_attr.detach().cpu()[batch_edge_inds, 0]
154 | mask = torch.isin(batch['node', 'to', 'node'].edge_attr[:, 0].detach().cpu(), batch_edge_ids)
155 |
156 | #add the seed edges that have not been sampled to the batch
157 | missing = ~torch.isin(batch_edge_ids, batch['node', 'to', 'node'].edge_attr[:, 0].detach().cpu())
158 |
159 | if missing.sum() != 0 and (args.data == 'Small_J' or args.data == 'Small_Q'):
160 | missing_ids = batch_edge_ids[missing].int()
161 | n_ids = batch['node'].n_id
162 | add_edge_index = data['node', 'to', 'node'].edge_index[:, missing_ids].detach().clone()
163 | node_mapping = {value.item(): idx for idx, value in enumerate(n_ids)}
164 | add_edge_index = torch.tensor([[node_mapping[val.item()] for val in row] for row in add_edge_index])
165 | add_edge_attr = data['node', 'to', 'node'].edge_attr[missing_ids, :].detach().clone()
166 | add_y = data['node', 'to', 'node'].y[missing_ids].detach().clone()
167 |
168 | batch['node', 'to', 'node'].edge_index = torch.cat((batch['node', 'to', 'node'].edge_index, add_edge_index), 1)
169 | batch['node', 'to', 'node'].edge_attr = torch.cat((batch['node', 'to', 'node'].edge_attr, add_edge_attr), 0)
170 | batch['node', 'to', 'node'].y = torch.cat((batch['node', 'to', 'node'].y, add_y), 0)
171 |
172 | mask = torch.cat((mask, torch.ones(add_y.shape[0], dtype=torch.bool)))
173 |
174 | #remove the unique edge id from the edge features, as it's no longer needed
175 | batch['node', 'to', 'node'].edge_attr = batch['node', 'to', 'node'].edge_attr[:, 1:]
176 | batch['node', 'rev_to', 'node'].edge_attr = batch['node', 'rev_to', 'node'].edge_attr[:, 1:]
177 |
178 | with torch.no_grad():
179 | batch.to(device)
180 | out = model(batch.x_dict, batch.edge_index_dict, batch.edge_attr_dict)
181 | out = out[('node', 'to', 'node')]
182 | out = out[mask]
183 | pred = out.argmax(dim=-1)
184 | preds.append(pred)
185 | ground_truths.append(batch['node', 'to', 'node'].y[mask])
186 | pred = torch.cat(preds, dim=0).cpu().numpy()
187 | ground_truth = torch.cat(ground_truths, dim=0).cpu().numpy()
188 | f1 = f1_score(ground_truth, pred)
189 |
190 | return f1
191 |
192 | def save_model(model, optimizer, epoch, args, data_config):
193 | # Save the model in a dictionary
194 | torch.save({
195 | 'epoch': epoch + 1,
196 | 'model_state_dict': model.state_dict(),
197 | 'optimizer_state_dict': optimizer.state_dict()
198 | }, f'{data_config["paths"]["model_to_save"]}/checkpoint_{args.unique_name}{"" if not args.finetune else "_finetuned"}.tar')
199 |
200 | def load_model(model, device, args, config, data_config):
201 | checkpoint = torch.load(f'{data_config["paths"]["model_to_load"]}/checkpoint_{args.unique_name}.tar')
202 | model.load_state_dict(checkpoint['model_state_dict'])
203 | model.to(device)
204 | optimizer = torch.optim.Adam(model.parameters(), lr=config.lr)
205 | optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
206 |
207 | return model, optimizer
--------------------------------------------------------------------------------
/training.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import tqdm
3 | from sklearn.metrics import f1_score
4 | from train_util import AddEgoIds, extract_param, add_arange_ids, get_loaders, evaluate_homo, evaluate_hetero, save_model, load_model
5 | from models import GINe, PNA, GATe, RGCN
6 | from torch_geometric.data import Data, HeteroData
7 | from torch_geometric.nn import to_hetero, summary
8 | from torch_geometric.utils import degree
9 | import wandb
10 | import logging
11 |
12 | def train_homo(tr_loader, val_loader, te_loader, tr_inds, val_inds, te_inds, model, optimizer, loss_fn, args, config, device, val_data, te_data, data_config):
13 | #training
14 | best_val_f1 = 0
15 | for epoch in range(config.epochs):
16 | total_loss = total_examples = 0
17 | preds = []
18 | ground_truths = []
19 | for batch in tqdm.tqdm(tr_loader, disable=not args.tqdm):
20 | optimizer.zero_grad()
21 | #select the seed edges from which the batch was created
22 | inds = tr_inds.detach().cpu()
23 | batch_edge_inds = inds[batch.input_id.detach().cpu()]
24 | batch_edge_ids = tr_loader.data.edge_attr.detach().cpu()[batch_edge_inds, 0]
25 | mask = torch.isin(batch.edge_attr[:, 0].detach().cpu(), batch_edge_ids)
26 |
27 | #remove the unique edge id from the edge features, as it's no longer needed
28 | batch.edge_attr = batch.edge_attr[:, 1:]
29 |
30 | batch.to(device)
31 | out = model(batch.x, batch.edge_index, batch.edge_attr)
32 | pred = out[mask]
33 | ground_truth = batch.y[mask]
34 | preds.append(pred.argmax(dim=-1))
35 | ground_truths.append(ground_truth)
36 | loss = loss_fn(pred, ground_truth)
37 |
38 | loss.backward()
39 | optimizer.step()
40 |
41 | total_loss += float(loss) * pred.numel()
42 | total_examples += pred.numel()
43 |
44 | pred = torch.cat(preds, dim=0).detach().cpu().numpy()
45 | ground_truth = torch.cat(ground_truths, dim=0).detach().cpu().numpy()
46 | f1 = f1_score(ground_truth, pred)
47 | wandb.log({"f1/train": f1}, step=epoch)
48 | logging.info(f'Train F1: {f1:.4f}')
49 |
50 | #evaluate
51 | val_f1 = evaluate_homo(val_loader, val_inds, model, val_data, device, args)
52 | te_f1 = evaluate_homo(te_loader, te_inds, model, te_data, device, args)
53 |
54 | wandb.log({"f1/validation": val_f1}, step=epoch)
55 | wandb.log({"f1/test": te_f1}, step=epoch)
56 | logging.info(f'Validation F1: {val_f1:.4f}')
57 | logging.info(f'Test F1: {te_f1:.4f}')
58 |
59 | if epoch == 0:
60 | wandb.log({"best_test_f1": te_f1}, step=epoch)
61 | elif val_f1 > best_val_f1:
62 | best_val_f1 = val_f1
63 | wandb.log({"best_test_f1": te_f1}, step=epoch)
64 | if args.save_model:
65 | save_model(model, optimizer, epoch, args, data_config)
66 |
67 | return model
68 |
69 | def train_hetero(tr_loader, val_loader, te_loader, tr_inds, val_inds, te_inds, model, optimizer, loss_fn, args, config, device, val_data, te_data, data_config):
70 | #training
71 | best_val_f1 = 0
72 | for epoch in range(config.epochs):
73 | total_loss = total_examples = 0
74 | preds = []
75 | ground_truths = []
76 | for batch in tqdm.tqdm(tr_loader, disable=not args.tqdm):
77 | optimizer.zero_grad()
78 | #select the seed edges from which the batch was created
79 | inds = tr_inds.detach().cpu()
80 | batch_edge_inds = inds[batch['node', 'to', 'node'].input_id.detach().cpu()]
81 | batch_edge_ids = tr_loader.data['node', 'to', 'node'].edge_attr.detach().cpu()[batch_edge_inds, 0]
82 | mask = torch.isin(batch['node', 'to', 'node'].edge_attr[:, 0].detach().cpu(), batch_edge_ids)
83 |
84 | #remove the unique edge id from the edge features, as it's no longer needed
85 | batch['node', 'to', 'node'].edge_attr = batch['node', 'to', 'node'].edge_attr[:, 1:]
86 | batch['node', 'rev_to', 'node'].edge_attr = batch['node', 'rev_to', 'node'].edge_attr[:, 1:]
87 |
88 | batch.to(device)
89 | out = model(batch.x_dict, batch.edge_index_dict, batch.edge_attr_dict)
90 | out = out[('node', 'to', 'node')]
91 | pred = out[mask]
92 | ground_truth = batch['node', 'to', 'node'].y[mask]
93 | preds.append(pred.argmax(dim=-1))
94 | ground_truths.append(batch['node', 'to', 'node'].y[mask])
95 | loss = loss_fn(pred, ground_truth)
96 |
97 | loss.backward()
98 | optimizer.step()
99 |
100 | total_loss += float(loss) * pred.numel()
101 | total_examples += pred.numel()
102 |
103 | pred = torch.cat(preds, dim=0).detach().cpu().numpy()
104 | ground_truth = torch.cat(ground_truths, dim=0).detach().cpu().numpy()
105 | f1 = f1_score(ground_truth, pred)
106 | wandb.log({"f1/train": f1}, step=epoch)
107 | logging.info(f'Train F1: {f1:.4f}')
108 |
109 | #evaluate
110 | val_f1 = evaluate_hetero(val_loader, val_inds, model, val_data, device, args)
111 | te_f1 = evaluate_hetero(te_loader, te_inds, model, te_data, device, args)
112 |
113 | wandb.log({"f1/validation": val_f1}, step=epoch)
114 | wandb.log({"f1/test": te_f1}, step=epoch)
115 | logging.info(f'Validation F1: {val_f1:.4f}')
116 | logging.info(f'Test F1: {te_f1:.4f}')
117 |
118 | if epoch == 0:
119 | wandb.log({"best_test_f1": te_f1}, step=epoch)
120 | elif val_f1 > best_val_f1:
121 | best_val_f1 = val_f1
122 | wandb.log({"best_test_f1": te_f1}, step=epoch)
123 | if args.save_model:
124 | save_model(model, optimizer, epoch, args, data_config)
125 |
126 | return model
127 |
128 | def get_model(sample_batch, config, args):
129 | n_feats = sample_batch.x.shape[1] if not isinstance(sample_batch, HeteroData) else sample_batch['node'].x.shape[1]
130 | e_dim = (sample_batch.edge_attr.shape[1] - 1) if not isinstance(sample_batch, HeteroData) else (sample_batch['node', 'to', 'node'].edge_attr.shape[1] - 1)
131 |
132 | if args.model == "gin":
133 | model = GINe(
134 | num_features=n_feats, num_gnn_layers=config.n_gnn_layers, n_classes=2,
135 | n_hidden=round(config.n_hidden), residual=False, edge_updates=args.emlps, edge_dim=e_dim,
136 | dropout=config.dropout, final_dropout=config.final_dropout
137 | )
138 | elif args.model == "gat":
139 | model = GATe(
140 | num_features=n_feats, num_gnn_layers=config.n_gnn_layers, n_classes=2,
141 | n_hidden=round(config.n_hidden), n_heads=round(config.n_heads),
142 | edge_updates=args.emlps, edge_dim=e_dim,
143 | dropout=config.dropout, final_dropout=config.final_dropout
144 | )
145 | elif args.model == "pna":
146 | if not isinstance(sample_batch, HeteroData):
147 | d = degree(sample_batch.edge_index[1], dtype=torch.long)
148 | else:
149 | index = torch.cat((sample_batch['node', 'to', 'node'].edge_index[1], sample_batch['node', 'rev_to', 'node'].edge_index[1]), 0)
150 | d = degree(index, dtype=torch.long)
151 | deg = torch.bincount(d, minlength=1)
152 | model = PNA(
153 | num_features=n_feats, num_gnn_layers=config.n_gnn_layers, n_classes=2,
154 | n_hidden=round(config.n_hidden), edge_updates=args.emlps, edge_dim=e_dim,
155 | dropout=config.dropout, deg=deg, final_dropout=config.final_dropout
156 | )
157 | elif config.model == "rgcn":
158 | model = RGCN(
159 | num_features=n_feats, edge_dim=e_dim, num_relations=8, num_gnn_layers=round(config.n_gnn_layers),
160 | n_classes=2, n_hidden=round(config.n_hidden),
161 | edge_update=args.emlps, dropout=config.dropout, final_dropout=config.final_dropout, n_bases=None #(maybe)
162 | )
163 |
164 | return model
165 |
166 | def train_gnn(tr_data, val_data, te_data, tr_inds, val_inds, te_inds, args, data_config):
167 | #set device
168 | device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
169 |
170 | #define a model config dictionary and wandb logging at the same time
171 | wandb.init(
172 | mode="disabled" if args.testing else "online",
173 | project="your_proj_name", #replace this with your wandb project name if you want to use wandb logging
174 |
175 | config={
176 | "epochs": args.n_epochs,
177 | "batch_size": args.batch_size,
178 | "model": args.model,
179 | "data": args.data,
180 | "num_neighbors": args.num_neighs,
181 | "lr": extract_param("lr", args),
182 | "n_hidden": extract_param("n_hidden", args),
183 | "n_gnn_layers": extract_param("n_gnn_layers", args),
184 | "loss": "ce",
185 | "w_ce1": extract_param("w_ce1", args),
186 | "w_ce2": extract_param("w_ce2", args),
187 | "dropout": extract_param("dropout", args),
188 | "final_dropout": extract_param("final_dropout", args),
189 | "n_heads": extract_param("n_heads", args) if args.model == 'gat' else None
190 | }
191 | )
192 |
193 | config = wandb.config
194 |
195 | #set the transform if ego ids should be used
196 | if args.ego:
197 | transform = AddEgoIds()
198 | else:
199 | transform = None
200 |
201 | #add the unique ids to later find the seed edges
202 | add_arange_ids([tr_data, val_data, te_data])
203 |
204 | tr_loader, val_loader, te_loader = get_loaders(tr_data, val_data, te_data, tr_inds, val_inds, te_inds, transform, args)
205 |
206 | #get the model
207 | sample_batch = next(iter(tr_loader))
208 | model = get_model(sample_batch, config, args)
209 |
210 | if args.reverse_mp:
211 | model = to_hetero(model, te_data.metadata(), aggr='mean')
212 |
213 | if args.finetune:
214 | model, optimizer = load_model(model, device, args, config, data_config)
215 | else:
216 | model.to(device)
217 | optimizer = torch.optim.Adam(model.parameters(), lr=config.lr)
218 |
219 | sample_batch.to(device)
220 | sample_x = sample_batch.x if not isinstance(sample_batch, HeteroData) else sample_batch.x_dict
221 | sample_edge_index = sample_batch.edge_index if not isinstance(sample_batch, HeteroData) else sample_batch.edge_index_dict
222 | if isinstance(sample_batch, HeteroData):
223 | sample_batch['node', 'to', 'node'].edge_attr = sample_batch['node', 'to', 'node'].edge_attr[:, 1:]
224 | sample_batch['node', 'rev_to', 'node'].edge_attr = sample_batch['node', 'rev_to', 'node'].edge_attr[:, 1:]
225 | else:
226 | sample_batch.edge_attr = sample_batch.edge_attr[:, 1:]
227 | sample_edge_attr = sample_batch.edge_attr if not isinstance(sample_batch, HeteroData) else sample_batch.edge_attr_dict
228 | logging.info(summary(model, sample_x, sample_edge_index, sample_edge_attr))
229 |
230 | loss_fn = torch.nn.CrossEntropyLoss(weight=torch.FloatTensor([config.w_ce1, config.w_ce2]).to(device))
231 |
232 | if args.reverse_mp:
233 | model = train_hetero(tr_loader, val_loader, te_loader, tr_inds, val_inds, te_inds, model, optimizer, loss_fn, args, config, device, val_data, te_data, data_config)
234 | else:
235 | model = train_homo(tr_loader, val_loader, te_loader, tr_inds, val_inds, te_inds, model, optimizer, loss_fn, args, config, device, val_data, te_data, data_config)
236 |
237 | wandb.finish()
--------------------------------------------------------------------------------
/util.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import numpy as np
3 | import torch
4 | import random
5 | import logging
6 | import os
7 | import sys
8 |
9 | def logger_setup():
10 | # Setup logging
11 | log_directory = "logs"
12 | if not os.path.exists(log_directory):
13 | os.makedirs(log_directory)
14 | logging.basicConfig(
15 | level=logging.INFO,
16 | format="%(asctime)s [%(levelname)-5.5s] %(message)s",
17 | handlers=[
18 | logging.FileHandler(os.path.join(log_directory, "logs.log")), ## log to local log file
19 | logging.StreamHandler(sys.stdout) ## log also to stdout (i.e., print to screen)
20 | ]
21 | )
22 |
23 | def create_parser():
24 | parser = argparse.ArgumentParser()
25 |
26 | #Adaptations
27 | parser.add_argument("--emlps", action='store_true', help="Use emlps in GNN training")
28 | parser.add_argument("--reverse_mp", action='store_true', help="Use reverse MP in GNN training")
29 | parser.add_argument("--ports", action='store_true', help="Use port numberings in GNN training")
30 | parser.add_argument("--tds", action='store_true', help="Use time deltas (i.e. the time between subsequent transactions) in GNN training")
31 | parser.add_argument("--ego", action='store_true', help="Use ego IDs in GNN training")
32 |
33 | #Model parameters
34 | parser.add_argument("--batch_size", default=8192, type=int, help="Select the batch size for GNN training")
35 | parser.add_argument("--n_epochs", default=100, type=int, help="Select the number of epochs for GNN training")
36 | parser.add_argument('--num_neighs', nargs='+', default=[100,100], help='Pass the number of neighors to be sampled in each hop (descending).')
37 |
38 | #Misc
39 | parser.add_argument("--seed", default=1, type=int, help="Select the random seed for reproducability")
40 | parser.add_argument("--tqdm", action='store_true', help="Use tqdm logging (when running interactively in terminal)")
41 | parser.add_argument("--data", default=None, type=str, help="Select the AML dataset. Needs to be either small or medium.", required=True)
42 | parser.add_argument("--model", default=None, type=str, help="Select the model architecture. Needs to be one of [gin, gat, rgcn, pna]", required=True)
43 | parser.add_argument("--testing", action='store_true', help="Disable wandb logging while running the script in 'testing' mode.")
44 | parser.add_argument("--save_model", action='store_true', help="Save the best model.")
45 | parser.add_argument("--unique_name", action='store_true', help="Unique name under which the model will be stored.")
46 | parser.add_argument("--finetune", action='store_true', help="Fine-tune a model. Note that args.unique_name needs to point to the pre-trained model.")
47 | parser.add_argument("--inference", action='store_true', help="Load a trained model and only do AML inference with it. args.unique name needs to point to the trained model.")
48 |
49 | return parser
50 |
51 | def set_seed(seed: int = 0) -> None:
52 | np.random.seed(seed)
53 | random.seed(seed)
54 | torch.manual_seed(seed)
55 | torch.cuda.manual_seed(seed)
56 | # When running on the CuDNN backend, two further options must be set
57 | torch.backends.cudnn.deterministic = True
58 | torch.backends.cudnn.benchmark = False
59 | # Set a fixed value for the hash seed
60 | os.environ["PYTHONHASHSEED"] = str(seed)
61 | logging.info(f"Random seed set as {seed}")
--------------------------------------------------------------------------------