├── LICENSE ├── README.md ├── dataloaders ├── delaunay_loader.py └── jets_loader.py ├── download_jets_data.py ├── main_scripts ├── main_delaunay.py └── main_jets.py ├── models ├── deep_sets.py ├── layers.py ├── mlp.py ├── set_partition_gnn.py ├── set_partition_mlp.py ├── set_to_graph.py ├── set_to_graph_gnn.py ├── set_to_graph_mlp.py ├── set_to_graph_siam.py ├── set_to_graph_triplets.py └── triplets_model.py └── performance_eval └── eval_test_jets.py /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright [yyyy] [name of copyright owner] 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # SetToGraphPaper 2 | This repository holds the code for the paper: https://arxiv.org/abs/2002.08772. 3 | 4 | ## Data 5 | Before running the code for the jets experiments, the data should be downloaded using the following commands: 6 | 7 | ``` 8 | cd SetToGraphPaper 9 | python download_jets_data.py 10 | ``` 11 | 12 | This script will download all the data from Zenodo links. 13 | 14 | 15 | 16 | ## Code 17 | 18 | ### Prerequisites 19 | 20 | You can use the following code to install a compatible environment (using anaconda), make sure to change the cuda toolkit version to the one that fits. 21 | ``` 22 | conda create -n s2g_env -c pytorch pytorch=1.5 cudatoolkit=10.2 torchvision 23 | conda activate s2g_env 24 | CUDA=cu102 # for cuda 10.2 25 | pip install torch-scatter==latest+${CUDA} -f https://pytorch-geometric.com/whl/torch-1.5.0.html 26 | pip install torch-sparse==latest+${CUDA} -f https://pytorch-geometric.com/whl/torch-1.5.0.html 27 | pip install torch-cluster==latest+${CUDA} -f https://pytorch-geometric.com/whl/torch-1.5.0.html 28 | pip install torch-spline-conv==latest+${CUDA} -f https://pytorch-geometric.com/whl/torch-1.5.0.html 29 | pip install torch-geometric 30 | conda install -c conda-forge -c anaconda tqdm easydict rdkit uproot 31 | ``` 32 | 33 | 34 | ### Running the tests 35 | 36 | The folder main_scripts contains scripts that run different experiments: 37 | 1. To run the main paticle-physics jets experiment with our chosen hyper-parameters, run the one of the following: 38 | ``` 39 | python main_scripts/main_jets.py --method=lin2 # for S2G 40 | ``` 41 | or 42 | ``` 43 | python main_scripts/main_jets.py --method=lin5 # for S2G+ 44 | ``` 45 | or change `--method=...` with `--baseline=siam/siam3/gnn/mlp` for running a baseline. 46 | 47 | 2. To run the Delaunay triangulation with our hyper-parameters, run the main_delaunay.py script with the options `--many_sizes`/`--one_size` for n=50 or n\in{20,...,80}. 48 | Example: 49 | ``` 50 | python main_scripts/main_delaunay.py --one_size --method=lin2 51 | ``` 52 | -------------------------------------------------------------------------------- /dataloaders/delaunay_loader.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from scipy.spatial import Delaunay 3 | import torch 4 | from torch.utils.data import TensorDataset, DataLoader, Sampler, Dataset 5 | from scipy.sparse import coo_matrix 6 | import random 7 | 8 | DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu' 9 | 10 | 11 | def generate_Delaunay_dataset(n_examples, n_points): 12 | Points = np.random.rand(n_examples, n_points, 2) 13 | Edges = np.zeros((n_examples, n_points, n_points)) 14 | for ii in range(n_examples): 15 | points = Points[ii] 16 | tri = Delaunay(points) 17 | edges = [] 18 | for i in range(n_points): 19 | neigh = tri.vertex_neighbor_vertices[1][ 20 | tri.vertex_neighbor_vertices[0][i]:tri.vertex_neighbor_vertices[0][i + 1]] 21 | for j in range(len(neigh)): 22 | edges.append([i, neigh[j]]) 23 | edges = np.array(edges) 24 | Edges[ii] = coo_matrix((np.ones(edges.shape[0]), (edges[:, 0], edges[:, 1])), shape=(points.shape[0], 25 | points.shape[0])).toarray() 26 | return torch.from_numpy(Points).float().to(DEVICE), torch.from_numpy(Edges).float().to(DEVICE) 27 | 28 | 29 | def get_delaunay_loader(config, train=True): 30 | n_examples, n_points = config.n_examples, 50 31 | if not train: 32 | n_examples, n_points = config.n_examples_test, 50 33 | points, edges = generate_Delaunay_dataset(n_examples, n_points) 34 | dataset = TensorDataset(points, edges) 35 | if train: 36 | return DataLoader(dataset=dataset, batch_size=config.bs, shuffle=True) 37 | return DataLoader(dataset=dataset, batch_size=config.bs, shuffle=False) 38 | 39 | 40 | def generate_Delaunay_dataset_different_sizes(n_examples): 41 | point_numbers = np.linspace(20, 80, 61).astype(np.int) 42 | Points = [] 43 | Edges = [] 44 | for ii in range(n_examples): 45 | n_points = random.sample(set(point_numbers), 1)[0] 46 | 47 | points = np.random.rand(n_points, 2)#@rot1@np.array([[s1, 0], [0, s2]])@rot2 48 | Points.append(points) 49 | tri = Delaunay(points) 50 | edges = [] 51 | for i in range(n_points): 52 | neigh = tri.vertex_neighbor_vertices[1][ 53 | tri.vertex_neighbor_vertices[0][i]:tri.vertex_neighbor_vertices[0][i + 1]] 54 | for j in range(len(neigh)): 55 | edges.append([i, neigh[j]]) 56 | edges = np.array(edges) 57 | Edges.append(coo_matrix((np.ones(edges.shape[0]), (edges[:, 0], edges[:, 1])), 58 | shape=(points.shape[0], points.shape[0])).toarray()) 59 | return Points, Edges 60 | 61 | 62 | class DelaunayManySizes(Dataset): 63 | def __init__(self, n_example): 64 | self.Points, self.Edges = generate_Delaunay_dataset_different_sizes(n_example) 65 | # Points and Edges are the data numpy arrays in different sizes 66 | self.n_nodes = np.array([points.shape[0] for points in self.Points]) 67 | 68 | self.Points, self.Edges = lst_of_np_to_torch(self.Points, self.Edges) 69 | 70 | def __len__(self): 71 | """Returns the length of the dataset""" 72 | return len(self.Points) 73 | 74 | def __getitem__(self, idx): 75 | """Generates a single instance of data""" 76 | return self.Points[idx], self.Edges[idx] 77 | 78 | 79 | def lst_of_np_to_torch(Points, Edges): 80 | r_points, r_edges = [], [] 81 | for i in range(len(Points)): 82 | r_points.append(torch.from_numpy(Points[i]).float().to(DEVICE)) 83 | r_edges.append(torch.from_numpy(Edges[i]).float().to(DEVICE)) 84 | return r_points, r_edges 85 | 86 | 87 | class DelaunaySampler(Sampler): 88 | def __init__(self, n_nodes_array, batch_size): 89 | super().__init__(n_nodes_array.size) 90 | 91 | self.dataset_size = n_nodes_array.size 92 | self.batch_size = batch_size 93 | 94 | self.index_to_batch = {} 95 | self.node_size_idx = {} 96 | running_idx = -1 97 | 98 | for n_nodes_i in set(n_nodes_array): 99 | 100 | if n_nodes_i <= 1: 101 | continue 102 | self.node_size_idx[n_nodes_i] = np.where(n_nodes_array == n_nodes_i)[0] 103 | 104 | n_of_size = len(self.node_size_idx[n_nodes_i]) 105 | n_batches = max(n_of_size / self.batch_size, 1) 106 | 107 | self.node_size_idx[n_nodes_i] = np.array_split(np.random.permutation(self.node_size_idx[n_nodes_i]), 108 | n_batches) 109 | for batch in self.node_size_idx[n_nodes_i]: 110 | running_idx += 1 111 | self.index_to_batch[running_idx] = batch 112 | 113 | self.n_batches = running_idx + 1 114 | 115 | def __len__(self): 116 | return self.n_batches 117 | 118 | def __iter__(self): 119 | batch_order = np.random.permutation(np.arange(self.n_batches)) 120 | for i in batch_order: 121 | yield self.index_to_batch[i] 122 | 123 | 124 | def get_delaunay_loader_many_sizes(config, train): 125 | n_example = config.n_examples 126 | if not train: 127 | n_example = config.n_examples_test 128 | batch_size = config.bs 129 | Delaunay_data = DelaunayManySizes(n_example) 130 | batch_sampler = DelaunaySampler(Delaunay_data.n_nodes, batch_size) 131 | data_loader = DataLoader(Delaunay_data, batch_sampler=batch_sampler) 132 | 133 | return data_loader 134 | -------------------------------------------------------------------------------- /dataloaders/jets_loader.py: -------------------------------------------------------------------------------- 1 | import os 2 | import uproot 3 | import torch 4 | import torch.nn as nn 5 | import numpy as np 6 | from torch.utils.data import Dataset, DataLoader, Sampler 7 | from datetime import datetime 8 | 9 | data_dir = 'data/' 10 | node_features_list = ['trk_d0', 'trk_z0', 'trk_phi', 'trk_ctgtheta', 'trk_pt', 'trk_charge'] 11 | jet_features_list = ['jet_pt', 'jet_eta', 'jet_phi', 'jet_M'] 12 | 13 | 14 | def get_data_loader(which_set, batch_size, debug_load=False): 15 | jets_data = JetGraphDataset(which_set, debug_load) 16 | batch_sampler = JetsBatchSampler(jets_data.n_nodes, batch_size) 17 | data_loader = DataLoader(jets_data, batch_sampler=batch_sampler) 18 | 19 | return data_loader 20 | 21 | 22 | def transform_features(transform_list, arr): 23 | new_arr = np.zeros_like(arr) 24 | for col_i, (mean, std) in enumerate(transform_list): 25 | new_arr[col_i, :] = (arr[col_i, :] - mean) / std 26 | return new_arr 27 | 28 | 29 | class JetGraphDataset(Dataset): 30 | def __init__(self, which_set, debug_load=False, random_permutation=True): 31 | """ 32 | Initialization 33 | :param which_set: either "train", "validation" or "test" 34 | :param debug_load: if True, will load only a small subset 35 | :param random_permutation: if True, apply random permutation to the order of the nodes/vertices. 36 | """ 37 | assert which_set in ['train', 'validation', 'test'] 38 | fname = {'train': 'training', 'validation': 'valid', 'test': 'test'} 39 | 40 | self.random_permutation = random_permutation 41 | self.filename = os.path.join(data_dir, which_set, fname[which_set]+'_data.root') 42 | with uproot.open(self.filename) as f: 43 | tree = f['tree'] 44 | self.n_jets = int(tree.numentries) 45 | self.n_nodes = np.array([len(x) for x in tree.array(b'trk_vtx_index')]) 46 | 47 | self.jet_arrays = tree.arrays(jet_features_list + node_features_list + ['trk_vtx_index']) 48 | self.sets, self.partitions, self.partitions_as_graphs = [], [], [] 49 | 50 | if debug_load: 51 | self.n_jets = 100 52 | self.n_nodes = self.n_nodes[:100] 53 | 54 | start_load = datetime.now() 55 | 56 | for set_, partition, partition_as_graph in self.get_all_items(): 57 | if torch.cuda.is_available(): 58 | set_ = torch.tensor(set_, dtype=torch.float, device='cuda') 59 | partition = torch.tensor(partition, dtype=torch.long, device='cuda') 60 | partition_as_graph = torch.tensor(partition_as_graph, dtype=torch.float, device='cuda') 61 | self.sets.append(set_) 62 | self.partitions.append(partition) 63 | self.partitions_as_graphs.append(partition_as_graph) 64 | 65 | if not torch.cuda.is_available(): 66 | self.sets = np.array(self.sets) 67 | self.partitions = np.array(self.partitions) 68 | self.partitions_as_graphs = np.array(self.partitions_as_graphs) 69 | 70 | print(f' {str(datetime.now() - start_load).split(".")[0]}', flush=True) 71 | 72 | def __len__(self): 73 | """Returns the length of the dataset""" 74 | return self.n_jets 75 | 76 | def get_all_items(self): 77 | node_feats = np.array([np.asarray(self.jet_arrays[str.encode(x)]) for x in node_features_list]) 78 | jet_feats = np.array([np.asarray(self.jet_arrays[str.encode(x)]) for x in jet_features_list]) 79 | n_labels = np.array(self.jet_arrays[b'trk_vtx_index']) 80 | 81 | for i in range(self.n_jets): 82 | n_nodes = self.n_nodes[i] 83 | node_feats_i = np.stack(node_feats[:, i], axis=0) # shape (6, n_nodes) 84 | jet_feats_i = jet_feats[:, i] # shape (4, ) 85 | jet_feats_i = jet_feats_i[:, np.newaxis] # shape (4, 1) 86 | 87 | node_feats_i = transform_features(FeatureTransform.node_feature_transform_list, node_feats_i) 88 | jet_feats_i = transform_features(FeatureTransform.jet_features_transform_list, jet_feats_i) 89 | 90 | jet_feats_i = np.repeat(jet_feats_i, n_nodes, axis=1) # change shape to (4, n_nodes) 91 | set_i = np.concatenate([node_feats_i, jet_feats_i]).T # shape (n_nodes, 10) 92 | 93 | partition_i = n_labels[i] 94 | 95 | if self.random_permutation: 96 | perm = np.random.permutation(n_nodes) 97 | set_i = set_i[perm] # random permutation 98 | partition_i = partition_i[perm] # random permuatation 99 | 100 | tile = np.tile(partition_i, (self.n_nodes[i], 1)) 101 | partition_as_graph_i = np.where((tile - tile.T), 0, 1) 102 | 103 | yield set_i, partition_i, partition_as_graph_i 104 | 105 | def __getitem__(self, idx): 106 | """Generates a single instance of data""" 107 | return self.sets[idx], self.partitions[idx], self.partitions_as_graphs[idx] 108 | 109 | 110 | class JetsBatchSampler(Sampler): 111 | def __init__(self, n_nodes_array, batch_size): 112 | """ 113 | Initialization 114 | :param n_nodes_array: array of sizes of the jets 115 | :param batch_size: batch size 116 | """ 117 | super().__init__(n_nodes_array.size) 118 | 119 | self.dataset_size = n_nodes_array.size 120 | self.batch_size = batch_size 121 | 122 | self.index_to_batch = {} 123 | self.node_size_idx = {} 124 | running_idx = -1 125 | 126 | for n_nodes_i in set(n_nodes_array): 127 | 128 | if n_nodes_i <= 1: 129 | continue 130 | self.node_size_idx[n_nodes_i] = np.where(n_nodes_array == n_nodes_i)[0] 131 | 132 | n_of_size = len(self.node_size_idx[n_nodes_i]) 133 | n_batches = max(n_of_size / self.batch_size, 1) 134 | 135 | self.node_size_idx[n_nodes_i] = np.array_split(np.random.permutation(self.node_size_idx[n_nodes_i]), 136 | n_batches) 137 | for batch in self.node_size_idx[n_nodes_i]: 138 | running_idx += 1 139 | self.index_to_batch[running_idx] = batch 140 | 141 | self.n_batches = running_idx + 1 142 | 143 | def __len__(self): 144 | return self.n_batches 145 | 146 | def __iter__(self): 147 | batch_order = np.random.permutation(np.arange(self.n_batches)) 148 | for i in batch_order: 149 | yield self.index_to_batch[i] 150 | 151 | 152 | class FeatureTransform(object): 153 | # Based on mean and std values of TRAINING set only 154 | node_feature_transform_list = [ 155 | (0.0006078152, 14.128961), 156 | (0.0038490593, 10.688491), 157 | (-0.0026713554, 1.8167108), 158 | (0.0047640945, 1.889725), 159 | (5.237357, 7.4841413), 160 | (-0.00015662189, 1.0)] 161 | 162 | jet_features_transform_list = [ 163 | (75.95093, 49.134453), 164 | (0.0022607117, 1.2152709), 165 | (-0.0023569583, 1.8164033), 166 | (9.437994, 6.765137)] 167 | -------------------------------------------------------------------------------- /download_jets_data.py: -------------------------------------------------------------------------------- 1 | import os 2 | import urllib.request 3 | 4 | 5 | def mkdir_if_not_exists(dir_name): 6 | if not (os.path.exists(dir_name) and os.path.isdir(dir_name)): 7 | os.mkdir(dir_name) 8 | 9 | 10 | if __name__ == '__main__': 11 | train_link = 'https://zenodo.org/record/4044628/files/training_data.root?download=1' 12 | val_link = 'https://zenodo.org/record/4044628/files/valid_data.root?download=1' 13 | test_link = 'https://zenodo.org/record/4044628/files/test_data.root?download=1' 14 | 15 | print('Creating data directories...') 16 | mkdir_if_not_exists('data') 17 | mkdir_if_not_exists('data/train') 18 | mkdir_if_not_exists('data/validation') 19 | mkdir_if_not_exists('data/test') 20 | 21 | print('Downloading training data to data/train/training_data.root...', flush=True) 22 | urllib.request.urlretrieve(train_link, 'data/train/training_data.root') 23 | print('Downloading validation data to data/validation/valid_data.root...', flush=True) 24 | urllib.request.urlretrieve(val_link, 'data/validation/valid_data.root') 25 | print('Downloading test data data/test/test_data.root...', flush=True) 26 | urllib.request.urlretrieve(test_link, 'data/test/test_data.root') 27 | 28 | print('Done!') 29 | 30 | -------------------------------------------------------------------------------- /main_scripts/main_delaunay.py: -------------------------------------------------------------------------------- 1 | import random 2 | import os 3 | import sys 4 | import argparse 5 | import shutil 6 | import json 7 | from pprint import pprint 8 | from datetime import datetime 9 | 10 | import numpy as np 11 | import pandas as pd 12 | import matplotlib.pyplot as plt 13 | import torch 14 | import torch.nn as nn 15 | import torch.nn.functional as F 16 | 17 | # Change working directory to project's main directory, and add it to path - for library and config usages 18 | project_dir = os.path.abspath(os.path.join(os.path.dirname(__file__), os.pardir)) 19 | sys.path.append(project_dir) 20 | os.chdir(project_dir) 21 | 22 | # Project dependencies 23 | from models.set_to_graph import SetToGraph 24 | from models.set_to_graph_mlp import SetToGraphMLP 25 | from models.set_to_graph_gnn import SetToGraphGNN 26 | from models.set_to_graph_triplets import SetToGraphTri 27 | from models.set_to_graph_siam import SetToGraphSiam 28 | from dataloaders.delaunay_loader import get_delaunay_loader, get_delaunay_loader_many_sizes 29 | 30 | 31 | def parse_args(): 32 | """ 33 | Define and retrieve command line arguements 34 | :return: argparser instance 35 | """ 36 | argparser = argparse.ArgumentParser(description=__doc__) 37 | argparser.add_argument('-g', '--gpu', default='0', help='The gpu to run on') 38 | argparser.add_argument('-e', '--epochs', default=100, type=int, help='The number of epochs to run') 39 | argparser.add_argument('-l', '--lr', default=0.001, type=float, help='The learning rate') 40 | argparser.add_argument('-b', '--bs', default=64, type=int, help='Batch size to use') 41 | argparser.add_argument('--n_examples_test', default=5000, type=int, help='number of test examples') 42 | argparser.add_argument('--n_examples', default=50000, type=int, help='number of training examples') 43 | argparser.add_argument('--res_dir', default='../experiments/delaunay_results', help='Results directory') 44 | argparser.add_argument('--method', default='lin2', type=str, help='method of transitioning from vectors to matrix') 45 | argparser.add_argument('--baseline', default=None, help='Run on baseline - siam, siam3, mlp or gnn') 46 | 47 | argparser.add_argument('--many_sizes', dest='many_sizes', action='store_true', help='Whether to use n in 20-80 or n=50.') 48 | argparser.add_argument('--one_size', dest='many_sizes', action='store_false') 49 | argparser.add_argument('--save', dest='save', action='store_true', help='Whether to save all to disk') 50 | argparser.add_argument('--no-save', dest='save', action='store_false') 51 | argparser.set_defaults(save=True, many_sizes=False) 52 | 53 | args = argparser.parse_args() 54 | return args 55 | 56 | 57 | def update_info(loss, pred, edges, accum_info): 58 | batch_size = pred.shape[0] 59 | accum_info['loss'] += loss.item() * batch_size 60 | pred_edges = pred.ge(0.).float() 61 | 62 | epsilon = 0.00000001 63 | tp = ((pred_edges == edges) * (pred_edges == 1)).sum(dim=2).sum(dim=1).float() 64 | tn = ((pred_edges == edges) * (pred_edges == 0)).sum(dim=2).sum(dim=1).float() 65 | fp = ((pred_edges != edges) * (pred_edges == 1)).sum(dim=2).sum(dim=1).float() 66 | fn = ((pred_edges != edges) * (pred_edges == 0)).sum(dim=2).sum(dim=1).float() 67 | 68 | accum_info['acc'] += ((tp + tn)/(tp + tn + fp + fn)).sum().item() 69 | accum_info['precision'] += (tp/(tp + fp + epsilon)).sum().item() 70 | accum_info['recall'] += (tp/(tp + fn + epsilon)).sum().item() 71 | accum_info['f1'] += (2 * tp / (2 * tp + fn + fp + epsilon)).sum().item() 72 | return accum_info 73 | 74 | 75 | def train_epoch(data, epoch, model, optimizer, device): 76 | model.train() 77 | 78 | # Iterate over batches 79 | 80 | accum_info = {k: 0.0 for k in ['loss', 'acc', 'precision', 'recall', 'f1']} 81 | for points, edges in data: 82 | # One Train step on the current batch 83 | points = points.to(device, torch.float) 84 | edges = edges.to(device, torch.float) 85 | 86 | if isinstance(model, SetToGraphTri): 87 | pred, loss = model(points, edges) 88 | else: 89 | pred = model(points).squeeze(1) # shape (B,N,N) 90 | pred = (pred + pred.transpose(1, 2)) / 2 91 | 92 | # calc loss 93 | loss = F.binary_cross_entropy_with_logits(pred, edges) 94 | 95 | # calc acc, precision, recall 96 | with torch.no_grad(): 97 | accum_info = update_info(loss, pred, edges, accum_info) 98 | 99 | optimizer.zero_grad() 100 | loss.backward() 101 | optimizer.step() 102 | 103 | data_len = len(data.dataset) 104 | accum_info['loss'] /= data_len 105 | accum_info['acc'] /= data_len 106 | accum_info['precision'] /= data_len 107 | accum_info['recall'] /= data_len 108 | accum_info['f1'] /= data_len 109 | print("train epoch %d loss %f acc %f precision %f recall %f f1 %f" % (epoch, accum_info['loss'], accum_info['acc'], 110 | accum_info['precision'], accum_info['recall'], 111 | accum_info['f1']), flush=True) 112 | 113 | return accum_info 114 | 115 | 116 | def evaluate(data, epoch, model, device): 117 | # train epoch 118 | model.eval() 119 | accum_info = {k: 0.0 for k in ['loss', 'acc', 'precision', 'recall', 'f1']} 120 | 121 | for points, edges in data: 122 | # One Train step on the current batch 123 | points = points.to(device, torch.float) 124 | edges = edges.to(device, torch.float) 125 | 126 | if isinstance(model, SetToGraphTri): 127 | pred, loss = model(points, edges) 128 | else: 129 | pred = model(points).squeeze(1) # shape (B,N,N) 130 | pred = (pred + pred.transpose(1, 2)) / 2 131 | 132 | loss = F.binary_cross_entropy_with_logits(pred, edges) 133 | 134 | # calc acc, precision, recall 135 | accum_info = update_info(loss, pred, edges, accum_info) 136 | 137 | data_len = data.dataset.__len__() 138 | accum_info['loss'] /= data_len 139 | accum_info['acc'] /= data_len 140 | accum_info['precision'] /= data_len 141 | accum_info['recall'] /= data_len 142 | accum_info['f1'] /= data_len 143 | print("validation epoch %d loss %f acc %f precision %f recall %f f1 %f" % (epoch, accum_info['loss'], 144 | accum_info['acc'], 145 | accum_info['precision'], 146 | accum_info['recall'], accum_info['f1']), flush=True) 147 | 148 | return accum_info 149 | 150 | 151 | def plot_val(df, output_dir, val): 152 | df.index.name = 'epochs' 153 | df.to_csv(os.path.join(output_dir, "metrics.csv"), index=False) 154 | df[['train_'+val, 'val_'+val]].plot(title=val, grid=True) 155 | plt.savefig(os.path.join(output_dir, val+".pdf")) 156 | 157 | 158 | def main(): 159 | config = parse_args() 160 | start_time = datetime.now() 161 | # os.environ['CUDA_LAUNCH_BLOCKING'] = "1" # uncomment only for CUDA error debugging 162 | # os.environ["CUDA_VISIBLE_DEVICES"] = config.gpu 163 | torch.cuda.set_device(int(config.gpu)) 164 | use_cuda = torch.cuda.is_available() # returns True 165 | device = torch.device('cuda' if use_cuda else 'cpu') 166 | seed = 1728 167 | np.random.seed(seed) 168 | random.seed(seed) 169 | torch.manual_seed(seed) 170 | torch.cuda.manual_seed(seed) 171 | torch.cuda.manual_seed_all(seed) 172 | # torch.backends.cudnn.deterministic = True # can impact performance 173 | # torch.backends.cudnn.benchmark = False # can impact performance 174 | 175 | config.exp_name = 'Delaunay' 176 | pprint(vars(config)) 177 | print(flush=True) 178 | 179 | # Load data 180 | many_sizes = config.many_sizes 181 | if many_sizes: 182 | print('Generating training data, n in 20-80...', flush=True) 183 | train_data = get_delaunay_loader_many_sizes(config, train=True)#, device=device) 184 | print('Generating validation data, n in 20-80...', flush=True) 185 | val_data = get_delaunay_loader_many_sizes(config, train=False)#, device=device) 186 | else: 187 | print('Generating training data, n=50...', flush=True) 188 | train_data = get_delaunay_loader(config, train=True) # , device=device) 189 | print('Generating validation data, n=50...', flush=True) 190 | val_data = get_delaunay_loader(config, train=False) # , device=device) 191 | 192 | # Create model instance 193 | 194 | # cfg = dict(agg=torch.mean, normalization='batchnorm', second_bias=False, mlp_with_relu=False) 195 | if config.baseline == 'mlp': 196 | maxnodes = 80 if config.many_sizes else 50 197 | model = SetToGraphMLP([500, 1000, 1000, 1000, 500, 80**2], in_features=2, max_nodes=maxnodes) 198 | elif config.baseline == 'gnn': 199 | model = SetToGraphGNN([1000, 1500, 1000], in_features=2, k=5) 200 | elif config.baseline == 'siam3': 201 | model = SetToGraphTri([500, 1000, 1500, 1250, 1000, 500, 500, 80], in_features=2) 202 | elif config.baseline == 'siam': 203 | cfg = dict(normalization='batchnorm', mlp_with_relu=False) 204 | model = SetToGraphSiam(2, 205 | [700, 700, 700, 1400, 700, 700, 112], 206 | hidden_mlp=[1000, 1000], 207 | cfg=cfg) 208 | else: 209 | cfg = dict(agg=torch.mean) 210 | model = SetToGraph(in_features=2, 211 | out_features=1, 212 | set_fn_feats=[500, 500, 500, 1000, 500, 500, 80], 213 | method=config.method, 214 | hidden_mlp=[1000, 1000], 215 | predict_diagonal=True, 216 | attention=False, 217 | cfg=cfg) 218 | model = model.to(device) 219 | print(f'Model: {model}') 220 | print(f'Num of params: {sum(p.numel() for p in model.parameters() if p.requires_grad)}') 221 | 222 | lr = config.lr 223 | if config.baseline == 'gnn': 224 | lr = 1e-4 225 | print(f'Changed learning rate to 1e-4 for GNN') 226 | optimizer = torch.optim.Adam(params=model.parameters(), lr=lr) 227 | 228 | # Metrics 229 | train_loss = np.empty(config.epochs, float) 230 | train_acc = np.empty(config.epochs, float) 231 | train_precision = np.empty(config.epochs, float) 232 | train_recall = np.empty(config.epochs, float) 233 | train_f1 = np.empty(config.epochs, float) 234 | 235 | val_loss = np.empty(config.epochs, float) 236 | val_acc = np.empty(config.epochs, float) 237 | val_precision = np.empty(config.epochs, float) 238 | val_recall = np.empty(config.epochs, float) 239 | val_f1 = np.empty(config.epochs, float) 240 | 241 | # Training and evaluation process 242 | for epoch in range(1, config.epochs + 1): 243 | train_info = train_epoch(train_data, epoch, model, optimizer, device) 244 | train_loss[epoch - 1], train_acc[epoch - 1], train_precision[epoch - 1], train_recall[epoch - 1], \ 245 | train_f1[epoch - 1] = train_info['loss'], train_info['acc'], train_info['precision'], train_info['recall'],\ 246 | train_info['f1'] 247 | with torch.no_grad(): 248 | val_info = evaluate(val_data, epoch, model, device) 249 | val_loss[epoch - 1], val_acc[epoch - 1], val_precision[epoch - 1], val_recall[epoch - 1], val_f1[epoch - 1] = \ 250 | val_info['loss'], val_info['acc'], val_info['precision'], val_info['recall'], val_info['f1'] 251 | 252 | # Saving to disk 253 | if config.save: 254 | if not os.path.exists(config.res_dir): 255 | os.makedirs(config.res_dir) 256 | exp_dir = f'Delaunay_{start_time:%Y%m%d_%H%M%S}' + config.method 257 | output_dir = os.path.join(config.res_dir, exp_dir) 258 | os.makedirs(output_dir) 259 | print(f'Saving all to {output_dir}') 260 | shutil.copyfile(__file__, os.path.join(output_dir, 'code.py')) 261 | 262 | results_dict = {'train_loss': train_loss, 263 | 'train_acc': train_acc, 264 | 'train_precision': train_precision, 265 | 'train_recall': train_recall, 266 | 'train_f1': train_f1, 267 | 'val_loss': val_loss, 268 | 'val_acc': val_acc, 269 | 'val_precision': val_precision, 270 | 'val_recall': val_recall, 271 | 'val_f1': val_f1} 272 | df = pd.DataFrame(results_dict) 273 | plot_val(df, output_dir, 'loss') 274 | plot_val(df, output_dir, 'acc') 275 | plot_val(df, output_dir, 'precision') 276 | plot_val(df, output_dir, 'recall') 277 | plot_val(df, output_dir, 'f1') 278 | 279 | torch.save(model.state_dict(), os.path.join(output_dir, 'model.pth')) 280 | if many_sizes: 281 | torch.save(train_data.dataset.Points, os.path.join(output_dir, 'train_Points.pth')) 282 | torch.save(train_data.dataset.Edges, os.path.join(output_dir, 'train_Edges.pth')) 283 | 284 | torch.save(val_data.dataset.Points, os.path.join(output_dir, 'val_Points.pth')) 285 | torch.save(val_data.dataset.Edges, os.path.join(output_dir, 'val_Edges.pth')) 286 | else: 287 | torch.save(train_data.dataset.tensors, os.path.join(output_dir, 'train.pth')) 288 | torch.save(val_data.dataset.tensors, os.path.join(output_dir, 'val.pth')) 289 | 290 | with open(os.path.join(output_dir, 'used_config.json'), 'w') as fp: 291 | json.dump(vars(config), fp) 292 | 293 | print(f'Total runtime: {str(datetime.now() - start_time).split(".")[0]}') 294 | 295 | 296 | if __name__ == '__main__': 297 | main() 298 | -------------------------------------------------------------------------------- /main_scripts/main_jets.py: -------------------------------------------------------------------------------- 1 | import random 2 | import os 3 | import sys 4 | import argparse 5 | import copy 6 | import shutil 7 | import json 8 | from pprint import pprint 9 | from datetime import datetime 10 | 11 | import numpy as np 12 | import pandas as pd 13 | import torch 14 | import torch.nn.functional as F 15 | 16 | 17 | """ 18 | How To: 19 | Example for running from command line: 20 | python /SetToGraph/main_scripts/main_jets.py 21 | """ 22 | # Change working directory to project's main directory, and add it to path - for library and config usages 23 | project_dir = os.path.abspath(os.path.join(os.path.dirname(__file__), os.pardir)) 24 | sys.path.append(project_dir) 25 | os.chdir(project_dir) 26 | 27 | # Project dependencies 28 | from models.set_to_graph import SetToGraph 29 | from models.set_to_graph_siam import SetToGraphSiam 30 | from models.set_partition_mlp import SetPartitionMLP 31 | from models.set_partition_gnn import SetPartitionGNN 32 | from models.triplets_model import SetPartitionTri 33 | from dataloaders import jets_loader 34 | from performance_eval.eval_test_jets import eval_jets_on_test_set 35 | 36 | DEVICE = 'cuda' 37 | 38 | 39 | def parse_args(): 40 | """ 41 | Define and retrieve command line arguements 42 | :return: argparser instance 43 | """ 44 | argparser = argparse.ArgumentParser(description=__doc__) 45 | argparser.add_argument('-g', '--gpu', default='0', help='The gpu to run on') 46 | argparser.add_argument('-e', '--epochs', default=400, type=int, help='The number of epochs to run') 47 | argparser.add_argument('-l', '--lr', default=0.001, type=float, help='The learning rate') 48 | argparser.add_argument('-b', '--bs', default=2048, type=int, help='Batch size to use') 49 | argparser.add_argument('--method', default='lin2', help='Method to transfer from sets to graphs: lin2 for S2G, lin5 for S2G+') 50 | argparser.add_argument('--res_dir', default='../experiments/jets_results', help='Results directory') 51 | argparser.add_argument('--baseline', default=None, help='Use a baseline and not set2graph. mlp, gnn, siam or siam3.') 52 | 53 | argparser.add_argument('--debug_load', dest='debug_load', action='store_true', help='Load only a small subset of the data') 54 | argparser.add_argument('--save', dest='save', action='store_true', help='Whether to save all to disk') 55 | argparser.add_argument('--no_save', dest='save', action='store_false') 56 | argparser.set_defaults(save=True, debug_load=False) 57 | 58 | args = argparser.parse_args() 59 | 60 | assert args.baseline is None or args.baseline in ['mlp', 'gnn', 'siam', 'siam3'] 61 | 62 | return args 63 | 64 | 65 | def calc_metrics(pred_partitions, partitions_as_graph, partitions, accum_info): 66 | with torch.no_grad(): 67 | B, N = partitions.shape 68 | C = pred_partitions.max().item() + 1 69 | pred_partitions = pred_partitions[:, :, np.newaxis] 70 | pred_onehot = torch.zeros((B, N, C), dtype=torch.float, device=partitions.device) 71 | pred_onehot.scatter_(2, pred_partitions, 1) 72 | pred_matrices = torch.matmul(pred_onehot, pred_onehot.transpose(1, 2)) 73 | 74 | # calc fscore, precision, recall 75 | tp = (pred_matrices * partitions_as_graph).sum(dim=(1, 2)) - N # Don't care about diagonals 76 | fp = (pred_matrices * (1 - partitions_as_graph)).sum(dim=(1, 2)) 77 | fn = ((1 - pred_matrices) * partitions_as_graph).sum(dim=(1, 2)) 78 | accum_info['recall'] += (tp / (tp + fp + 1e-10)).sum().item() 79 | accum_info['precision'] += (tp / (tp + fn + 1e-10)).sum().item() 80 | accum_info['fscore'] += ((2 * tp) / (2 * tp + fp + fn + 1e-10)).sum().item() 81 | 82 | # calc RI 83 | equiv_pairs = (pred_matrices == partitions_as_graph).float() 84 | accum_info['accuracy'] += equiv_pairs.mean(dim=(1, 2)).sum().item() 85 | # ignore pairs of same node 86 | equiv_pairs[:, torch.arange(N), torch.arange(N)] = torch.zeros((N,), device=DEVICE) 87 | ri_results = equiv_pairs.sum(dim=(1, 2)) / (N*(N-1)) 88 | accum_info['ri'] += ri_results.sum().item() 89 | 90 | return accum_info 91 | 92 | 93 | def infer_clusters(edge_vals): 94 | ''' 95 | Infer the clusters. Enforce symmetry. 96 | :param edge_vals: predicted edge score values. shape (B, N, N) 97 | :return: long tensor shape (B, N) of the clusters. 98 | ''' 99 | # deployment - infer chosen clusters: 100 | b, n, _ = edge_vals.shape 101 | with torch.no_grad(): 102 | pred_matrices = edge_vals + edge_vals.transpose(1, 2) # to make symmetric 103 | pred_matrices = pred_matrices.ge(0.).float() # adj matrix - 0 as threshold 104 | pred_matrices[:, np.arange(n), np.arange(n)] = 1. # each node is always connected to itself 105 | ones_now = pred_matrices.sum() 106 | ones_before = ones_now - 1 107 | while ones_now != ones_before: # get connected components - each node connected to all in its component 108 | ones_before = ones_now 109 | pred_matrices = torch.matmul(pred_matrices, pred_matrices) 110 | pred_matrices = pred_matrices.bool().float() # remain as 0-1 matrices 111 | ones_now = pred_matrices.sum() 112 | 113 | clusters = -1 * torch.ones((b, n), device=edge_vals.device) 114 | tensor_1 = torch.tensor(1., device=edge_vals.device) 115 | for i in range(n): 116 | clusters = torch.where(pred_matrices[:, i] == 1, i * tensor_1, clusters) 117 | 118 | return clusters.long() 119 | 120 | 121 | def get_loss(y_hat, y): 122 | # No loss on diagonal 123 | B, N, _ = y_hat.shape 124 | y_hat[:, torch.arange(N), torch.arange(N)] = torch.finfo(y_hat.dtype).max # to be "1" after sigmoid 125 | 126 | # calc loss 127 | loss = F.binary_cross_entropy_with_logits(y_hat, y) # cross entropy 128 | 129 | y_hat = torch.sigmoid(y_hat) 130 | tp = (y_hat * y).sum(dim=(1, 2)) 131 | fn = ((1. - y_hat) * y).sum(dim=(1, 2)) 132 | fp = (y_hat * (1. - y)).sum(dim=(1, 2)) 133 | loss = loss - ((2 * tp) / (2 * tp + fp + fn + 1e-10)).sum() # fscore 134 | 135 | return loss 136 | 137 | 138 | def train(data, model, optimizer): 139 | train_info = do_epoch(data, model, optimizer) 140 | return train_info 141 | 142 | 143 | def evaluate(data, model): 144 | val_info = do_epoch(data, model, optimizer=None) 145 | return val_info 146 | 147 | 148 | def do_epoch(data, model, optimizer=None): 149 | if optimizer is not None: 150 | # train epoch 151 | model.train() 152 | else: 153 | # validation epoch 154 | model.eval() 155 | start_time = datetime.now() 156 | 157 | # Iterate over batches 158 | accum_info = {k: 0.0 for k in ['ri', 'loss', 'insts', 'accuracy', 'fscore', 'precision', 'recall']} 159 | for sets, partitions, partitions_as_graph in data: 160 | # One Train step on the current batch 161 | sets = sets.to(DEVICE, torch.float) 162 | partitions = partitions.to(DEVICE, torch.long) 163 | partitions_as_graph = partitions_as_graph.to(DEVICE, torch.float) 164 | batch_size = sets.shape[0] 165 | accum_info['insts'] += batch_size 166 | 167 | if isinstance(model, SetPartitionTri): 168 | pred_partitions, loss = model(sets, partitions) 169 | else: 170 | edge_vals = model(sets).squeeze(1) # B,N,N 171 | pred_partitions = infer_clusters(edge_vals) 172 | loss = get_loss(edge_vals, partitions_as_graph) 173 | 174 | if optimizer is not None: 175 | # backprop for training epochs only 176 | optimizer.zero_grad() 177 | loss.backward() 178 | optimizer.step() 179 | 180 | # calc ri 181 | accum_info = calc_metrics(pred_partitions, partitions_as_graph, partitions, accum_info) 182 | 183 | # update results from train_step func 184 | accum_info['loss'] += loss.item() * batch_size 185 | 186 | num_insts = accum_info.pop('insts') 187 | accum_info['loss'] /= num_insts 188 | accum_info['ri'] /= num_insts 189 | accum_info['accuracy'] /= num_insts 190 | accum_info['fscore'] /= num_insts 191 | accum_info['recall'] /= num_insts 192 | accum_info['precision'] /= num_insts 193 | 194 | accum_info['run_time'] = datetime.now() - start_time 195 | accum_info['run_time'] = str(accum_info['run_time']).split(".")[0] 196 | 197 | return accum_info 198 | 199 | 200 | def main(): 201 | start_time = datetime.now() 202 | 203 | seed = 42 204 | np.random.seed(seed) 205 | random.seed(seed) 206 | torch.manual_seed(seed) 207 | torch.cuda.manual_seed(seed) 208 | torch.cuda.manual_seed_all(seed) 209 | # torch.backends.cudnn.deterministic = True # can impact performance 210 | # torch.backends.cudnn.benchmark = False # can impact performance 211 | 212 | config = parse_args() 213 | 214 | # os.environ['CUDA_LAUNCH_BLOCKING'] = "1" # uncomment only for CUDA error debugging 215 | # os.environ["CUDA_VISIBLE_DEVICES"] = config.gpu 216 | torch.cuda.set_device(int(config.gpu)) 217 | 218 | pprint(vars(config)) 219 | print(flush=True) 220 | 221 | # Load data 222 | print('Loading training data...', end='', flush=True) 223 | train_data = jets_loader.get_data_loader('train', config.bs, config.debug_load) 224 | print('Loading validation data...', end='', flush=True) 225 | val_data = jets_loader.get_data_loader('validation', config.bs, config.debug_load) 226 | 227 | # Create model instance 228 | if config.baseline == 'siam3': 229 | model = SetPartitionTri(10, [384, 384, 384, 384, 20]) 230 | elif config.baseline == 'mlp': 231 | model = SetPartitionMLP([512, 256, 512, 15*15], 10) 232 | elif config.baseline == 'gnn': 233 | model = SetPartitionGNN([350, 350, 300, 20], 10) 234 | elif config.baseline == 'siam': 235 | model = SetToGraphSiam(10, [384, 384, 384, 384, 5], hidden_mlp=[256]) 236 | else: 237 | assert config.baseline is None 238 | model = SetToGraph(10, 239 | out_features=1, 240 | set_fn_feats=[256, 256, 256, 256, 5], 241 | method=config.method, 242 | hidden_mlp=[256], 243 | predict_diagonal=False, 244 | attention=True) 245 | print('Model:' , model) 246 | model = model.to(DEVICE) 247 | num_params = sum(p.numel() for p in model.parameters() if p.requires_grad) 248 | print(f'The nubmer of model parameters is {num_params}') 249 | 250 | # Optimizer 251 | optimizer = torch.optim.Adam(params=model.parameters(), lr=config.lr) 252 | 253 | # Metrics 254 | train_loss = np.empty(config.epochs, float) 255 | train_ri = np.empty(config.epochs, float) 256 | val_loss = np.empty(config.epochs, float) 257 | val_ri = np.empty(config.epochs, float) 258 | 259 | best_epoch = -1 260 | best_val_ri = -1 261 | best_val_fscore = -1 262 | best_model = None 263 | 264 | # Training and evaluation process 265 | for epoch in range(1, config.epochs + 1): 266 | train_info = train(train_data, model, optimizer) 267 | print(f"\tTraining - {epoch:4}", 268 | " loss:{loss:.6f} -- mean_ri:{ri:.4f} -- fscore:{fscore:.4f} -- recall:{recall:.4f}" 269 | " -- precision:{precision:.4f} -- runtime:{run_time}".format(**train_info), flush=True) 270 | train_loss[epoch-1], train_ri[epoch-1] = train_info['loss'], train_info['ri'] 271 | 272 | val_info = evaluate(val_data, model) 273 | print(f"\tVal - {epoch:4}", 274 | " loss:{loss:.6f} -- mean_ri:{ri:.4f} -- fscore:{fscore:.4f} -- recall:{recall:.4f}" 275 | " -- precision:{precision:.4f} -- runtime:{run_time}\n".format(**val_info), flush=True) 276 | val_loss[epoch-1], val_ri[epoch-1] = val_info['loss'], val_info['ri'] 277 | 278 | if val_info['fscore'] > best_val_fscore: 279 | best_val_fscore = val_info['fscore'] 280 | best_epoch = epoch 281 | best_model = copy.deepcopy(model) 282 | 283 | if best_epoch < epoch - 20: 284 | print('Early stopping training due to no improvement over the last 20 epochs...') 285 | break 286 | 287 | del train_data, val_data 288 | print(f'Best validation F-score: {best_val_fscore:.4f}, best epoch: {best_epoch}.') 289 | 290 | print(f'Training runtime: {str(datetime.now() - start_time).split(".")[0]}') 291 | print() 292 | 293 | # Saving to disk 294 | if config.save: 295 | if not os.path.isdir(config.res_dir): 296 | os.makedirs(config.res_dir) 297 | exp_dir = f'jets_{start_time:%Y%m%d_%H%M%S}_0' 298 | output_dir = os.path.join(config.res_dir, exp_dir) 299 | 300 | i = 0 301 | while True: 302 | if not os.path.isdir(output_dir): 303 | os.makedirs(output_dir) # raises error if dir already exists 304 | break 305 | i += 1 306 | output_dir = output_dir[:-1] + str(i) 307 | if i > 9: 308 | print(f'Cannot save results on disk. (tried to save as {output_dir})') 309 | return 310 | 311 | print(f'Saving all to {output_dir}') 312 | torch.save(best_model.state_dict(), os.path.join(output_dir, "exp_model.pt")) 313 | shutil.copyfile(__file__, os.path.join(output_dir, 'code.py')) 314 | results_dict = {'train_loss': train_loss, 315 | 'train_ri': train_ri, 316 | 'val_loss': val_loss, 317 | 'val_ri': val_ri} 318 | df = pd.DataFrame(results_dict) 319 | df.index.name = 'epochs' 320 | df.to_csv(os.path.join(output_dir, "metrics.csv"), index=False) 321 | best_dict = {'best_val_ri': best_val_ri, 'best_epoch': best_epoch} 322 | best_df = pd.DataFrame(best_dict, index=[0]) 323 | best_df.to_csv(os.path.join(output_dir, "best_val_results.csv"), index=False) 324 | with open(os.path.join(output_dir, 'used_config.json'), 'w') as fp: 325 | json.dump(vars(config), fp) 326 | 327 | # print('Loading test data...', end='', flush=True) 328 | # test_data = jets_loader.get_data_loader('test', config.bs, config.debug_load) 329 | # test_info = evaluate(test_data, best_model) 330 | # print(f"\tTest - {best_epoch:4}", 331 | # " loss:{loss:.6f} -- mean_ri:{ri:.4f} -- fscore:{fscore:.4f} -- recall:{recall:.4f} " 332 | # "-- precision:{precision:.4f} -- runtime:{run_time}\n".format(**test_info)) 333 | 334 | print(f'Epoch {best_epoch} - evaluating over test set.') 335 | test_results = eval_jets_on_test_set(best_model) 336 | print('Test results:') 337 | print(test_results) 338 | if config.save: 339 | test_results.to_csv(os.path.join(output_dir, "test_results.csv"), index=True) 340 | 341 | print(f'Total runtime: {str(datetime.now() - start_time).split(".")[0]}') 342 | 343 | 344 | if __name__ == '__main__': 345 | main() 346 | -------------------------------------------------------------------------------- /models/deep_sets.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from models.layers import Attention 4 | 5 | 6 | class DeepSet(nn.Module): 7 | def __init__(self, in_features, feats, attention, cfg=None): 8 | """ 9 | DeepSets implementation 10 | :param in_features: input's number of features 11 | :param feats: list of features for each deepsets layer 12 | :param attention: True/False to use attention 13 | :param cfg: configurations of second_bias and normalization method 14 | """ 15 | super(DeepSet, self).__init__() 16 | if cfg is None: 17 | cfg = {} 18 | 19 | layers = [] 20 | normalization = cfg.get('normalization', 'fro') 21 | second_bias = cfg.get('second_bias', True) 22 | 23 | layers.append(DeepSetLayer(in_features, feats[0], attention, normalization, second_bias)) 24 | for i in range(1, len(feats)): 25 | layers.append(nn.ReLU()) 26 | layers.append(DeepSetLayer(feats[i-1], feats[i], attention, normalization, second_bias)) 27 | 28 | self.sequential = nn.Sequential(*layers) 29 | 30 | def forward(self, x): 31 | return self.sequential(x) 32 | 33 | 34 | class DeepSetLayer(nn.Module): 35 | def __init__(self, in_features, out_features, attention, normalization, second_bias): 36 | """ 37 | DeepSets single layer 38 | :param in_features: input's number of features 39 | :param out_features: output's number of features 40 | :param attention: Whether to use attention 41 | :param normalization: normalization method - 'fro' or 'batchnorm' 42 | :param second_bias: use a bias in second conv1d layer 43 | """ 44 | super(DeepSetLayer, self).__init__() 45 | 46 | self.attention = None 47 | if attention: 48 | self.attention = Attention(in_features) 49 | self.layer1 = nn.Conv1d(in_features, out_features, 1) 50 | self.layer2 = nn.Conv1d(in_features, out_features, 1, bias=second_bias) 51 | 52 | self.normalization = normalization 53 | if normalization == 'batchnorm': 54 | self.bn = nn.BatchNorm1d(out_features) 55 | 56 | def forward(self, x): 57 | # x.shape = (B,C,N) 58 | 59 | # attention 60 | if self.attention: 61 | x_T = x.transpose(2, 1) # B,C,N -> B,N,C 62 | x = self.layer1(x) + self.layer2(self.attention(x_T).transpose(1, 2)) 63 | else: 64 | x = self.layer1(x) + self.layer2(x - x.mean(dim=2, keepdim=True)) 65 | 66 | # normalization 67 | if self.normalization == 'batchnorm': 68 | x = self.bn(x) 69 | else: 70 | x = x / torch.norm(x, p='fro', dim=1, keepdim=True) # BxCxN / Bx1xN 71 | 72 | return x 73 | -------------------------------------------------------------------------------- /models/layers.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | import math 5 | 6 | 7 | class PsiSuffix(nn.Module): 8 | def __init__(self, features, predict_diagonal): 9 | super().__init__() 10 | layers = [] 11 | for i in range(len(features) - 2): 12 | layers.append(DiagOffdiagMLP(features[i], features[i + 1], predict_diagonal)) 13 | layers.append(nn.ReLU()) 14 | layers.append(DiagOffdiagMLP(features[-2], features[-1], predict_diagonal)) 15 | self.model = nn.Sequential(*layers) 16 | 17 | def forward(self, x): 18 | return self.model(x) 19 | 20 | 21 | class DiagOffdiagMLP(nn.Module): 22 | def __init__(self, in_features, out_features, seperate_diag): 23 | super(DiagOffdiagMLP, self).__init__() 24 | 25 | self.seperate_diag = seperate_diag 26 | self.conv_offdiag = nn.Conv2d(in_features, out_features, 1) 27 | if self.seperate_diag: 28 | self.conv_diag = nn.Conv1d(in_features, out_features, 1) 29 | 30 | def forward(self, x): 31 | # Assume x.shape == (B, C, N, N) 32 | if self.seperate_diag: 33 | return self.conv_offdiag(x) + (self.conv_diag(x.diagonal(dim1=2, dim2=3))).diag_embed(dim1=2, dim2=3) 34 | return self.conv_offdiag(x) 35 | 36 | 37 | class Attention(nn.Module): 38 | def __init__(self, in_features): 39 | super().__init__() 40 | small_in_features = max(math.floor(in_features/10), 1) 41 | self.d_k = small_in_features 42 | 43 | self.query = nn.Sequential( 44 | nn.Linear(in_features, small_in_features), 45 | nn.Tanh(), 46 | ) 47 | self.key = nn.Linear(in_features, small_in_features) 48 | 49 | def forward(self, inp): 50 | # inp.shape should be (B,N,C) 51 | q = self.query(inp) # (B,N,C/10) 52 | k = self.key(inp) # B,N,C/10 53 | 54 | x = torch.matmul(q, k.transpose(1, 2)) / math.sqrt(self.d_k) # B,N,N 55 | 56 | x = x.transpose(1, 2) # (B,N,N) 57 | x = x.softmax(dim=2) # over rows 58 | x = torch.matmul(x, inp) # (B, N, C) 59 | return x 60 | -------------------------------------------------------------------------------- /models/mlp.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | 6 | class MLP(nn.Module): 7 | def __init__(self, in_features, feats, cfg=None): 8 | """ 9 | Element wise MLP implementation 10 | :param in_features: input's number of features 11 | :param feats: list of features for each linear layer 12 | :param cfg: configurations of to end with relu and normalization method 13 | """ 14 | super().__init__() 15 | 16 | if cfg is None: 17 | cfg = {} 18 | self.end_with_relu = cfg.get('mlp_with_relu', True) 19 | self.layers = nn.ModuleList() 20 | self.layers.append(nn.Conv1d(in_features, feats[0], 1)) 21 | for i in range(1, len(feats)): 22 | self.layers.append(nn.Conv1d(feats[i-1], feats[i], 1)) 23 | 24 | self.normalization = cfg.get('normalization', 'fro') 25 | if self.normalization == 'batchnorm': 26 | self.bns = nn.ModuleList([nn.BatchNorm1d(feat) for feat in feats]) 27 | 28 | def forward(self, x): 29 | for i, layer in enumerate(self.layers[:-1]): 30 | x = layer(x) 31 | if self.normalization == 'batchnorm': 32 | x = self.bns[i](x) 33 | else: 34 | x = x / torch.norm(x, p='fro', dim=1, keepdim=True) # BxCxN / Bx1xN 35 | x = F.relu(x) 36 | 37 | x = self.layers[-1](x) 38 | if self.normalization == 'batchnorm': 39 | x = self.bns[-1](x) 40 | else: 41 | x = x / torch.norm(x, p='fro', dim=1, keepdim=True) # BxCxN / Bx1xN 42 | if self.end_with_relu: 43 | x = F.relu(x) 44 | 45 | return x 46 | -------------------------------------------------------------------------------- /models/set_partition_gnn.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | from torch_geometric.nn import GraphConv 6 | 7 | 8 | class SetPartitionGNN(torch.nn.Module): 9 | def __init__(self, params, in_features=10): 10 | super().__init__() 11 | last = in_features 12 | self.convs = nn.ModuleList() 13 | for p in params: 14 | self.convs.append(GraphConv(last, p)) 15 | last = p 16 | 17 | self.tensor_1 = torch.tensor(1.) 18 | 19 | def forward(self, x, k=5): 20 | b, n, c = x.shape 21 | 22 | if k >= n: 23 | k = n - 1 24 | 25 | # k nearest neighbors 26 | nbors = torch.topk(torch.norm(x.unsqueeze(1) - x.unsqueeze(2), dim=3), k+1, largest=False)[1][:, :,1:] # shape b,n,k 27 | src = torch.arange(n, device=x.device).reshape(1,n,1).repeat(b,1,1).repeat(1,1,k).flatten() # shape b*n*k 28 | edge_index = torch.stack([src, nbors.flatten()]) # shape 2, b*n*k 29 | batch = torch.arange(b, device=x.device).reshape(b,1).repeat(1,k*n).flatten() # shape b*n*k 30 | edge_index = edge_index + (batch.view(1, -1) * n) # batched graphes 31 | x = x.view(b*n, c) 32 | 33 | for conv in self.convs[:-1]: 34 | x = conv(x, edge_index) 35 | x = F.relu(x) 36 | 37 | x = self.convs[-1](x, edge_index) # shape b*n,c_new 38 | x = x.view(b, n, -1) # shape b,n,c_new 39 | edge_vals = x @ x.transpose(1, 2) # outer product, shape b,n,n 40 | edge_vals = edge_vals.unsqueeze(1) # shape b,1,n,n 41 | 42 | return edge_vals 43 | -------------------------------------------------------------------------------- /models/set_partition_mlp.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | 5 | class SetPartitionMLP(torch.nn.Module): 6 | def __init__(self, params, in_features=10): 7 | super().__init__() 8 | assert params[-1] == 15**2 9 | 10 | last = in_features * 15 11 | layers = [] 12 | for p in params: 13 | layers.append(nn.Linear(last, p)) 14 | layers.append(nn.ReLU()) 15 | last = p 16 | layers = layers[:-1] 17 | self.model = nn.Sequential(*layers) 18 | 19 | def forward(self, x): 20 | B, N, C = x.shape 21 | new_x = torch.zeros((B, 15, C), device=x.device) 22 | new_x[:, :N] = x 23 | 24 | new_x = new_x.view(B, 15*C).contiguous() 25 | edge_vals = self.model(new_x).view(B, 15, 15).contiguous()[:, :N, :N].unsqueeze(1) # B,1,N,N 26 | 27 | return edge_vals 28 | -------------------------------------------------------------------------------- /models/set_to_graph.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | from models.deep_sets import DeepSet 5 | from models.layers import PsiSuffix 6 | 7 | 8 | class SetToGraph(nn.Module): 9 | def __init__(self, in_features, out_features, set_fn_feats, method, hidden_mlp, predict_diagonal, attention, cfg=None): 10 | """ 11 | SetToGraph model. 12 | :param in_features: input set's number of features per data point 13 | :param out_features: number of output features. 14 | :param set_fn_feats: list of number of features for the output of each deepsets layer 15 | :param method: transformer method - quad, lin2 or lin5 16 | :param hidden_mlp: list[int], number of features in hidden layers mlp. 17 | :param predict_diagonal: Bool. True to predict the diagonal (diagonal needs a separate psi function). 18 | :param attention: Bool. Use attention in DeepSets 19 | :param cfg: configurations of using second bias in DeepSetLayer, normalization method and aggregation for lin5. 20 | """ 21 | super(SetToGraph, self).__init__() 22 | assert method in ['lin2', 'lin5'] 23 | 24 | self.method = method 25 | if cfg is None: 26 | cfg = {} 27 | self.agg = cfg.get('agg', torch.sum) 28 | 29 | self.set_model = DeepSet(in_features=in_features, feats=set_fn_feats, attention=attention, cfg=cfg) 30 | 31 | # Suffix - from last number of features, to 1 feature per entrance 32 | d2 = (2 if method == 'lin2' else 5) * set_fn_feats[-1] 33 | hidden_mlp = [d2] + hidden_mlp + [out_features] 34 | self.suffix = PsiSuffix(hidden_mlp, predict_diagonal=predict_diagonal) 35 | 36 | def forward(self, x): 37 | x = x.transpose(2, 1) # from BxNxC to BxCxN 38 | u = self.set_model(x) # Bx(out_features)xN 39 | n = u.shape[2] 40 | 41 | if self.method == 'lin2': 42 | m1 = u.unsqueeze(2).repeat(1, 1, n, 1) # broadcast to rows 43 | m2 = u.unsqueeze(3).repeat(1, 1, 1, n) # broadcast to cols 44 | block = torch.cat((m1, m2), dim=1) 45 | elif self.method == 'lin5': 46 | m1 = u.unsqueeze(2).repeat(1, 1, n, 1) # broadcast to rows 47 | m2 = u.unsqueeze(3).repeat(1, 1, 1, n) # broadcast to cols 48 | m3 = self.agg(u, dim=2, keepdim=True).unsqueeze(3).repeat(1, 1, n, n) # sum over N, put on all 49 | m4 = u.diag_embed(dim1=2, dim2=3) # assign values to diag only 50 | m5 = self.agg(u, dim=2, keepdim=True).repeat(1, 1, n).diag_embed(dim1=2, dim2=3) # sum over N, put on diag 51 | block = torch.cat((m1, m2, m3, m4, m5), dim=1) 52 | edge_vals = self.suffix(block) # shape (B,out_features,N,N) 53 | 54 | return edge_vals 55 | -------------------------------------------------------------------------------- /models/set_to_graph_gnn.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | from torch_geometric.nn import GraphConv 6 | 7 | 8 | class SetToGraphGNN(torch.nn.Module): 9 | def __init__(self, params, in_features=2, k=10): 10 | super().__init__() 11 | self.k = k 12 | last = in_features 13 | self.convs = nn.ModuleList() 14 | for p in params: 15 | self.convs.append(GraphConv(last, p)) 16 | last = p 17 | 18 | def forward(self, x): 19 | b, n, c = x.shape 20 | 21 | k = self.k 22 | if k >= n: 23 | k = n - 1 24 | 25 | # k nearest neighbors 26 | nbors = torch.topk(torch.norm(x.unsqueeze(1) - x.unsqueeze(2), dim=3), k + 1, largest=False)[1][:, :, 27 | 1:] # shape b,n,k 28 | src = torch.arange(n, device=x.device).reshape(1, n, 1).repeat(b, 1, 1).repeat(1, 1, k).flatten() # shape b*n*k 29 | edge_index = torch.stack([src, nbors.flatten()]) # shape 2, b*n*k 30 | batch = torch.arange(b, device=x.device).reshape(b, 1).repeat(1, k * n).flatten() # shape b*n*k 31 | edge_index = edge_index + (batch.view(1, -1) * n) # batched graphes 32 | x = x.view(b * n, c) 33 | 34 | for conv in self.convs[:-1]: 35 | x = conv(x, edge_index) 36 | x = F.relu(x) 37 | 38 | x = self.convs[-1](x, edge_index) # shape b*n,c_new 39 | x = x.view(b, n, -1) # shape b,n,c_new 40 | graphs = x @ x.transpose(1, 2) # outer product, shape b,n,n 41 | graphs = graphs.unsqueeze(1) # shape b,1,n,n 42 | 43 | return graphs 44 | -------------------------------------------------------------------------------- /models/set_to_graph_mlp.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | 5 | class SetToGraphMLP(torch.nn.Module): 6 | def __init__(self, params, in_features=2, max_nodes=80): 7 | super().__init__() 8 | self.max_nodes = max_nodes 9 | last = in_features * max_nodes 10 | layers = [] 11 | for p in params: 12 | layers.append(nn.Linear(last, p)) 13 | layers.append(nn.ReLU()) 14 | last = p 15 | layers = layers[:-1] 16 | self.model = nn.Sequential(*layers) 17 | 18 | def forward(self, x): 19 | B, N, C = x.shape 20 | new_x = torch.zeros((B, self.max_nodes, C), device=x.device) 21 | new_x[:, :N] = x 22 | 23 | new_x = new_x.view(B, self.max_nodes * C).contiguous() 24 | graph = self.model(new_x).view(B, 80, 80).contiguous()[:, :N, :N] # B,N,N 25 | graph = (graph + graph.transpose(1, 2)) / 2 # symmetric 26 | graph = graph.unsqueeze(1) # shape b,1,n,n 27 | 28 | return graph 29 | -------------------------------------------------------------------------------- /models/set_to_graph_siam.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | from models.layers import PsiSuffix 5 | from models.mlp import MLP 6 | 7 | 8 | class SetToGraphSiam(nn.Module): 9 | def __init__(self, in_features, set_fn_feats, hidden_mlp, cfg=None): 10 | """ 11 | SetToGraph model. 12 | :param in_features: input set's number of features per data point 13 | :param set_fn_feats: list of number of features for the output of each deepsets layer 14 | :param hidden_mlp: list[int], number of features in hidden layers mlp. 15 | :param cfg: configurations of mlp to end with relu and normalization method 16 | """ 17 | super().__init__() 18 | 19 | # For comparison - in DeepSet we use 2 mlps each layer, here only 1, so double up. 20 | if cfg is None: 21 | cfg = {} 22 | self.set_model = MLP(in_features=in_features, feats=set_fn_feats, cfg=cfg) 23 | 24 | # Suffix - from last number of features, to 1 feature per entrance 25 | d2 = 2 * set_fn_feats[-1] 26 | hidden_mlp = [d2] + hidden_mlp + [1] 27 | self.suffix = PsiSuffix(hidden_mlp, predict_diagonal=False) 28 | 29 | def forward(self, x): 30 | x = x.transpose(2, 1) # from B,N,C to B,C,N 31 | u = self.set_model(x) # Bx(out_features)xN 32 | n = u.shape[2] 33 | 34 | m1 = u.unsqueeze(2).repeat(1, 1, n, 1) # broadcast to rows 35 | m2 = u.unsqueeze(3).repeat(1, 1, 1, n) # broadcast to cols 36 | block = torch.cat((m1, m2), dim=1) 37 | edge_vals = self.suffix(block) # shape (B,1,N,N) 38 | 39 | return edge_vals 40 | -------------------------------------------------------------------------------- /models/set_to_graph_triplets.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | 5 | class SetToGraphTri(nn.Module): 6 | def __init__(self, params, in_features=2): 7 | """ 8 | Triplets model. 9 | """ 10 | super().__init__() 11 | last = in_features 12 | layers = [] 13 | for p in params: 14 | layers.append(nn.Linear(last, p)) 15 | layers.append(nn.ReLU()) 16 | last = p 17 | layers = layers[:-1] 18 | self.model = nn.Sequential(*layers) 19 | 20 | def forward(self, x, gt): 21 | device = x.device 22 | x = self.model(x) # shape (B,N,C_out) 23 | 24 | B, N, C = x.shape 25 | loss = torch.tensor(0., requires_grad=True, device=device) 26 | 27 | dists = (x.unsqueeze(1) - x.unsqueeze(2)).pow(2).sum(3) # shape (B,N,N) 28 | 29 | tri = torch.randint(0, N, (200, 3), device=device) 30 | tri = tri[tri[:, 0] != tri[:, 1]] 31 | tri = tri[tri[:, 0] != tri[:, 2]] 32 | tri = tri[tri[:, 1] != tri[:, 2]] 33 | 34 | if gt is not None: 35 | for i in range(B): 36 | if gt[i].unique().numel() == 1: 37 | continue # only one label, cant learn from it 38 | tri_i = tri[(gt[i, tri[:, 0], tri[:, 1]].bool()) & (~gt[i, tri[:, 0], tri[:, 2]].bool())] 39 | tri_i = tri_i.unique(dim=0) 40 | if len(tri_i) == 0: 41 | continue 42 | anch, pos, neg = tri_i.t() 43 | loss = loss + torch.clamp_min(dists[i, anch, pos]-dists[i, anch, neg]+2, 0.).mean() 44 | 45 | #graphs = (dists + dists.transpose(1, 2)) / 2 # to make symmetric 46 | graphs = dists.le(1.).float() # adj matrix - 1 as threshold 47 | graphs = graphs - 0.5 # the main script uses 0 as threshold 48 | return graphs, loss 49 | -------------------------------------------------------------------------------- /models/triplets_model.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import numpy as np 4 | 5 | from models.mlp import MLP 6 | 7 | 8 | class SetPartitionTri(nn.Module): 9 | def __init__(self, in_features, mlp_features): 10 | """ 11 | SetPartitionTri model. 12 | """ 13 | super().__init__() 14 | cfg = dict(mlp_with_relu=False) 15 | self.mlp = MLP(in_features=in_features, feats=mlp_features, cfg=cfg) 16 | self.tensor_1 = torch.tensor(1., device='cuda') 17 | 18 | def forward(self, x, labels, margin=2.): 19 | device = x.device 20 | 21 | x = x.transpose(2, 1) # from BxNxC to BxCxN 22 | u = self.mlp(x) # Bx(out_features)xN 23 | u = u.transpose(2, 1) # shape BxNx(out_features) 24 | 25 | B, N, C = u.shape 26 | loss = torch.tensor(0., requires_grad=True, device=device) 27 | 28 | dists = (u.unsqueeze(1) - u.unsqueeze(2)).pow(2).sum(3) # shape (B,N,N) 29 | 30 | tri = torch.randint(0, N, (200, 3), device=device) 31 | tri = tri[tri[:, 0] != tri[:, 1]] 32 | 33 | if labels is not None: 34 | for i in range(B): 35 | if labels[i].max().item() == 0: 36 | continue # only one cluster, cant learn from it 37 | tri_i = tri[(labels[i, tri[:, 0]] == labels[i, tri[:, 1]]) & (labels[i, tri[:, 1]] != labels[i, tri[:, 2]])] 38 | tri_i = tri_i.unique(dim=0) 39 | if len(tri_i) == 0: 40 | continue 41 | 42 | anch, pos, neg = tri_i.t() 43 | loss = loss + torch.clamp_min(dists[i, anch, pos]-dists[i, anch, neg]+2, 0.).mean() 44 | 45 | # deployment - infer chosen clusters: 46 | with torch.no_grad(): 47 | pred_matrices = (dists + dists.transpose(1, 2)) / 2 # to make symmetric 48 | pred_matrices = pred_matrices.le(1.).float() # adj matrix - 1 as threshold 49 | pred_matrices[:, np.arange(N), np.arange(N)] = self.tensor_1 # each node is always connected to itself 50 | ones_now = pred_matrices.sum() 51 | ones_before = ones_now - 1 52 | while ones_now != ones_before: # get connected components - each node connected to all in its component 53 | ones_before = ones_now 54 | pred_matrices = torch.matmul(pred_matrices, pred_matrices) 55 | pred_matrices = pred_matrices.bool().float() # remain as 0-1 matrices 56 | ones_now = pred_matrices.sum() 57 | 58 | clusters = -1 * torch.ones((B, N), device=device) 59 | for i in range(N): 60 | clusters = torch.where(pred_matrices[:, i] == 1, i * self.tensor_1, clusters) 61 | 62 | return clusters.long(), loss 63 | 64 | def generate_triplets(self, labels, n_triplets): 65 | tries = 0 66 | triplets = [] 67 | labels = labels.cpu().numpy() 68 | while tries < 25 and len(triplets) < n_triplets: 69 | tries += 1 70 | idx = np.random.randint(0, labels.shape[0]) 71 | idx_matches = np.where(labels == labels[idx])[0] 72 | idx_no_matches = np.where(labels != labels[idx])[0] 73 | if len(idx_matches) > 1 and len(idx_no_matches) > 0: 74 | idx_a, idx_p = np.random.choice(idx_matches, 2, replace=False) 75 | idx_n = np.random.choice(idx_no_matches, 1)[0] 76 | triplets.append([idx_a, idx_p, idx_n]) 77 | return np.array(triplets) 78 | -------------------------------------------------------------------------------- /performance_eval/eval_test_jets.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import uproot 3 | import numpy as np 4 | import pandas as pd 5 | from datetime import datetime 6 | from sklearn import metrics 7 | 8 | from dataloaders.jets_loader import JetGraphDataset 9 | from models.triplets_model import SetPartitionTri 10 | 11 | 12 | def _get_rand_index(labels, predictions): 13 | n_items = len(labels) 14 | if (n_items < 2): 15 | return 1 16 | n_pairs = (n_items * (n_items - 1)) / 2 17 | 18 | correct_pairs = 0 19 | for item_i in range(n_items): 20 | for item_j in range(item_i + 1, n_items): 21 | label_true = labels[item_i] == labels[item_j] 22 | pred_true = predictions[item_i] == predictions[item_j] 23 | if (label_true and pred_true) or ((not label_true) and (not pred_true)): 24 | correct_pairs += 1 25 | 26 | return correct_pairs / n_pairs 27 | 28 | 29 | def _error_count(labels, predictions): 30 | n_items = len(labels) 31 | 32 | true_positives = 0 33 | false_positive = 0 34 | false_negative = 0 35 | 36 | for item_i in range(n_items): 37 | for item_j in range(item_i + 1, n_items): 38 | label_true = labels[item_i] == labels[item_j] 39 | pred_true = predictions[item_i] == predictions[item_j] 40 | if (label_true and pred_true): 41 | true_positives += 1 42 | if (not label_true) and pred_true: 43 | false_positive += 1 44 | if label_true and (not pred_true): 45 | false_negative += 1 46 | return true_positives, false_positive, false_negative 47 | 48 | 49 | def _get_recall(labels, predictions): 50 | true_positives, false_positive, false_negative = _error_count(labels, predictions) 51 | 52 | if true_positives + false_negative == 0: 53 | return 0 54 | 55 | return true_positives / (true_positives + false_negative) 56 | 57 | 58 | def _get_precision(labels, predictions): 59 | true_positives, false_positive, false_negative = _error_count(labels, predictions) 60 | 61 | if true_positives + false_positive == 0: 62 | return 0 63 | return true_positives / (true_positives + false_positive) 64 | 65 | 66 | def _f_measure(labels, predictions): 67 | precision = _get_precision(labels, predictions) 68 | recall = _get_recall(labels, predictions) 69 | 70 | if precision == 0 or recall == 0: 71 | return 0 72 | 73 | return 2 * (precision * recall) / (recall + precision) 74 | 75 | 76 | def eval_jets_on_test_set(model): 77 | 78 | pred = _predict_on_test_set(model) 79 | 80 | test_ds = uproot.open('data/jets_data_all/test/test_data.root') 81 | jet_df = test_ds['tree'].pandas.df(['jet_flav', 'trk_vtx_index'], flatten=False) 82 | jet_flav = jet_df['jet_flav'] 83 | 84 | target = [x for x in jet_df['trk_vtx_index'].values] 85 | 86 | print('Calculating scores on test set... ', end='') 87 | start = datetime.now() 88 | model_scores = {} 89 | model_scores['RI'] = np.vectorize(_get_rand_index)(target, pred) 90 | model_scores['ARI'] = np.vectorize(metrics.adjusted_rand_score)(target, pred) 91 | model_scores['P'] = np.vectorize(_get_precision)(target, pred) 92 | model_scores['R'] = np.vectorize(_get_recall)(target, pred) 93 | model_scores['F1'] = np.vectorize(_f_measure)(target, pred) 94 | 95 | end = datetime.now() 96 | print(f': {str(end - start).split(".")[0]}') 97 | 98 | flavours = {5: 'b jets', 4: 'c jets', 0: 'light jets'} 99 | metrics_to_table = ['P', 'R', 'F1', 'RI', 'ARI'] 100 | 101 | df = pd.DataFrame(index=flavours.values(), columns=metrics_to_table) 102 | 103 | for flav_n, flav in flavours.items(): 104 | for metric in metrics_to_table: 105 | mean_metric = np.mean(model_scores[metric][jet_flav == flav_n]) 106 | df.at[flav, metric] = mean_metric 107 | 108 | return df 109 | 110 | 111 | def _predict_on_test_set(model): 112 | test_ds = JetGraphDataset('test', random_permutation=False) 113 | model.eval() 114 | 115 | n_tracks = [test_ds[i][0].shape[0] for i in range(len(test_ds))] 116 | 117 | indx_list = [] 118 | predictions = [] 119 | 120 | for tracks_in_jet in range(2, np.amax(n_tracks)+1): 121 | trk_indxs = np.where(np.array(n_tracks) == tracks_in_jet)[0] 122 | if len(trk_indxs) < 1: 123 | continue 124 | indx_list += list(trk_indxs) 125 | 126 | input_batch = torch.stack([test_ds[i][0] for i in trk_indxs]) # shape (B, N_i, 10) 127 | if isinstance(model, SetPartitionTri): 128 | predictions += list(model(input_batch, None)[0].cpu().data.numpy()) 129 | else: 130 | edge_vals = model(input_batch).squeeze(1) 131 | predictions += list(infer_clusters(edge_vals).cpu().data.numpy()) # Shape 132 | 133 | sorted_predictions = [list(x) for _, x in sorted(zip(indx_list, predictions))] 134 | return sorted_predictions 135 | 136 | 137 | def infer_clusters(edge_vals): 138 | ''' 139 | Infer the clusters. Enforce symmetry. 140 | :param edge_vals: predicted edge score values. shape (B, N, N) 141 | :return: long tensor shape (B, N) of the clusters. 142 | ''' 143 | # deployment - infer chosen clusters: 144 | b, n, _ = edge_vals.shape 145 | with torch.no_grad(): 146 | pred_matrices = edge_vals + edge_vals.transpose(1, 2) # to make symmetric 147 | pred_matrices = pred_matrices.ge(0.).float() # adj matrix - 0 as threshold 148 | pred_matrices[:, np.arange(n), np.arange(n)] = 1. # each node is always connected to itself 149 | ones_now = pred_matrices.sum() 150 | ones_before = ones_now - 1 151 | while ones_now != ones_before: # get connected components - each node connected to all in its component 152 | ones_before = ones_now 153 | pred_matrices = torch.matmul(pred_matrices, pred_matrices) 154 | pred_matrices = pred_matrices.bool().float() # remain as 0-1 matrices 155 | ones_now = pred_matrices.sum() 156 | 157 | clusters = -1 * torch.ones((b, n), device=edge_vals.device) 158 | tensor_1 = torch.tensor(1., device=edge_vals.device) 159 | for i in range(n): 160 | clusters = torch.where(pred_matrices[:, i] == 1, i * tensor_1, clusters) 161 | 162 | return clusters.long() 163 | 164 | 165 | --------------------------------------------------------------------------------