├── LICENSE ├── README.md ├── baselines ├── models │ ├── dkl_modules.py │ └── dkl_run.py ├── sklearn_models.py ├── sklearn_tune.py └── utils │ ├── baseline_hyper_tuner.py │ └── hyper_tuning_utils.py ├── environment.yml ├── environment_no_gpu.yml ├── npt ├── batch_dataset.py ├── column_encoding_dataset.py ├── configs.py ├── constants.py ├── datasets │ ├── base.py │ ├── boston_housing.py │ ├── breast_cancer.py │ ├── cifar10.py │ ├── concrete.py │ ├── debug.py │ ├── forest_cover.py │ ├── higgs.py │ ├── image_utils.py │ ├── income.py │ ├── kick.py │ ├── mnist.py │ ├── poker_hand.py │ ├── protein.py │ └── yacht.py ├── distribution.py ├── loss.py ├── mask.py ├── model │ ├── __init__.py │ ├── image_patcher.py │ ├── npt.py │ └── npt_modules.py ├── optim.py ├── train.py └── utils │ ├── analyse_wandb_project.py │ ├── analysis.py │ ├── batch_utils.py │ ├── config_utils.py │ ├── cv_utils.py │ ├── data_loading_utils.py │ ├── debug.py │ ├── encode_utils.py │ ├── eval_checkpoint_utils.py │ ├── image_loading_utils.py │ ├── logging_utils.py │ ├── memory_utils.py │ ├── model_init_utils.py │ ├── optim_utils.py │ ├── plotting.py │ ├── preprocess_utils.py │ ├── train_utils.py │ └── viz_att_maps.py ├── run.py └── scripts ├── ablations.sh ├── image_data.sh ├── row_corruption_tests.sh └── uci_class_reg.sh /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright 2021 The NPT Authors 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Self-Attention Between Datapoints: Going Beyond Individual Input-Output Pairs in Deep Learning 2 | 3 | **[Overview](#overview)** 4 | | **[Abstract](#abstract)** 5 | | **[Installation](#installation)** 6 | | **[Examples](#examples)** 7 | | **[Citation](#citation)** 8 | 9 | [![arXiv](https://img.shields.io/badge/arXiv-2106.02584-b31b1b.svg)](https://arxiv.org/abs/2106.02584) 10 | [![Python 3.8](https://img.shields.io/badge/python-3.8-blue.svg)](https://www.python.org/downloads/release/python-380/) 11 | [![Pytorch](https://img.shields.io/badge/Pytorch-1.7-red.svg)](https://shields.io/) 12 | [![License](https://img.shields.io/badge/License-Apache%202.0-blue.svg)](https://opensource.org/licenses/Apache-2.0) 13 | [![Maintenance](https://img.shields.io/badge/Maintained%3F-yes-green.svg)](https://GitHub.com/Naereen/StrapDown.js/graphs/commit-activity) 14 | 15 | 16 | ## Overview 17 | 18 | Hi, good to see you here! 👋 19 | 20 | Thanks for checking out the code for Non-Parametric Transformers (NPTs). 21 | 22 | This codebase will allow you to reproduce experiments from the paper as well as use NPTs for your own research. 23 | 24 | ## Abstract 25 | 26 | We challenge a common assumption underlying most supervised deep learning: that a model makes a prediction depending only on its parameters and the features of a single input. To this end, we introduce a general-purpose deep learning architecture that takes as input the entire dataset instead of processing one datapoint at a time. Our approach uses self-attention to reason about relationships between datapoints explicitly, which can be seen as realizing non-parametric models using parametric attention mechanisms. However, unlike conventional non-parametric models, we let the model learn end-to-end from the data how to make use of other datapoints for prediction. Empirically, our models solve cross-datapoint lookup and complex reasoning tasks unsolvable by traditional deep learning models. We show highly competitive results on tabular data, early results on CIFAR-10, and give insight into how the model makes use of the interactions between points. 27 | 28 | ## Installation 29 | 30 | Set up and activate the Python environment by executing 31 | 32 | ``` 33 | conda env create -f environment.yml 34 | conda activate npt 35 | ``` 36 | 37 | For now, we recommend installing CUDA <= 10.2: 38 | 39 | See [issue with CUDA >= 11.0 here](https://github.com/pytorch/pytorch/issues/47908). 40 | 41 | If you are running this on a system without a GPU, use the above with `environment_no_gpu.yml` instead. 42 | 43 | ## Examples 44 | 45 | We now give some basic examples of running NPT. 46 | 47 | NPT downloads all supported datasets automatically, so you don't need to worry about that. 48 | 49 | We use [wandb](http://wandb.com/) to log experimental results. 50 | Wandb allows us to conveniently track run progress online. 51 | If you do not want wandb enabled, you can run `wandb off` in the shell where you execute NPT. 52 | 53 | For example, run this to explore NPT with default configuration on Breast Cancer 54 | 55 | ``` 56 | python run.py --data_set breast-cancer 57 | ``` 58 | 59 | Another example: A run on the poker-hand dataset may look like this 60 | 61 | ``` 62 | python run.py --data_set poker-hand \ 63 | --exp_batch_size 4096 \ 64 | --exp_print_every_nth_forward 100 65 | ``` 66 | 67 | You can find all possible config arguments and descriptions in `NPT/configs.py` or using `python run.py --help`. 68 | 69 | In `scripts/` we provide a list with the runs and correct hyperparameter configurations presented in the paper. 70 | 71 | We hope you enjoy using the code and please feel free to reach out with any questions 😊 72 | 73 | 74 | ## Citation 75 | 76 | If you find this code helpful for your work, please cite our paper 77 | [Paper](https://arxiv.org/abs/2106.02584) as 78 | 79 | ```bibtex 80 | @article{kossen2021self, 81 | title={Self-Attention Between Datapoints: Going Beyond Individual Input-Output Pairs in Deep Learning}, 82 | author={Kossen, Jannik and Band, Neil and Gomez, Aidan N. and Lyle, Clare and Rainforth, Tom and Gal, Yarin}, 83 | journal={arXiv:2106.02584}, 84 | year={2021} 85 | } 86 | ``` 87 | -------------------------------------------------------------------------------- /baselines/models/dkl_modules.py: -------------------------------------------------------------------------------- 1 | import math 2 | 3 | import gpytorch 4 | import numpy as np 5 | import torch 6 | import torch.nn.functional as F 7 | from sklearn.cluster import KMeans 8 | from torch import nn 9 | 10 | 11 | # MLP feature extractor 12 | class MLP(nn.Module): 13 | def __init__( 14 | self, input_size, hidden_layer_sizes, output_size, 15 | dropout_prob=None): 16 | super(MLP, self).__init__() 17 | fc_layers = [] 18 | all_layer_sizes = [input_size] + hidden_layer_sizes 19 | for layer_size_idx in range(len(all_layer_sizes) - 1): 20 | fc_layers.append( 21 | nn.Linear(all_layer_sizes[layer_size_idx], 22 | all_layer_sizes[layer_size_idx + 1])) 23 | 24 | self.fc_layers = nn.ModuleList(fc_layers) 25 | self.output_layer = nn.Linear( 26 | hidden_layer_sizes[-1], output_size) 27 | 28 | if dropout_prob is not None: 29 | self.dropout = torch.nn.Dropout(p=dropout_prob) 30 | else: 31 | self.dropout = None 32 | 33 | def forward(self, x): 34 | for fc_layer in self.fc_layers: 35 | x = fc_layer(x) 36 | x = F.relu(x) 37 | 38 | if self.dropout is not None: 39 | x = self.dropout(x) 40 | 41 | output = self.output_layer(x) 42 | return output 43 | 44 | 45 | # GP Layer 46 | # Trains one GP per feature, as per the SV-DKL paper 47 | # The outputs of those GPs are mixed in the Softmax Likelihood for classification 48 | class GaussianProcessLayer(gpytorch.models.ApproximateGP): 49 | def __init__(self, num_dim, grid_bounds=(-10., 10.), grid_size=64): 50 | 51 | if num_dim > 1: 52 | batch_shape = torch.Size([num_dim]) 53 | else: 54 | batch_shape = torch.Size([]) 55 | 56 | variational_distribution = ( 57 | gpytorch.variational.CholeskyVariationalDistribution( 58 | num_inducing_points=grid_size, batch_shape=batch_shape)) 59 | 60 | # Our base variational strategy is a GridInterpolationVariationalStrategy, 61 | # which places variational inducing points on a Grid 62 | # We wrap it with a MultitaskVariationalStrategy so that our output is a vector-valued GP 63 | variational_strategy = gpytorch.variational.GridInterpolationVariationalStrategy( 64 | self, grid_size=grid_size, grid_bounds=[grid_bounds], 65 | variational_distribution=variational_distribution 66 | ) 67 | 68 | if num_dim > 1: 69 | variational_strategy = gpytorch.variational.IndependentMultitaskVariationalStrategy( 70 | variational_strategy, num_tasks=num_dim) 71 | 72 | super().__init__(variational_strategy) 73 | 74 | self.covar_module = gpytorch.kernels.ScaleKernel( 75 | gpytorch.kernels.RBFKernel( 76 | lengthscale_prior=gpytorch.priors.SmoothedBoxPrior( 77 | math.exp(-1), math.exp(1), sigma=0.1, transform=torch.exp 78 | ) 79 | ) 80 | ) 81 | 82 | self.mean_module = gpytorch.means.ConstantMean() 83 | self.grid_bounds = grid_bounds 84 | 85 | def forward(self, x): 86 | mean = self.mean_module(x) 87 | covar = self.covar_module(x) 88 | return gpytorch.distributions.MultivariateNormal(mean, covar) 89 | 90 | 91 | # Stochastic Variational Deep Kernel Learning 92 | # Wilson et al. 2016 93 | # https://arxiv.org/abs/1611.00336 94 | # https://docs.gpytorch.ai/en/v1.2.1/examples/06_PyTorch_NN_Integration_DKL/ 95 | # Deep_Kernel_Learning_DenseNet_CIFAR_Tutorial.html 96 | class DKLClassificationModel(gpytorch.Module): 97 | def __init__(self, feature_extractor, num_dim, grid_bounds=(-10., 10.)): 98 | super(DKLClassificationModel, self).__init__() 99 | self.feature_extractor = feature_extractor 100 | self.gp_layer = GaussianProcessLayer( 101 | num_dim=num_dim, grid_bounds=grid_bounds) 102 | self.grid_bounds = grid_bounds 103 | self.num_dim = num_dim 104 | 105 | def forward(self, x): 106 | features = self.feature_extractor(x) 107 | features = gpytorch.utils.grid.scale_to_bounds(features, self.grid_bounds[0], self.grid_bounds[1]) 108 | # This next line makes it so that we learn a GP for each feature 109 | features = features.transpose(-1, -2).unsqueeze(-1) 110 | res = self.gp_layer(features) 111 | return res 112 | 113 | 114 | class DKLInducingPointGP(gpytorch.models.ApproximateGP): 115 | def __init__(self, n_inducing_points, feature_extractor, batch_size, X_train): 116 | inducing_points = self.get_inducing_points(X_train, n_inducing_points, feature_extractor, batch_size) 117 | variational_distribution = ( 118 | gpytorch.variational.CholeskyVariationalDistribution( 119 | inducing_points.size(0))) 120 | variational_strategy = gpytorch.variational.VariationalStrategy( 121 | self, inducing_points, variational_distribution, 122 | learn_inducing_locations=True) 123 | super(DKLInducingPointGP, self).__init__(variational_strategy) 124 | self.feature_extractor_batch_size = batch_size 125 | self.mean_module = gpytorch.means.ConstantMean() 126 | self.covar_module = gpytorch.kernels.ScaleKernel(gpytorch.kernels.RBFKernel()) 127 | 128 | def forward(self, x): 129 | mean_x = self.mean_module(x) 130 | covar_x = self.covar_module(x) 131 | return gpytorch.distributions.MultivariateNormal(mean_x, covar_x) 132 | 133 | def get_inducing_points(self, X_train, n_inducing_points, 134 | feature_extractor, feature_extractor_batch_size): 135 | if torch.cuda.is_available(): 136 | self.feature_extractor = self.feature_extractor.cuda() 137 | 138 | n_inducing_points = min(X_train.size(0), n_inducing_points) 139 | n_embeds = min(X_train.size(0), n_inducing_points * 10) 140 | feature_extractor_embeds = [] 141 | 142 | # Input indices to embed 143 | input_indices = np.random.choice( 144 | np.arange(X_train.size(0)), size=n_embeds, replace=False) 145 | 146 | with torch.no_grad(): 147 | for i in range(0, n_inducing_points, feature_extractor_batch_size): 148 | batch_indices = input_indices[i:i+feature_extractor_batch_size] 149 | feature_extractor_embeds.append( 150 | feature_extractor(X_train[batch_indices])) 151 | 152 | feature_extractor_embeds = torch.cat(feature_extractor_embeds).numpy() 153 | km = KMeans(n_clusters=n_inducing_points) 154 | km.fit(feature_extractor_embeds) 155 | if True: 156 | a = 1 157 | inducing_points = torch.from_numpy(km.cluster_centers_) 158 | return inducing_points 159 | 160 | 161 | class DKLRegressionModel(gpytorch.Module): 162 | def __init__(self, feature_extractor, n_inducing_points, batch_size, X_train): 163 | super(DKLRegressionModel, self).__init__() 164 | self.feature_extractor = feature_extractor 165 | self.gp_layer = DKLInducingPointGP(n_inducing_points, feature_extractor, batch_size, X_train) 166 | 167 | def forward(self, x): 168 | features = self.feature_extractor(x) 169 | res = self.gp_layer(features) 170 | return res 171 | -------------------------------------------------------------------------------- /baselines/models/dkl_run.py: -------------------------------------------------------------------------------- 1 | """Train loop based on 2 | https://docs.gpytorch.ai/en/v1.2.1/examples/06_PyTorch_NN_Integration_DKL/ 3 | Deep_Kernel_Learning_DenseNet_CIFAR_Tutorial.html 4 | """ 5 | import gpytorch 6 | import torch 7 | import tqdm 8 | from torch.optim import SGD 9 | from torch.optim.lr_scheduler import MultiStepLR 10 | from torch.utils.data import TensorDataset, DataLoader 11 | 12 | from baselines.models.dkl_modules import MLP, DKLClassificationModel, \ 13 | DKLRegressionModel 14 | 15 | 16 | def main(c, dataset): 17 | tune_dkl(c, dataset, hyper_dict=None) 18 | 19 | 20 | def build_dataloaders(dataset, batch_size): 21 | data_dict = dataset.cv_dataset.data_dict 22 | metadata = dataset.metadata 23 | D = metadata['D'] 24 | cat_target_cols, num_target_cols = ( 25 | metadata['cat_target_cols'], metadata['num_target_cols']) 26 | target_cols = list(sorted(cat_target_cols + num_target_cols)) 27 | non_target_cols = sorted( 28 | list(set(range(D)) - set(target_cols))) 29 | train_indices, val_indices, test_indices = ( 30 | tuple(data_dict['new_train_val_test_indices'])) 31 | data_arrs = data_dict['data_arrs'] 32 | X = [] 33 | y = None 34 | 35 | for i, col in enumerate(data_arrs): 36 | if i in non_target_cols: 37 | col = col[:, :-1] 38 | X.append(col) 39 | else: 40 | col = col[:, :-1] 41 | y = col 42 | 43 | X = torch.cat(X, dim=-1) 44 | X_train, X_val, X_test = X[train_indices], X[val_indices], X[test_indices] 45 | 46 | if y.shape[1] > 1: 47 | dataset_is_classification = True 48 | num_classes = y.shape[1] 49 | y = torch.argmax(y.long(), dim=1) 50 | else: 51 | dataset_is_classification = False 52 | num_classes = None 53 | y = torch.squeeze(y) 54 | 55 | y_train, y_val, y_test = y[train_indices], y[val_indices], y[test_indices] 56 | train_dataset, val_dataset, test_dataset = ( 57 | TensorDataset(X_train, y_train), 58 | TensorDataset(X_val, y_val), 59 | TensorDataset(X_test, y_test)) 60 | train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, drop_last=False) 61 | val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False, drop_last=False) 62 | test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False, drop_last=False) 63 | return ( 64 | (train_loader, val_loader, test_loader), X.shape[1], 65 | dataset_is_classification, num_classes, X_train) 66 | 67 | 68 | def get_likelihood(dataset_is_classification, num_features, num_classes=None): 69 | if dataset_is_classification: 70 | likelihood = gpytorch.likelihoods.SoftmaxLikelihood( 71 | num_features=num_features, num_classes=num_classes) 72 | else: 73 | noise_prior = None 74 | if False: 75 | noise_prior_loc = 0.1 76 | noise_prior_scale = 0.1 77 | noise_prior = gpytorch.priors.NormalPrior( 78 | loc=noise_prior_loc, scale=noise_prior_scale) 79 | likelihood = gpytorch.likelihoods.GaussianLikelihood( 80 | noise_prior=noise_prior) 81 | 82 | return likelihood 83 | 84 | 85 | def tune_dkl(c, dataset, hyper_dict): 86 | batch_size = c.exp_batch_size 87 | dataloaders, input_dims, is_classification, num_classes, X_train = ( 88 | build_dataloaders(dataset=dataset, batch_size=batch_size)) 89 | train_loader, val_loader, test_loader = dataloaders 90 | 91 | # Define some hypers here 92 | 93 | # This is the output of the feature extractor, which is then 94 | # transformed to grid space (in classification) or 95 | # is the size of the inducing points (regression) 96 | # We init the inducing points by projecting a random selection of 97 | # training points, and running KMeans on that 98 | num_features = 10 99 | hidden_layers = [100] 100 | dropout_prob = 0.1 101 | n_epochs = 1000 102 | lr = 0.001 103 | feature_extractor_weight_decay = 1e-4 104 | scheduler__milestones = [0.5 * n_epochs, 0.75 * n_epochs] 105 | scheduler__gamma = 0.1 106 | n_inducing_points = 1000 107 | 108 | likelihood = get_likelihood( 109 | dataset_is_classification=is_classification, 110 | num_features=num_features, num_classes=num_classes) 111 | 112 | # Define some hypers here 113 | 114 | feature_extractor = MLP( 115 | input_size=input_dims, hidden_layer_sizes=hidden_layers, 116 | output_size=num_features, dropout_prob=dropout_prob) 117 | 118 | if is_classification: 119 | model = DKLClassificationModel(feature_extractor, num_dim=num_features) 120 | else: 121 | model = DKLRegressionModel( 122 | feature_extractor, n_inducing_points, batch_size, X_train) 123 | 124 | # If you run this example without CUDA, I hope you like waiting! 125 | if torch.cuda.is_available(): 126 | model = model.cuda() 127 | likelihood = likelihood.cuda() 128 | 129 | # Train loop 130 | optimizer = SGD([ 131 | {'params': model.feature_extractor.parameters(), 132 | 'weight_decay': feature_extractor_weight_decay}, 133 | {'params': model.gp_layer.hyperparameters(), 'lr': lr * 0.01}, 134 | {'params': model.gp_layer.variational_parameters()}, 135 | {'params': likelihood.parameters()}, 136 | ], lr=lr, momentum=0.9, nesterov=True, weight_decay=0) 137 | scheduler = MultiStepLR( 138 | optimizer, milestones=scheduler__milestones, gamma=scheduler__gamma) 139 | mll = gpytorch.mlls.VariationalELBO( 140 | likelihood, model.gp_layer, num_data=len(train_loader.dataset)) 141 | 142 | def train(epoch): 143 | model.train() 144 | likelihood.train() 145 | 146 | minibatch_iter = tqdm.tqdm(train_loader, desc=f"(Epoch {epoch}) Minibatch") 147 | with gpytorch.settings.num_likelihood_samples(8): 148 | for data, target in minibatch_iter: 149 | if torch.cuda.is_available(): 150 | data, target = data.cuda(), target.cuda() 151 | 152 | data = data.reshape(data.size(0), -1) 153 | 154 | optimizer.zero_grad() 155 | output = model(data) 156 | loss = -mll(output, target) 157 | loss.backward() 158 | optimizer.step() 159 | minibatch_iter.set_postfix(loss=loss.item()) 160 | 161 | def test(data_loader, mode): 162 | model.eval() 163 | likelihood.eval() 164 | 165 | if is_classification: 166 | correct = 0 167 | else: 168 | mse = torch.zeros(1) 169 | num_batches = 0 170 | 171 | # This gives us 16 samples from the predictive distribution 172 | with torch.no_grad(), gpytorch.settings.num_likelihood_samples(16): 173 | for data, target in data_loader: 174 | if torch.cuda.is_available(): 175 | data, target = data.cuda(), target.cuda() 176 | 177 | data = data.reshape(data.size(0), -1) 178 | 179 | if is_classification: 180 | output = likelihood(model(data)) 181 | pred = output.probs.mean(0) 182 | pred = pred.argmax(-1) # Taking the mean over all of the sample we've drawn 183 | correct += pred.eq(target.view_as(pred)).cpu().sum() 184 | else: 185 | preds = model(data) 186 | mse += torch.mean((preds.mean - target.cpu()) ** 2) 187 | num_batches += 1 188 | 189 | if is_classification: 190 | print('{} set: Accuracy: {}/{} ({}%)'.format( 191 | mode, 192 | correct, len(test_loader.dataset), 100. * correct / float(len(data_loader.dataset)) 193 | )) 194 | else: 195 | print('{} set: MSE: {}'.format(mode, mse / num_batches)) 196 | 197 | for epoch in range(1, n_epochs + 1): 198 | with gpytorch.settings.use_toeplitz(False): 199 | train(epoch) 200 | test(val_loader, mode='Val') 201 | test(data_loader=test_loader, mode='Test') 202 | scheduler.step() 203 | state_dict = model.state_dict() 204 | likelihood_state_dict = likelihood.state_dict() 205 | torch.save({'model': state_dict, 'likelihood': likelihood_state_dict}, 'dkl_cifar_checkpoint.dat') 206 | -------------------------------------------------------------------------------- /baselines/sklearn_tune.py: -------------------------------------------------------------------------------- 1 | from pprint import pprint 2 | 3 | from sklearn.model_selection import GridSearchCV 4 | 5 | from baselines.sklearn_models import ( 6 | SKLEARN_CLASSREG_HYPERS, SKLEARN_CLASSREG_MODELS) 7 | from baselines.utils.baseline_hyper_tuner import BaselineHyperTuner 8 | from baselines.utils.hyper_tuning_utils import parse_string_list 9 | 10 | # Hyperparameter search algorithms 11 | HYPER_SEARCH_ALGS = { 12 | 'Grid': GridSearchCV 13 | } 14 | 15 | 16 | def run_sklearn_hypertuning( 17 | dataset, wandb_args, args, c, wandb_run): 18 | if c.sklearn_hyper_search == 'Random': 19 | raise NotImplementedError 20 | 21 | search_alg = HYPER_SEARCH_ALGS[c.sklearn_hyper_search] 22 | models = c.sklearn_model 23 | 24 | if models == 'All': 25 | models_dict = SKLEARN_CLASSREG_MODELS 26 | 27 | else: 28 | models = parse_string_list(models) 29 | models_dict = {} 30 | for model in models: 31 | try: 32 | models_dict[model] = SKLEARN_CLASSREG_MODELS[model] 33 | except KeyError: 34 | raise NotImplementedError( 35 | f'Have not implemented model {c.sklearn_model}') 36 | 37 | print('Running sklearn tuning loop with models:') 38 | pprint(models_dict) 39 | 40 | baseline_hyper_tuner = BaselineHyperTuner( 41 | dataset=dataset, wandb_args=wandb_args, args=args, c=c, 42 | wandb_run=wandb_run, models_dict=models_dict, 43 | hypers_dict=SKLEARN_CLASSREG_HYPERS, 44 | search_alg=search_alg, verbose=c.sklearn_verbose, 45 | n_jobs=c.sklearn_n_jobs) 46 | 47 | baseline_hyper_tuner.run_hypertuning() 48 | -------------------------------------------------------------------------------- /baselines/utils/hyper_tuning_utils.py: -------------------------------------------------------------------------------- 1 | from pytorch_tabnet.metrics import Metric 2 | from sklearn.metrics import log_loss 3 | 4 | 5 | def add_baseline_random_state(hypers_list, seed): 6 | random_seed_like_names = ['random_state', 'seed', 'random_seed'] 7 | for hypers_dict in hypers_list: 8 | for key in random_seed_like_names: 9 | if key in hypers_dict.keys(): 10 | hypers_dict[key] = [seed] 11 | 12 | return hypers_list 13 | 14 | 15 | def parse_string_list(string): 16 | if ',' in string: 17 | string = string.replace('[', '').replace(']', '') 18 | string = [i.strip() for i in string.split(',')] 19 | else: 20 | string = [string] 21 | return string 22 | 23 | 24 | def modified_tabnet(TabNetModel): 25 | # Add max_epochs, patience, batch_size as member variables to TabNet 26 | 27 | attributes = [ 28 | 'n_d', 'n_a', 'n_steps', 'gamma', 'cat_idxs', 'cat_dims', 29 | 'cat_emb_dim', 'n_independent', 'n_shared', 'epsilon', 'momentum', 30 | 'lambda_sparse', 'seed', 'clip_value', 'verbose', 'optimizer_fn', 31 | 'optimizer_params', 'scheduler_fn', 'scheduler_params', 'mask_type', 32 | 'input_dim', 'output_dim', 'device_name', 'labels'] 33 | 34 | class ModifiedTabNetModel(TabNetModel): 35 | """""" 36 | def __init__( 37 | # need to list all params here, s.t. sklearn is happy 38 | self, 39 | n_d='dummy', 40 | n_a='dummy', 41 | n_steps='dummy', 42 | gamma='dummy', 43 | cat_idxs='dummy', 44 | cat_dims='dummy', 45 | cat_emb_dim='dummy', 46 | n_independent='dummy', 47 | n_shared='dummy', 48 | epsilon='dummy', 49 | momentum='dummy', 50 | lambda_sparse='dummy', 51 | seed='dummy', 52 | clip_value='dummy', 53 | verbose='dummy', 54 | optimizer_fn='dummy', 55 | optimizer_params='dummy', 56 | scheduler_fn='dummy', 57 | scheduler_params='dummy', 58 | mask_type='dummy', 59 | input_dim='dummy', 60 | output_dim='dummy', 61 | device_name='dummy', 62 | max_epochs='dummy', 63 | patience='dummy', 64 | batch_size='dummy', 65 | virtual_batch_size='dummy', 66 | eval_metric='dummy', 67 | labels='dummy'): 68 | 69 | # intercept kwargs and remove injection attributes 70 | # set injection attributes (if used at init) 71 | # however, sklearn does not do this!! sklearn inits with 72 | # default and then uses set_params() 73 | self.injected_attributes = dict( 74 | max_epochs=200, 75 | patience=15, 76 | batch_size=1024, 77 | virtual_batch_size=128, 78 | eval_metric=None) 79 | 80 | for attribute in self.injected_attributes: 81 | value = eval(attribute) 82 | if value != 'dummy': 83 | self.injected_attributes[attribute] = value 84 | setattr(self, attribute, value) 85 | else: 86 | # need to write default value, s.t. parameter is present 87 | setattr( 88 | self, attribute, 89 | self.injected_attributes[attribute]) 90 | 91 | # filter out non-dummy, s.t. default initialisation can still work 92 | pass_on_kwargs = dict() 93 | for attribute in attributes: 94 | value = eval(attribute) 95 | if value != 'dummy' and attribute != 'labels': 96 | pass_on_kwargs[attribute] = value 97 | if value != 'dummy' and attribute == 'labels': 98 | print('Passing labels explicitly to TabNet.') 99 | setattr(self, attribute, value) 100 | 101 | super().__init__(**pass_on_kwargs) 102 | 103 | def fit(self, *args, **kwargs): 104 | 105 | # inject desired epochs/patience 106 | # sklearn does not use __init__ to set params 107 | # 108 | injected_attributes = { 109 | attribute: getattr(self, attribute) 110 | for attribute in self.injected_attributes} 111 | 112 | kwargs.update(injected_attributes) 113 | 114 | # Need to switch the metric here. 115 | try: 116 | kwargs['eval_metric'] = [ 117 | get_label_log_loss_metric(self.labels)] 118 | print('Labels were provided to TabNet run. ' 119 | 'Injecting a logloss with explicit labels to ' 120 | 'avoid edge case bugs.') 121 | # The above deals with cases like the validation set not 122 | # covering the full set of labels 123 | except Exception as e: 124 | print(e) 125 | 126 | return super().fit(*args, **kwargs) 127 | 128 | def predict_proba(self, *args, **kwargs): 129 | return super().predict_proba(*args, **kwargs).astype('float64') 130 | 131 | def set_params(self, **kwargs): 132 | """Used in the Sklearn grid search to set parameters 133 | before fit and evaluation on any cross-validation split.""" 134 | print('Setting TabNet parameters.') 135 | for var_name, value in kwargs.items(): 136 | if var_name == 'labels': 137 | continue 138 | 139 | # setattr(self, param_key, param_value) 140 | try: 141 | exec(f"global previous_val; previous_val = self.{var_name}") 142 | if previous_val != value: # noqa 143 | wrn_msg = ( 144 | f"NPT Hyperparameter Tuning: {var_name} changed " 145 | f"from {previous_val} to {value}") # noqa 146 | print(wrn_msg) 147 | exec(f"self.{var_name} = value") 148 | except AttributeError: 149 | exec(f"self.{var_name} = value") 150 | 151 | return self 152 | 153 | return ModifiedTabNetModel 154 | 155 | 156 | def get_label_log_loss_metric(labels): 157 | class LabelLogLoss(Metric): 158 | """ 159 | LogLoss with explicitly specified label set, to avoid edge cases 160 | in which a batch does not contain all different categories (e.g., 161 | this commonly happens in the heavily imbalanced poker-hand 162 | synthetic dataset). 163 | 164 | Code from TabNet 165 | https://github.com/dreamquark-ai/tabnet/blob/ 166 | 5e4e8099335ebddd6b297b16aa40cf0bad145b4a/pytorch_tabnet/metrics.py#L444 167 | """ 168 | 169 | def __init__(self): 170 | self._name = "labellogloss" 171 | self._maximize = False 172 | 173 | def __call__(self, y_true, y_score): 174 | """ 175 | Compute LogLoss of predictions. 176 | Parameters 177 | ---------- 178 | y_true : np.ndarray 179 | Target matrix or vector 180 | y_score : np.ndarray 181 | Score matrix or vector 182 | Returns 183 | ------- 184 | float 185 | LogLoss of predictions vs targets. 186 | """ 187 | return log_loss(y_true, y_score, labels=labels) 188 | 189 | return LabelLogLoss 190 | -------------------------------------------------------------------------------- /environment.yml: -------------------------------------------------------------------------------- 1 | name: npt 2 | channels: 3 | - pytorch 4 | - defaults 5 | dependencies: 6 | - python=3.8 7 | - cudatoolkit=10.1 8 | - pytorch=1.7 9 | - torchvision 10 | - pip 11 | - pip: 12 | - wandb 13 | - scikit-learn 14 | - pandas 15 | - pycodestyle 16 | - dotmap 17 | - line-profiler 18 | - tqdm 19 | - pytorch-lightning 20 | - patool 21 | - wget 22 | - scikit-optimize 23 | - xgboost 24 | - catboost 25 | - transformers 26 | - pytorch-tabnet 27 | - numba 28 | - ogb 29 | - memory-profiler 30 | - seaborn 31 | - tensorboardX 32 | - seaborn 33 | - fairseq 34 | - dotmap 35 | - tabnet 36 | - lightgbm 37 | - jupyterlab 38 | - notebook 39 | - matplotlib 40 | - qhoptim 41 | - torchvision 42 | - gpytorch -------------------------------------------------------------------------------- /environment_no_gpu.yml: -------------------------------------------------------------------------------- 1 | name: npt 2 | channels: 3 | - pytorch 4 | - defaults 5 | dependencies: 6 | - python=3.8 7 | - pytorch=1.7 8 | - torchvision 9 | - pip 10 | - pip: 11 | - wandb 12 | - scikit-learn 13 | - pandas 14 | - pycodestyle 15 | - dotmap 16 | - line-profiler 17 | - tqdm 18 | - pytorch-lightning 19 | - patool 20 | - wget 21 | - scikit-optimize 22 | - xgboost 23 | - catboost 24 | - transformers 25 | - pytorch-tabnet 26 | - numba 27 | - ogb 28 | - memory-profiler 29 | - seaborn 30 | - tensorboardX 31 | - seaborn 32 | - fairseq 33 | - dotmap 34 | - tabnet 35 | - lightgbm 36 | - jupyterlab 37 | - notebook 38 | - matplotlib 39 | - qhoptim 40 | - torchvision 41 | - gpytorch -------------------------------------------------------------------------------- /npt/constants.py: -------------------------------------------------------------------------------- 1 | # ########## Classification/Regression Mask Modes ########## 2 | # #### Stochastic Label Masking #### 3 | # On these labels, stochastic masking may take place at training time, vali- 4 | # dation time, or at test time. 5 | DATA_MODE_TO_LABEL_BERT_MODE = { 6 | 'train': ['train'], 7 | 'val': ['train'], 8 | 'test': ['train', 'val'], 9 | } 10 | 11 | # However, even when we do stochastic label masking, some labels will be 12 | # masked out deterministically, to avoid information leaks. 13 | DATA_MODE_TO_LABEL_BERT_FIXED = { 14 | 'train': ['val', 'test'], 15 | 'val': ['val', 'test'], 16 | 'test': ['test'], 17 | } 18 | -------------------------------------------------------------------------------- /npt/datasets/base.py: -------------------------------------------------------------------------------- 1 | import os 2 | import shutil 3 | from abc import ABC, abstractmethod 4 | from pathlib import Path 5 | 6 | import numpy as np 7 | 8 | 9 | class BaseDataset(ABC): 10 | """Abstract base dataset class. 11 | 12 | Requires subclasses to override the following: 13 | load, which should set all attributes below that are None 14 | which will be needed by the dataset (e.g. fixed_split_indices 15 | need only be set by a dataset that has a fully specified 16 | train, val, and test set for comparability to prior approaches). 17 | """ 18 | def __init__( 19 | self, fixed_test_set_index): 20 | """ 21 | Args: 22 | fixed_test_set_index: int, if specified, the dataset has a 23 | fixed test set starting at this index. Needed for 24 | comparability to other methods. 25 | """ 26 | self.fixed_test_set_index = fixed_test_set_index 27 | self.c = None 28 | self.data_table = None 29 | self.missing_matrix = None 30 | self.N = None 31 | self.D = None 32 | self.cat_features = None 33 | self.num_features = None 34 | self.cat_target_cols = None 35 | self.num_target_cols = None 36 | self.auroc_setting = None 37 | self.is_data_loaded = False 38 | self.tmp_file_or_dir_names = [] # Deleted if c.clear_tmp_files=True 39 | 40 | # fixed_split_indices: Dict[str, np.array], a fully specified 41 | # mapping from the dataset mode key (train, val, or test) 42 | # to a np.array containing the indices for the respective 43 | # mode. 44 | self.fixed_split_indices = None 45 | 46 | def get_data_dict(self, force_disable_auroc=None): 47 | if not self.is_data_loaded: 48 | self.load() 49 | 50 | self.auroc_setting = self.use_auroc(force_disable_auroc) 51 | 52 | # # # For some datasets, we should immediately delete temporary files 53 | # # # e.g. Higgs: zipped and unzipped file = 16 GB, CV split is 3 GB 54 | if self.c.data_clear_tmp_files: 55 | print('\nClearing tmp files.') 56 | path = Path(self.c.data_path) / self.c.data_set 57 | for file_or_dir in self.tmp_file_or_dir_names: 58 | file_dir_path = path / file_or_dir 59 | 60 | # Could be both file and a path! 61 | if os.path.isfile(file_dir_path): 62 | os.remove(file_dir_path) 63 | print(f'Removed file {file_or_dir}.') 64 | if os.path.isdir(file_dir_path): 65 | try: 66 | shutil.rmtree(file_dir_path) 67 | print(f'Removed dir {file_or_dir}.') 68 | except OSError as e: 69 | print("Error: %s - %s." % (e.filename, e.strerror)) 70 | 71 | return self.__dict__ 72 | 73 | @abstractmethod 74 | def load(self): 75 | pass 76 | 77 | def use_auroc(self, force_disable=None): 78 | """ 79 | Disable AUROC metric: 80 | (i) if we do not have a single categorical target column, 81 | (ii) if the single categorical target column is multiclass. 82 | """ 83 | if not self.is_data_loaded: 84 | self.load() 85 | 86 | disable = 'Disabling AUROC metric.' 87 | 88 | if force_disable: 89 | print(disable) 90 | return False 91 | 92 | if not self.c.metrics_auroc: 93 | print(disable) 94 | print("As per config argument 'metrics_auroc'.") 95 | return False 96 | 97 | num_target_cols, cat_target_cols = ( 98 | self.num_target_cols, self.cat_target_cols) 99 | n_cat_target_cols = len(cat_target_cols) 100 | 101 | if n_cat_target_cols != 1: 102 | print(disable) 103 | print( 104 | f'\tBecause dataset has {n_cat_target_cols} =/= 1 ' 105 | f'categorical target columns.') 106 | if n_cat_target_cols > 1: 107 | print( 108 | '\tNote that we have not decided how we want to handle ' 109 | 'AUROC among multiple categorical target columns.') 110 | return False 111 | elif num_target_cols: 112 | print(disable) 113 | print( 114 | '\tBecause dataset has a nonzero count of ' 115 | 'numerical target columns.') 116 | return False 117 | else: 118 | auroc_col = cat_target_cols[0] 119 | if n_classes := len(np.unique(self.data_table[:, auroc_col])) > 2: 120 | print(disable) 121 | print(f'\tBecause AUROC does not (in the current implem.) ' 122 | f'support multiclass ({n_classes}) classification.') 123 | return False 124 | 125 | return True 126 | 127 | @staticmethod 128 | def get_num_cat_auto(data, cutoff): 129 | """Interpret all columns with < "cutoff" values as categorical.""" 130 | D = data.shape[1] 131 | cols = np.arange(0, D) 132 | unique_vals = np.array([np.unique(data[:, col]).size for col in cols]) 133 | 134 | num_feats = cols[unique_vals > cutoff] 135 | cat_feats = cols[unique_vals <= cutoff] 136 | 137 | assert np.intersect1d(cat_feats, num_feats).size == 0 138 | assert np.union1d(cat_feats, num_feats).size == D 139 | 140 | # we dump to json later, it will crie if not python dtypes 141 | num_feats = [int(i) for i in num_feats] 142 | cat_feats = [int(i) for i in cat_feats] 143 | 144 | return num_feats, cat_feats 145 | 146 | @staticmethod 147 | def impute_missing_entries(cat_features, data_table, missing_matrix): 148 | """ 149 | Fill categorical missing entries with ? 150 | and numerical entries with the mean of the column. 151 | """ 152 | for col in range(data_table.shape[1]): 153 | # Get missing value locations 154 | curr_col = data_table[:, col] 155 | 156 | if curr_col.dtype == np.object_: 157 | col_missing = np.array( 158 | [True if str(n) == "nan" else False for n in curr_col]) 159 | else: 160 | col_missing = np.isnan(data_table[:, col]) 161 | 162 | # There are missing values 163 | if col_missing.sum() > 0: 164 | # Set in missing matrix (used to avoid using data augmentation 165 | # or predicting on those values 166 | missing_matrix[:, col] = col_missing 167 | 168 | if col in cat_features: 169 | missing_impute_val = '?' 170 | else: 171 | missing_impute_val = np.mean( 172 | data_table[~col_missing, col]) 173 | 174 | data_table[:, col] = np.array([ 175 | missing_impute_val if col_missing[i] else data_table[i, col] 176 | for i in range(data_table.shape[0])]) 177 | 178 | n_missing_values = missing_matrix.sum() 179 | print(f'Detected {n_missing_values} missing values in dataset.') 180 | 181 | return data_table, missing_matrix 182 | 183 | def make_missing(self, p): 184 | N = self.N 185 | D = self.D 186 | 187 | # drawn random indices (excluding the target columns) 188 | target_cols = self.num_target_cols + self.cat_target_cols 189 | D_miss = D - len(target_cols) 190 | 191 | missing = np.zeros((N * D_miss), dtype=np.bool_) 192 | 193 | # draw random indices at which to set True do 194 | idxs = np.random.choice( 195 | a=range(0, N * D_miss), size=int(p * N * D_miss), replace=False) 196 | 197 | # set missing to true at these indices 198 | missing[idxs] = True 199 | 200 | assert missing.sum() == int(p * N * D_miss) 201 | 202 | # reshape to original shape 203 | missing = missing.reshape(N, D_miss) 204 | 205 | # add back target columns 206 | missing_complete = missing 207 | 208 | for col in target_cols: 209 | missing_complete = np.concatenate( 210 | [missing_complete[:, :col], 211 | np.zeros((N, 1), dtype=np.bool_), 212 | missing_complete[:, col:]], 213 | axis=1 214 | ) 215 | 216 | if len(target_cols) > 1: 217 | raise NotImplementedError( 218 | 'Missing matrix generation should work for multiple ' 219 | 'target cols as well, but this has not been tested. ' 220 | 'Please test first.') 221 | 222 | return missing_complete 223 | -------------------------------------------------------------------------------- /npt/datasets/boston_housing.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from sklearn.datasets import load_boston 3 | 4 | from npt.datasets.base import BaseDataset 5 | 6 | 7 | class BostonHousingDataset(BaseDataset): 8 | def __init__(self, c): 9 | super().__init__( 10 | fixed_test_set_index=None) 11 | self.c = c 12 | 13 | def load(self): 14 | """ 15 | Regression dataset. 16 | 17 | Target in last column. 18 | 506 rows. 19 | 13 attributes. 20 | 21 | Feature types (Copied from sklearn description): 22 | idx name type num unique 23 | 0 - CRIM NUM 504 per capita crime rate by town 24 | 1 - ZN NUM 26 proportion of residential land zoned for 25 | lots over 25,000 sq.ft. (only has 26 unique 26 | values) 27 | 2 - INDUS NUM 76 proportion of non-retail business acres per 28 | town 29 | 3 - CHAS CAT 2 Charles River dummy variable (= 1 if tract 30 | bounds river; 0 otherwise) 31 | 4 - NOX NUM 81 nitric oxides concentration (parts per 10 32 | million) 33 | 5 - RM NUM 446 average number of rooms per dwelling 34 | 6 - AGE NUM 356 proportion of owner-occupied units built 35 | prior to 1940 36 | 7 - DIS NUM 412 weighted distances to five Boston 37 | employment centres 38 | 8 - RAD CAT 9 index of accessibility to radial highways 39 | 9 - TAX NUM 66 full-value property-tax rate per $10,000 40 | 10 - PTRATIO NUM 46 pupil-teacher ratio by town 41 | 11 - B NUM 357 1000(Bk - 0.63)^2 where Bk is the 42 | proportion of black individuals by town 43 | 12 - LSTAT NUM 455 % lower status of the population 44 | (T) 13 - MEDV NUM Median value of owner-occupied homes in $1000's 45 | 46 | Mean, std value of target column: 22.532806324110677, 9.188011545278203 47 | 48 | --> Just guessing the mean value on standardized data will always 49 | give you an MSE of 1. 50 | 51 | """ 52 | 53 | x, y = load_boston(return_X_y=True) 54 | 55 | self.data_table = np.concatenate([x, y[:, np.newaxis]], 1) 56 | 57 | self.N = self.data_table.shape[0] 58 | self.D = self.data_table.shape[1] 59 | self.cat_features = [3, 8] 60 | self.num_features = [0, 1, 2, 4, 5, 6, 7, 9, 10, 11, 12, 13] 61 | 62 | # Target col is the last feature (numerical, "median housing value") 63 | self.num_target_cols = [self.D - 1] 64 | self.cat_target_cols = [] 65 | 66 | # TODO: add missing entries to sanity check 67 | self.missing_matrix = np.zeros((self.N, self.D), dtype=np.bool_) 68 | 69 | self.is_data_loaded = True 70 | 71 | # No tmp files left by this dwnld method 72 | self.tmp_file_or_dir_names = [] 73 | -------------------------------------------------------------------------------- /npt/datasets/breast_cancer.py: -------------------------------------------------------------------------------- 1 | from pathlib import Path 2 | 3 | import numpy as np 4 | import pandas as pd 5 | 6 | from npt.datasets.base import BaseDataset 7 | from npt.utils.data_loading_utils import download 8 | 9 | 10 | class BreastCancerClassificationDataset(BaseDataset): 11 | def __init__(self, c): 12 | super(BreastCancerClassificationDataset, self).__init__( 13 | fixed_test_set_index=None) 14 | self.c = c 15 | 16 | def load(self): 17 | (self.data_table, self.N, self.D, self.cat_features, self.num_features, 18 | self.missing_matrix) = load_and_preprocess_breast_cancer_dataset( 19 | self.c) 20 | 21 | # For breast cancer, target index is the first column 22 | self.num_target_cols = [] 23 | self.cat_target_cols = [0] 24 | 25 | self.is_data_loaded = True 26 | self.tmp_file_or_dir_names = ['wdbc.data'] 27 | 28 | # overwrite missing 29 | if (p := self.c.exp_artificial_missing) > 0: 30 | self.missing_matrix = self.make_missing(p) 31 | # this is not strictly necessary with our code, but safeguards 32 | # against bugs 33 | # TODO: maybe replace with np.nan 34 | self.data_table[self.missing_matrix] = 0 35 | 36 | 37 | def load_and_preprocess_breast_cancer_dataset(c): 38 | """Class imbalance is [357, 212].""" 39 | path = Path(c.data_path) / c.data_set 40 | data_name = 'wdbc.data' 41 | 42 | file = path / data_name 43 | 44 | if not file.is_file(): 45 | # download if does not exist 46 | url = ( 47 | 'https://archive.ics.uci.edu/ml/' 48 | + 'machine-learning-databases/' 49 | + 'breast-cancer-wisconsin/' 50 | + data_name) 51 | 52 | download(file, url) 53 | 54 | # Read dataset 55 | data_table = pd.read_csv(file, header=None).to_numpy() 56 | 57 | # Drop id col 58 | data_table = data_table[:, 1:] 59 | 60 | N = data_table.shape[0] 61 | D = data_table.shape[1] 62 | 63 | if c.exp_smoke_test: 64 | print('Running smoke test -- building simple breast cancer dataset.') 65 | dm = data_table[data_table[:, 0] == 'M'][:8, :5] 66 | db = data_table[data_table[:, 0] == 'B'][:8, :5] 67 | data_table = np.concatenate([dm, db], 0) 68 | N = data_table.shape[0] 69 | D = data_table.shape[1] 70 | 71 | # Speculate some spurious missing features 72 | missing_matrix = np.zeros((N, D)) 73 | missing_matrix[0, 1] = 1 74 | missing_matrix[2, 2] = 1 75 | missing_matrix = missing_matrix.astype(dtype=np.bool_) 76 | else: 77 | missing_matrix = np.zeros((N, D)) 78 | missing_matrix = missing_matrix.astype(dtype=np.bool_) 79 | 80 | cat_features = [0] 81 | num_features = list(range(1, D)) 82 | return data_table, N, D, cat_features, num_features, missing_matrix 83 | 84 | 85 | class BreastCancerDebugClassificationDataset(BaseDataset): 86 | """For debugging row interactions. Add two columns for index tracking.""" 87 | def __init__(self, c): 88 | super(BreastCancerClassificationDataset, self).__init__( 89 | fixed_test_set_index=None) 90 | self.c = c 91 | 92 | def load(self): 93 | raise 94 | # need to augment table and features and and and 95 | # (to contain the index rows!! can already write index rows as long 96 | # as permutation is random!) 97 | 98 | (self.data_table, self.N, self.D, self.cat_features, self.num_features, 99 | self.missing_matrix) = load_and_preprocess_breast_cancer_dataset( 100 | self.c) 101 | 102 | # For breast cancer, target index is the first column 103 | self.num_target_cols = [] 104 | self.cat_target_cols = [0] 105 | 106 | self.is_data_loaded = True 107 | self.tmp_file_or_dir_names = ['wdbc.data'] 108 | 109 | # overwrite missing 110 | if (p := self.c.exp_artificial_missing) > 0: 111 | self.missing_matrix = self.make_missing(p) 112 | # this is not strictly necessary with our code, but safeguards 113 | # against bugs 114 | # TODO: maybe replace with np.nan 115 | self.data_table[self.missing_matrix] = 0 116 | -------------------------------------------------------------------------------- /npt/datasets/cifar10.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | from npt.datasets.base import BaseDataset 4 | 5 | 6 | class CIFAR10Dataset(BaseDataset): 7 | def __init__(self, c): 8 | super().__init__( 9 | fixed_test_set_index=None) 10 | self.c = c 11 | 12 | def load(self): 13 | """ 14 | Classification dataset. 15 | 16 | Target in last column. 17 | 60 000 rows. 18 | 3072 attributes. 19 | 1 class (10 class labels) 20 | 21 | Author: Alex Krizhevsky, Vinod Nair, and Geoffrey Hinton 22 | Source: [University of Toronto] 23 | (https://www.cs.toronto.edu/~kriz/cifar.html) - 2009 24 | Alex Krizhevsky (2009) Learning Multiple Layers of Features from 25 | Tiny Images, Tech Report. 26 | 27 | CIFAR-10 is a labeled subset of the [80 million tiny images dataset] 28 | (http://groups.csail.mit.edu/vision/TinyImages/). 29 | 30 | It (originally) consists 32x32 color images representing 31 | 10 classes of objects: 32 | 0. airplane 33 | 1. automobile 34 | 2. bird 35 | 3. cat 36 | 4. deer 37 | 5. dog 38 | 6. frog 39 | 7. horse 40 | 8. ship 41 | 9. truck 42 | 43 | CIFAR-10 contains 6000 images per class. 44 | Similar to the original train-test split, which randomly divided 45 | these classes into 5000 train and 1000 test images per class, 46 | we do 5-fold class-balanced cross-validation by default. 47 | 48 | The classes are completely mutually exclusive. 49 | There is no overlap between automobiles and trucks. 50 | "Automobile" includes sedans, SUVs, things of that sort. 51 | "Truck" includes only big trucks. Neither includes pickup trucks. 52 | 53 | ### Attribute description 54 | Each instance represents a 32x32 colour image as a 3072-value array. 55 | The first 1024 entries contain the red channel values, the next 56 | 1024 the green, and the final 1024 the blue. The image is stored 57 | in row-major order, so that the first 32 entries of the array are 58 | the red channel values of the first row of the image. 59 | 60 | The labels are encoded as integers in the range 0-9, 61 | corresponding to the numbered classes listed above. 62 | """ 63 | self.N = 60000 64 | self.D = 3073 65 | self.cat_features = [self.D - 1] 66 | self.num_features = list(range(0, self.D - 1)) 67 | 68 | # Target col is the last feature 69 | self.num_target_cols = [] 70 | self.cat_target_cols = [self.D - 1] 71 | 72 | # TODO: add missing entries to sanity check 73 | self.missing_matrix = torch.zeros((self.N, self.D), dtype=torch.bool) 74 | self.is_data_loaded = True 75 | 76 | self.input_feature_dims = [1] * 3072 77 | self.input_feature_dims += [10] 78 | -------------------------------------------------------------------------------- /npt/datasets/concrete.py: -------------------------------------------------------------------------------- 1 | from pathlib import Path 2 | 3 | import numpy as np 4 | import pandas as pd 5 | from sklearn.datasets import fetch_openml 6 | 7 | from npt.datasets.base import BaseDataset 8 | 9 | 10 | class ConcreteDataset(BaseDataset): 11 | def __init__(self, c): 12 | super().__init__( 13 | fixed_test_set_index=None) 14 | self.c = c 15 | 16 | def load(self): 17 | """ 18 | Regression dataset. 19 | 20 | Target in last column. 21 | 1030 rows. 22 | 8 attributes. 23 | 1 target (Residuary.resistance) (256 unique numbers) 24 | 25 | Features n_unique encode as 26 | 0 Cement 278 NUM 27 | 1 Blast Furnace Slag 185 NUM 28 | 2 Fly Ash 156 NUM 29 | 3 Water 195 NUM 30 | 4 Superplasticizer 111 NUM 31 | 5 Coarse Aggregate 284 NUM 32 | 6 Fine Aggregate 302 NUM 33 | 7 Age 14 NUM 34 | 8 Concrete compressive strength 845 NUM 35 | 36 | Std of Target Col 16.697630409134263. 37 | """ 38 | 39 | # Load data from https://www.openml.org/d/4353 40 | data_home = Path(self.c.data_path) / self.c.data_set 41 | x, _ = fetch_openml( 42 | 'Concrete_data', 43 | version=1, return_X_y=True, data_home=data_home) 44 | 45 | if isinstance(x, np.ndarray): 46 | pass 47 | elif isinstance(x, pd.DataFrame): 48 | x = x.to_numpy() 49 | 50 | self.data_table = x 51 | self.N = self.data_table.shape[0] 52 | self.D = self.data_table.shape[1] 53 | 54 | # Target col is the last feature 55 | self.num_target_cols = [self.D - 1] 56 | self.cat_target_cols = [] 57 | 58 | self.num_features = [0, 1, 2, 3, 4, 5, 6, 7, 8] 59 | self.cat_features = [] 60 | 61 | # TODO: add missing entries to sanity check 62 | self.missing_matrix = np.zeros((self.N, self.D), dtype=np.bool_) 63 | self.is_data_loaded = True 64 | self.tmp_file_or_dir_names = ['openml'] -------------------------------------------------------------------------------- /npt/datasets/debug.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | from npt.datasets.base import BaseDataset 4 | 5 | 6 | class DebugDataset(BaseDataset): 7 | """Used for debugging of row interactions. 8 | Will dynamically overwrite stuff. In each batch. 9 | But need to set some fake data here such that metadata gets written 10 | correctly. 11 | """ 12 | def __init__(self, c): 13 | super().__init__(fixed_test_set_index=-5) 14 | self.c = c 15 | 16 | def load(self): 17 | """Debug dataset. Has four columns. 18 | 19 | The first are copies from one throw of a random dice. 20 | I.e. the entire column contains the same data. 21 | Model has to masked out values by reading off dice value from other 22 | rows. (Only makes sense in semi-supervised). 23 | The other three of which are just random data we don't care about. 24 | """ 25 | 26 | # Load data from https://www.openml.org/d/4535 27 | 28 | self.N = 20 29 | self.D = 6 30 | 31 | data = np.zeros((self.N, self.D)) 32 | # populate with all possible choices 33 | for i in range(self.D): 34 | data[:18, i] = np.repeat(range(6), 3) 35 | 36 | self.data_table = data.astype('int') 37 | 38 | self.num_target_cols = [] 39 | self.cat_target_cols = [0] 40 | 41 | self.num_features = [] 42 | self.cat_features = list(range(self.D)) 43 | 44 | # TODO: add missing entries to sanity check 45 | self.missing_matrix = np.zeros((self.N, self.D), dtype=np.bool_) 46 | self.is_data_loaded = True 47 | -------------------------------------------------------------------------------- /npt/datasets/forest_cover.py: -------------------------------------------------------------------------------- 1 | from pathlib import Path 2 | 3 | import numpy as np 4 | import pandas as pd 5 | import patoolib 6 | 7 | from npt.datasets.base import BaseDataset 8 | from npt.utils.data_loading_utils import download 9 | 10 | 11 | class ForestCoverClassificationDataset(BaseDataset): 12 | def __init__(self, c): 13 | super(ForestCoverClassificationDataset, self).__init__( 14 | fixed_test_set_index=None) 15 | self.c = c 16 | 17 | def load(self): 18 | (self.data_table, self.N, self.D, self.cat_features, self.num_features, 19 | self.missing_matrix) = load_and_preprocess_forest_cover_dataset( 20 | self.c) 21 | 22 | # Target col is the last feature -- multiclass classification 23 | self.num_target_cols = [] 24 | self.cat_target_cols = [self.D - 1] 25 | 26 | self.tmp_file_or_dir_names = ['covtype.data', 'covtype.data.gz'] 27 | self.is_data_loaded = True 28 | 29 | 30 | def load_and_preprocess_forest_cover_dataset(c): 31 | """ForestCoverDataset. 32 | 33 | Used in TabNet. 34 | 35 | Multi-class classification. 36 | Target in last column. (7 different categories of forest cover.) 37 | 581,012 rows. 38 | Each row has 54 features (55th column is the target). 39 | 40 | Feature types: 41 | 10 continuous features, 4 binary "wilderness area" features, 42 | 40 binary "soil type" variables. 43 | 44 | Classical usage: 45 | first 11,340 records used for training data subset 46 | next 3,780 records used for validation data subset 47 | last 565,892 records used for testing data subset 48 | 49 | WE DON'T DO THE ABOVE, following the TabNet and XGBoost baselines 50 | Just do (0.8, 0.2) (train, test) split. 51 | 52 | Class imbalance: Yes. 53 | [211840, 283301, 35754, 2747, 9493, 17367, 20510] 54 | Guessing performance is 0.488 percent accuracy. 55 | Getting the two most frequent classes gives 0.729 percent accuracy. 56 | Top three most frequent gets 0.914. 57 | 58 | """ 59 | 60 | path = Path(c.data_path) / c.data_set 61 | data_name = 'covtype.data' 62 | file = path / data_name 63 | 64 | if not file.is_file(): 65 | # download if does not exist 66 | download_name = 'covtype.data.gz' 67 | url = ( 68 | 'https://archive.ics.uci.edu/ml/' 69 | + 'machine-learning-databases/covtype/' 70 | + download_name 71 | ) 72 | download_file = path / download_name 73 | download(download_file, url) 74 | # Forest cover comes compressed. 75 | patoolib.extract_archive(str(download_file), outdir=str(path)) 76 | 77 | data_table = pd.read_csv(file, header=None).to_numpy() 78 | 79 | # return 80 | if c.exp_smoke_test: 81 | print( 82 | 'Running smoke test -- building simple forest cover dataset.') 83 | class_datasets = [] 84 | for class_type in [1, 2, 3, 4, 5, 6, 7]: 85 | class_datasets.append([data_table[ 86 | data_table[:, -1] == class_type][0]]) 87 | 88 | data_table = np.concatenate(class_datasets, axis=0) 89 | 90 | N = data_table.shape[0] 91 | D = data_table.shape[1] 92 | num_features = list(range(10)) 93 | cat_features = list(range(10, D)) 94 | 95 | # TODO: add missing entries to sanity check 96 | missing_matrix = np.zeros((N, D), dtype=np.bool_) 97 | 98 | return data_table, N, D, cat_features, num_features, missing_matrix 99 | -------------------------------------------------------------------------------- /npt/datasets/higgs.py: -------------------------------------------------------------------------------- 1 | from pathlib import Path 2 | 3 | import numpy as np 4 | import pandas as pd 5 | import patoolib 6 | 7 | from npt.datasets.base import BaseDataset 8 | from npt.utils.data_loading_utils import download 9 | from os import remove 10 | 11 | 12 | class HiggsClassificationDataset(BaseDataset): 13 | def __init__(self, c): 14 | super(HiggsClassificationDataset, self).__init__( 15 | fixed_test_set_index=-500000) # Test set: last 500,000 examples 16 | self.c = c 17 | 18 | def load(self): 19 | (self.data_table, self.N, self.D, self.cat_features, self.num_features, 20 | self.missing_matrix) = load_and_preprocess_higgs_dataset( 21 | self.c) 22 | 23 | self.num_target_cols = [] 24 | self.cat_target_cols = [0] # Binary classification 25 | self.is_data_loaded = True 26 | self.tmp_file_names = ['HIGGS.csv'] 27 | 28 | 29 | def load_and_preprocess_higgs_dataset(c): 30 | """HIGGS dataset as used by NODE. 31 | 32 | Binary classification. 33 | First column is categorical target column, 34 | all remaining 28 columns are continuous features. 35 | 11.000.000 rows in total. 36 | The last 500,000 rows are commonly used as a test set. 37 | Separate training and test set. 38 | 39 | No class imbalance (array([0., 1.]), array([5170877, 5829123])). 40 | """ 41 | path = Path(c.data_path) / c.data_set 42 | data_name = 'HIGGS.csv' 43 | file = path / data_name 44 | 45 | # For breast cancer, target index is the first column 46 | if not file.is_file(): 47 | # download if does not exist 48 | download_name = 'HIGGS.csv.gz' 49 | url = ( 50 | 'https://archive.ics.uci.edu/ml/' 51 | + 'machine-learning-databases/00280/' 52 | + download_name 53 | ) 54 | download_file = path / download_name 55 | download(download_file, url) 56 | 57 | # Higgs comes compressed. 58 | print('Decompressing...') 59 | patoolib.extract_archive(str(download_file), outdir=str(path)) 60 | print('... done.') 61 | 62 | # Delete the compressed file (Higgs is very large) 63 | remove(download_file) 64 | print(f'Removed compressed file {download_name}.') 65 | 66 | data_table = pd.read_csv(file, header=None).to_numpy() 67 | N, D = data_table.shape 68 | cat_features = [0] 69 | num_features = list(range(1, D)) 70 | missing_matrix = np.zeros((N, D), dtype=np.bool_) 71 | 72 | return data_table, N, D, cat_features, num_features, missing_matrix 73 | -------------------------------------------------------------------------------- /npt/datasets/image_utils.py: -------------------------------------------------------------------------------- 1 | def load_image_dataloaders(c): 2 | batch_size = c.exp_batch_size 3 | from npt.utils.image_loading_utils import get_dataloaders 4 | 5 | if c.data_set in ['cifar10']: 6 | # For CIFAR, let's just use 10% of the training set for validation. 7 | # That is, 10% of 50,000 rows = 5,000 rows 8 | val_perc = 0.10 9 | else: 10 | raise NotImplementedError 11 | 12 | _, trainloader, validloader, testloader = get_dataloaders( 13 | c.data_set, batch=batch_size, dataroot=f'{c.data_path}/{c.data_set}', 14 | c=c, split=val_perc, split_idx=0) 15 | data_dict = { 16 | 'trainloader': trainloader, 17 | 'validloader': validloader, 18 | 'testloader': testloader} 19 | 20 | return data_dict 21 | -------------------------------------------------------------------------------- /npt/datasets/income.py: -------------------------------------------------------------------------------- 1 | from pathlib import Path 2 | 3 | import numpy as np 4 | import pandas as pd 5 | from sklearn.datasets import fetch_openml 6 | 7 | from npt.datasets.base import BaseDataset 8 | 9 | 10 | class IncomeDataset(BaseDataset): 11 | def __init__(self, c): 12 | super().__init__( 13 | fixed_test_set_index=-99762) 14 | self.c = c 15 | 16 | def load(self): 17 | """KDD Income Dataset 18 | 19 | Possibly used in VIME and TabNet. 20 | 21 | There are multiple datasets called income. 22 | https://archive.ics.uci.edu/ml/datasets/census+income 23 | https://archive.ics.uci.edu/ml/datasets/Census-Income+%28KDD%29 24 | The KDD One is significantly larger than the other one. 25 | 26 | We will take KDD one. Both TabNet and VIME are not super explicit about 27 | which dataset they use. 28 | TabNet cite Oza et al "Online Bagging and Boosting", which use the 29 | bigger one. So we will start with that. 30 | (But there is no full TabNet Code to confirm.) 31 | 32 | Binary classification. 33 | 34 | Target in last column. 35 | 299.285 rows. 36 | 42 attributes. Use get_num_cat_auto to assign. 37 | 1 target 38 | """ 39 | 40 | # Load data from https://www.openml.org/d/4535 41 | data_home = Path(self.c.data_path) / self.c.data_set 42 | data = fetch_openml('Census-income', version=1, data_home=data_home) 43 | 44 | # target in 'data' 45 | self.data_table = data['data'] 46 | 47 | if isinstance(self.data_table, np.ndarray): 48 | pass 49 | elif isinstance(self.data_table, pd.DataFrame): 50 | self.data_table = self.data_table.to_numpy() 51 | 52 | self.N = self.data_table.shape[0] 53 | self.D = self.data_table.shape[1] 54 | 55 | # Target col is the last feature 56 | # last column is target (V42) 57 | # (binary classification, if income > or < 50k) 58 | self.num_target_cols = [] 59 | self.cat_target_cols = [self.D - 1] 60 | 61 | self.num_features, self.cat_features = BaseDataset.get_num_cat_auto( 62 | self.data_table, cutoff=55) 63 | print('income num cat features') 64 | print(len(self.num_features)) 65 | print(len(self.cat_features)) 66 | 67 | # TODO: add missing entries to sanity check 68 | self.missing_matrix = np.zeros((self.N, self.D), dtype=np.bool_) 69 | self.is_data_loaded = True 70 | self.tmp_file_or_dir_names = ['openml'] -------------------------------------------------------------------------------- /npt/datasets/kick.py: -------------------------------------------------------------------------------- 1 | from pathlib import Path 2 | 3 | import numpy as np 4 | import pandas as pd 5 | from sklearn.datasets import fetch_openml 6 | 7 | from npt.datasets.base import BaseDataset 8 | 9 | 10 | class KickDataset(BaseDataset): 11 | def __init__(self, c): 12 | super().__init__( 13 | fixed_test_set_index=None) 14 | self.c = c 15 | 16 | def load(self): 17 | """ 18 | The challenge of this Kaggle competition is to predict if the car 19 | purchased at the Auction is a Kick (bad buy). 20 | 21 | Accessed through OpenML 22 | https://www.openml.org/d/41162 23 | 24 | 72983 rows 25 | 32 features, 1 target column. 26 | 27 | Binary classification. 28 | 29 | Target in last column. 30 | 31 | Features n_unique encode as 32 | 0 PurchDate 517 NUM 33 | 1 Auction 3 CAT 34 | 2 VehYear 10 NUM 35 | 3 VehicleAge 10 NUM 36 | 4 Make 33 CAT 37 | 5 Model 1063 CAT 38 | 6 Trim 134 CAT 39 | 7 SubModel 863 CAT 40 | 8 Color 16 CAT 41 | 9 Transmission 3 CAT 42 | 10 WheelTypeID 4 CAT 43 | 11 WheelType 3 CAT 44 | 12 VehOdo 39947 NUM 45 | 13 Nationality 4 CAT 46 | 14 Size 12 CAT 47 | 15 TopThreeAmericanName 4 CAT 48 | 16 MMRAcquisitionAuctionAveragePrice 10342 NUM 49 | 17 MMRAcquisitionAuctionCleanPrice 11379 NUM 50 | 18 MMRAcquisitionRetailAveragePrice 12725 NUM 51 | 19 MMRAcquisitonRetailCleanPrice 13456 NUM 52 | 20 MMRCurrentAuctionAveragePrice 10315 NUM 53 | 21 MMRCurrentAuctionCleanPrice 11265 NUM 54 | 22 MMRCurrentRetailAveragePrice 12493 NUM 55 | 23 MMRCurrentRetailCleanPrice 13192 NUM 56 | 24 PRIMEUNIT 2 CAT 57 | 25 AUCGUART 2 CAT 58 | 26 BYRNO 74 CAT 59 | 27 VNZIP1 153 CAT 60 | 28 VNST 37 CAT 61 | 29 VehBCost 2010 NUM 62 | 30 IsOnlineSale 2 CAT 63 | 31 WarrantyCost 281 NUM 64 | """ 65 | 66 | # Load data from https://www.openml.org/d/4353 67 | data_home = Path(self.c.data_path) / self.c.data_set 68 | x, y = fetch_openml( 69 | 'kick', 70 | version=1, return_X_y=True, data_home=data_home) 71 | 72 | if isinstance(x, np.ndarray): 73 | pass 74 | elif isinstance(x, pd.DataFrame): 75 | x = x.to_numpy() 76 | 77 | x = np.concatenate((x, np.expand_dims(y, -1)), axis=1) 78 | print(x.shape) 79 | 80 | self.data_table = x 81 | self.N = self.data_table.shape[0] 82 | self.D = self.data_table.shape[1] 83 | 84 | # Target col is the last feature 85 | self.num_target_cols = [] 86 | self.cat_target_cols = [x.shape[1] - 1] 87 | 88 | self.num_features = [ 89 | 0, 2, 3, 12, 16, 17, 18, 19, 20, 21, 22, 23, 29, 31] 90 | self.cat_features = [ 91 | 1, 4, 5, 6, 7, 8, 9, 10, 11, 13, 14, 15, 24, 25, 26, 27, 28, 30, 32] 92 | 93 | self.missing_matrix = np.zeros((self.N, self.D), dtype=np.bool_) 94 | self.data_table, self.missing_matrix = self.impute_missing_entries( 95 | cat_features=self.cat_features, data_table=self.data_table, 96 | missing_matrix=self.missing_matrix) 97 | 98 | self.is_data_loaded = True 99 | self.tmp_file_or_dir_names = ['openml'] 100 | -------------------------------------------------------------------------------- /npt/datasets/mnist.py: -------------------------------------------------------------------------------- 1 | from pathlib import Path 2 | 3 | import numpy as np 4 | from sklearn.datasets import fetch_openml 5 | 6 | from npt.datasets.base import BaseDataset 7 | 8 | 9 | class MNISTDataset(BaseDataset): 10 | def __init__(self, c): 11 | super().__init__( 12 | fixed_test_set_index=-10000) 13 | self.c = c 14 | 15 | def load(self): 16 | """ 17 | Classification dataset. 18 | 19 | Target in last column. 20 | 70 000 rows. 21 | 784 attributes. 22 | 1 class (10 class labels) 23 | 24 | Class imbalance: Not really. 25 | array([6903, 7877, 6990, 7141, 6824, 6313, 6876, 7293, 6825, 6958]) 26 | 27 | """ 28 | 29 | # Load data from https://www.openml.org/d/554 30 | data_home = Path(self.c.data_path) / self.c.data_set 31 | x, y = fetch_openml( 32 | 'mnist_784', version=1, return_X_y=True, data_home=data_home) 33 | 34 | self.data_table = np.hstack((x, np.expand_dims(y, -1))) 35 | 36 | self.N = self.data_table.shape[0] 37 | self.D = self.data_table.shape[1] 38 | self.cat_features = [self.D-1] 39 | self.num_features = list(range(0, self.D-1)) 40 | 41 | # Target col is the last feature 42 | self.num_target_cols = [] 43 | self.cat_target_cols = [self.D - 1] 44 | 45 | # TODO: add missing entries to sanity check 46 | self.missing_matrix = np.zeros((self.N, self.D), dtype=np.bool_) 47 | self.is_data_loaded = True 48 | self.tmp_file_or_dir_names = ['openml'] 49 | -------------------------------------------------------------------------------- /npt/datasets/poker_hand.py: -------------------------------------------------------------------------------- 1 | from pathlib import Path 2 | 3 | import numpy as np 4 | import pandas as pd 5 | 6 | from npt.datasets.base import BaseDataset 7 | from npt.utils.data_loading_utils import download 8 | 9 | 10 | class PokerHandDataset(BaseDataset): 11 | def __init__(self, c): 12 | super(PokerHandDataset, self).__init__( 13 | fixed_test_set_index=None) # Set when load is called 14 | self.c = c 15 | 16 | def load(self): 17 | """Poker Hand data set as used by TabNet. 18 | 19 | 10-fold classification. (What kind of Poker Hand?) 20 | Target in last column. 21 | 1025010 rows. 22 | Each row has 10 features which describe the poker hand. 23 | 5 are numerical (the rank), 5 categorical (the suit). 24 | Last column is the label. 25 | Separate training and test set. 26 | 27 | This dataset has extreme class imbalance 28 | [513702, 433097, 48828, 21634, 3978, 2050, 1460, 236, 17, 8] 29 | such that guessing performance is 0.5. 30 | 31 | This also means that TabNet can get 99% performance by just getting 32 | the first 4 predictions right. 33 | 34 | # NOTE: UCI lists 'Ranks of Cards' as numerical feature 35 | # NOTE: but it seems categorical to me. 36 | 37 | """ 38 | path = Path(self.c.data_path) / self.c.data_set 39 | 40 | data_names = ['poker-hand-training-true.data', 41 | 'poker-hand-testing.data'] 42 | files = [path / data_name for data_name in data_names] 43 | 44 | files_exist = [file.is_file() for file in files] 45 | 46 | if not all(files_exist): 47 | url = ( 48 | 'https://archive.ics.uci.edu/ml/' 49 | + 'machine-learning-databases/poker/' 50 | ) 51 | 52 | urls = [url + data_name for data_name in data_names] 53 | 54 | download(files, urls) 55 | 56 | data_tables = [ 57 | pd.read_csv(file, header=None).to_numpy() for file in files] 58 | self.fixed_test_set_index = -data_tables[1].shape[0] 59 | self.data_table = np.concatenate(data_tables, 0) 60 | 61 | self.N, self.D = self.data_table.shape 62 | 63 | self.num_target_cols = [] 64 | self.cat_target_cols = [self.D - 1] 65 | 66 | # It turns out that not all features in Poker Hands are categorical -- 67 | # i.e., we can encode the rank of a card with a numerical variable. 68 | # self.cat_features = list(range(self.D)) 69 | # self.num_features = [] 70 | 71 | self.cat_features = [0, 2, 4, 6, 8, 10] 72 | self.num_features = [1, 3, 5, 7, 9] 73 | self.missing_matrix = np.zeros((self.N, self.D), dtype=np.bool_) 74 | 75 | self.is_data_loaded = True 76 | self.tmp_file_or_dir_names = data_names 77 | -------------------------------------------------------------------------------- /npt/datasets/protein.py: -------------------------------------------------------------------------------- 1 | from pathlib import Path 2 | 3 | import numpy as np 4 | import pandas as pd 5 | 6 | from npt.datasets.base import BaseDataset 7 | from npt.utils.data_loading_utils import download 8 | 9 | 10 | def load_protein(c, data_name): 11 | """Protein Dataset 12 | 13 | Used in Gal et al., 'Dropout as Bayesian Approximation'. 14 | 15 | Physicochemical Properties of Protein Tertiary Structure Data Set 16 | 17 | Regression Dataset 18 | Number of Rows 45730 19 | Number of Attributes 9 20 | 21 | RMSD-Size of the residue. 22 | F1 - Total surface area. 23 | F2 - Non polar exposed area. 24 | F3 - Fractional area of exposed non polar residue. 25 | F4 - Fractional area of exposed non polar part of residue. 26 | F5 - Molecular mass weighted exposed area. 27 | F6 - Average deviation from standard exposed area of residue. 28 | F7 - Euclidian distance. 29 | F8 - Secondary structure penalty. 30 | F9 - Spatial Distribution constraints (N,K Value). 31 | 32 | There may be a fixed test set as suggested by 'more-documentation. 33 | names' but it does not seem like Hernandez-Lobato et al. (whose setup 34 | Gal et al. repeat), respect that. 35 | 36 | https://www.kaggle.com/c/pcon-ml seems to suggest that RMSD is target. 37 | 38 | Target Col has std of 6.118244779017878. 39 | """ 40 | path = Path(c.data_path) / c.data_set 41 | 42 | file = path / data_name 43 | 44 | if not file.is_file(): 45 | # download if does not exist 46 | url = ( 47 | 'https://archive.ics.uci.edu/ml/' 48 | 'machine-learning-databases/00265/' 49 | + data_name 50 | ) 51 | download_file = path / data_name 52 | download(download_file, url) 53 | 54 | return pd.read_csv(file).to_numpy() 55 | 56 | 57 | class ProteinDataset(BaseDataset): 58 | def __init__(self, c): 59 | super().__init__( 60 | fixed_test_set_index=None) 61 | 62 | self.c = c 63 | 64 | def load(self): 65 | data_name = 'CASP.csv' 66 | self.data_table = load_protein(self.c, data_name) 67 | self.N, self.D = self.data_table.shape 68 | self.num_target_cols = [0] 69 | self.cat_target_cols = [] 70 | 71 | # have checked this with get_num_cat_auto as well 72 | self.cat_features = [] 73 | self.num_features = list(range(0, self.D)) 74 | 75 | if (p := self.c.exp_artificial_missing) > 0: 76 | self.missing_matrix = self.make_missing(p) 77 | # this is not strictly necessary with our code, but safeguards 78 | # against bugs 79 | # TODO: maybe replace with np.nan 80 | # self.data_table[self.missing_matrix] = 0 81 | 82 | else: 83 | self.missing_matrix = np.zeros((self.N, self.D), dtype=np.bool_) 84 | 85 | self.is_data_loaded = True 86 | self.tmp_file_or_dir_names = [data_name] 87 | 88 | def make_missing(self, p): 89 | N = self.N 90 | D = self.D 91 | 92 | # drawn random indices (excluding the target columns) 93 | target_cols = self.num_target_cols + self.cat_target_cols 94 | D_miss = D - len(target_cols) 95 | 96 | missing = np.zeros((N * D_miss), dtype=np.bool_) 97 | 98 | # draw random indices at which to set True do 99 | idxs = np.random.choice( 100 | a=range(0, N * D_miss), size=int(p * N * D_miss), replace=False) 101 | 102 | # set missing to true at these indices 103 | missing[idxs] = True 104 | 105 | assert missing.sum() == int(p * N * D_miss) 106 | 107 | # reshape to original shape 108 | missing = missing.reshape(N, D_miss) 109 | 110 | # add back target columns 111 | missing_complete = missing 112 | 113 | for col in target_cols: 114 | missing_complete = np.concatenate( 115 | [missing_complete[:, :col], 116 | np.zeros((N, 1), dtype=np.bool_), 117 | missing_complete[:, col:]], 118 | axis=1 119 | ) 120 | 121 | if len(target_cols) > 1: 122 | raise NotImplementedError( 123 | 'Missing matrix generation should work for multiple ' 124 | 'target cols as well, but this has not been tested. ' 125 | 'Please test first.') 126 | 127 | print(missing_complete.shape) 128 | return missing_complete 129 | -------------------------------------------------------------------------------- /npt/datasets/yacht.py: -------------------------------------------------------------------------------- 1 | from pathlib import Path 2 | 3 | import numpy as np 4 | from sklearn.datasets import fetch_openml 5 | 6 | from npt.datasets.base import BaseDataset 7 | 8 | 9 | class YachtDataset(BaseDataset): 10 | def __init__(self, c): 11 | super().__init__( 12 | fixed_test_set_index=None) 13 | self.c = c 14 | 15 | def load(self): 16 | """ 17 | Regression dataset. 18 | 19 | Target in last column. 20 | 308 rows. 21 | 6 attributes. 22 | 1 target (Residuary.resistance) (256 unique numbers) 23 | 24 | Features n_unique encode as 25 | Logitudinal.position 5 CAT 26 | Prismatic.coefficient 10 CAT 27 | Length.displacement.ratio 8 CAT 28 | Beam.draught.ratio 17 CAT 29 | Length.beam.ratio 10 CAT 30 | Froude.number' 14 NUM 31 | 32 | 33 | Std of Target Col 15.135858907655322. 34 | """ 35 | 36 | # Load data from https://www.openml.org/d/554 37 | data_home = Path(self.c.data_path) / self.c.data_set 38 | x, y = fetch_openml( 39 | 'yacht_hydrodynamics', 40 | version=1, return_X_y=True, data_home=data_home) 41 | 42 | self.data_table = np.concatenate([x, y[:, np.newaxis]], 1) 43 | 44 | self.N = self.data_table.shape[0] 45 | self.D = self.data_table.shape[1] 46 | 47 | # Target col is the last feature 48 | self.num_target_cols = [self.D - 1] 49 | self.cat_target_cols = [] 50 | 51 | self.num_features = [self.D - 1] 52 | self.cat_features = list(range(0, self.D - 1)) 53 | 54 | # TODO: add missing entries to sanity check 55 | self.missing_matrix = np.zeros((self.N, self.D), dtype=np.bool_) 56 | self.is_data_loaded = True 57 | self.tmp_file_or_dir_names = ['openml'] 58 | -------------------------------------------------------------------------------- /npt/distribution.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import torch.distributed as dist 4 | import wandb 5 | 6 | from npt.column_encoding_dataset import NPTDataset 7 | from npt.train import Trainer 8 | from npt.utils.model_init_utils import ( 9 | init_model_opt_scaler_from_dataset, setup_ddp_model) 10 | 11 | 12 | def distributed_train_wrapper(gpu, args): 13 | wandb_args = args['wandb_args'] 14 | 15 | if gpu == 0: 16 | wandb_run = wandb.init(**wandb_args) 17 | wandb.config.update(args, allow_val_change=True) 18 | 19 | c = args['c'] 20 | rank = c.mp_nr * c.mp_gpus + gpu 21 | world_size = c.mp_gpus * c.mp_nodes 22 | 23 | dist.init_process_group( 24 | backend='nccl', 25 | init_method='env://', 26 | world_size=world_size, 27 | rank=rank) 28 | torch.manual_seed(c.torch_seed) 29 | np.random.seed(c.np_seed) 30 | 31 | dataset = args['dataset'] 32 | torch.cuda.set_device(gpu) 33 | model, optimizer, scaler = init_model_opt_scaler_from_dataset( 34 | dataset=dataset, c=c, device=gpu) 35 | model = setup_ddp_model(model=model, c=c, device=gpu) 36 | 37 | distributed_dataset = NPTDataset(dataset) 38 | dist_args = { 39 | 'world_size': world_size, 40 | 'rank': rank, 41 | 'gpu': gpu} 42 | 43 | trainer = Trainer( 44 | model=model, optimizer=optimizer, scaler=scaler, c=c, 45 | cv_index=0, wandb_run=None, 46 | dataset=dataset, 47 | torch_dataset=distributed_dataset, distributed_args=dist_args) 48 | trainer.train_and_eval() 49 | 50 | if gpu == 0: 51 | wandb_run.finish() -------------------------------------------------------------------------------- /npt/model/__init__.py: -------------------------------------------------------------------------------- 1 | from npt.model.npt import NPTModel -------------------------------------------------------------------------------- /npt/model/image_patcher.py: -------------------------------------------------------------------------------- 1 | from itertools import cycle 2 | 3 | import torch 4 | import torch.nn as nn 5 | 6 | 7 | class ImagePatcher(nn.Module): 8 | def __init__( 9 | self, dim_hidden, input_feature_dims, c, patcher_type): 10 | super(ImagePatcher, self).__init__() 11 | 12 | self.c = c 13 | 14 | # Global setting, avoids non-patching logic in NPT init 15 | self.model_image_patching = True 16 | 17 | D = len(input_feature_dims) # Flattened size of an image 18 | 19 | # We reduce what the core model sees to a sequence of patches 20 | # (unordered, but with patch index embeddings), until the decoder 21 | self.image_n_patches = self.c.model_image_n_patches 22 | 23 | # Includes the target column 24 | self.num_input_features = self.image_n_patches + 1 25 | 26 | # Options: {'linear'} 27 | self.image_patch_type = self.c.model_image_patch_type 28 | 29 | # Share embedding weights across patches, or separate 30 | self.image_share_embed = self.c.model_image_share_embed 31 | 32 | # e.g. 3 for RGB 33 | self.image_n_channels = self.c.model_image_n_channels 34 | 35 | # If we use BERT augmentation, this must be 2, for 36 | # the continuous pixel intensity and the mask value. 37 | # Otherwise, this should be 1. 38 | self.dim_intensity = 1 + bool(self.c.model_bert_augmentation) 39 | self.dim_target_col = self.c.model_image_n_classes + bool( 40 | self.c.model_bert_augmentation) 41 | 42 | # Exclude target column (we assume it is right concatenated) 43 | image_input_shape = (D - 1, self.dim_intensity) 44 | 45 | # The number of patches must divide the number of pixels 46 | assert image_input_shape[0] % self.image_n_patches == 0 47 | 48 | # This is in raw intensities, i.e. counting each pixel in an 49 | # RGB image thrice 50 | self.patch_size = image_input_shape[0] // self.image_n_patches 51 | 52 | # Compute resizing constants 53 | n_features = len(input_feature_dims) - 1 54 | assert n_features % self.image_n_channels == 0 55 | 56 | if patcher_type == 'linear': 57 | # H = height, note that we are expecting square images for now 58 | # H = height = W = width 59 | flattened_image_size = n_features // self.image_n_channels 60 | self.image_H = int(flattened_image_size ** 0.5) 61 | assert flattened_image_size // self.image_H == self.image_H 62 | 63 | # Get number of rows of patches 64 | n_patches_per_side = self.image_n_patches ** 0.5 65 | assert int(n_patches_per_side) == n_patches_per_side 66 | n_patches_per_side = int(n_patches_per_side) 67 | 68 | # Get length of patches 69 | # (i.e. each patch is patch_side_length x patch_side_length) 70 | assert self.image_H % n_patches_per_side == 0 71 | self.patch_side_length = self.image_H // n_patches_per_side 72 | 73 | # ### Embeddings ### 74 | 75 | # Always use a linear out-embedding 76 | if self.image_share_embed: 77 | # Output into the number of intensities in a patch 78 | # (no mask dim needed), applied in a sliding fashion 79 | self.out_feature_embedding = nn.ModuleList([ 80 | nn.Linear(dim_hidden, self.patch_size)]) 81 | else: 82 | # Separate linear embedding for each patch 83 | self.out_feature_embedding = nn.ModuleList([ 84 | nn.Linear(dim_hidden, self.patch_size) 85 | for _ in range(self.image_n_patches)]) 86 | 87 | self.out_target_embedding = nn.Linear( 88 | dim_hidden, c.model_image_n_classes) 89 | 90 | def decode(self, X): 91 | # We receive a tensor of shape (N, n_patches + 1, E) 92 | 93 | # Feature Patch De-Embedding 94 | if self.image_share_embed: 95 | de_embeds = cycle(self.out_feature_embedding) 96 | else: 97 | de_embeds = self.out_feature_embedding 98 | 99 | X_ragged = [] 100 | 101 | # Projects each batched feature patch of shape (N, E) to (N, 102 | for patch_index in range(X.shape[1] - 1): 103 | # X_patch.shape = (N, E) 104 | X_patch = X[:, patch_index, :] 105 | 106 | # de_embed.shape = (E, p) where p = patch size 107 | de_embed = next(de_embeds) 108 | 109 | # X_de_embed.shape = (N, p) 110 | X_de_embed = de_embed(X_patch) 111 | 112 | # Split into p columns of shape (N, 1) 113 | X_de_embed = torch.split(X_de_embed, 1, dim=1) 114 | X_ragged += X_de_embed 115 | 116 | # Append projection of target column 117 | X_ragged.append(self.out_target_embedding(X[:, -1, :])) 118 | 119 | return X_ragged 120 | 121 | def get_npt_attrs(self): 122 | """Send a few key attributes back to the main model.""" 123 | return {'num_input_features': self.num_input_features, 124 | 'image_n_patches': self.image_n_patches, 125 | 'patch_size': self.patch_size} 126 | 127 | def preprocess_flattened_image(self, X_ragged): 128 | """ 129 | Prior to applying the Linear transforms, we wish to reshape 130 | our features, which constitute the image: 131 | * D = total number of columns (including the target) 132 | (N, D - 1, dim_intensity) 133 | where dim_intensity is 2 if we are using masking, 1 otherwise 134 | to (N, (D - 1) // n_channels, dim_intensity * n_channels) 135 | 136 | This is necessary because, e.g., CIFAR-10 flattens images to be of 137 | format 1024 R, 1024 G, 1024 B. We must reshape to make sure 138 | the patching has the correct receptive fields. 139 | 140 | Returns: 141 | Reshaped X_features, X_target column 142 | """ 143 | # Shape (N, D - 1, dim_intensity) 144 | # where dim_intensity = 2 if we have continuous pixel intensity + mask 145 | # or 1 if we just have the pixel intensity (no BERT augmentation mask) 146 | X_features = torch.stack(X_ragged[:-1], 1) 147 | 148 | # Reshape to (N, (D - 1) // n_channels, dim_intensity * n_channels) 149 | X_features = torch.reshape( 150 | X_features, 151 | (X_features.size(0), 152 | X_features.size(1) // self.image_n_channels, 153 | self.dim_intensity * self.image_n_channels)) 154 | 155 | # Shape (N, 1, H_j) where H_j = num_categories + bool(is_mask) 156 | # (e.g. 2, for image regression with BERT augmentation) 157 | X_target = X_ragged[-1] 158 | 159 | return X_features, X_target 160 | 161 | 162 | class LinearImagePatcher(ImagePatcher): 163 | def __init__(self, input_feature_dims, dim_hidden, c): 164 | super(LinearImagePatcher, self).__init__( 165 | dim_hidden, input_feature_dims, c, patcher_type='linear') 166 | 167 | self.patch_n_pixels = self.patch_side_length * self.patch_side_length 168 | pixel_input_dims = self.dim_intensity * self.image_n_channels 169 | 170 | # Each patch embedding should be shape 171 | # (patch_n_pixels, (1 + bool(is_mask)) * n_channels, dim_feature_embedding) 172 | if self.image_share_embed: 173 | self.in_feature_embedding = nn.ParameterList([ 174 | nn.Parameter(torch.empty( 175 | self.patch_n_pixels, pixel_input_dims, 176 | dim_hidden))]) 177 | else: 178 | self.in_feature_embedding = nn.ParameterList([ 179 | nn.Parameter(torch.empty( 180 | self.patch_n_pixels, pixel_input_dims, 181 | dim_hidden)) 182 | for _ in range(self.image_n_patches)]) 183 | 184 | for embed in self.in_feature_embedding: 185 | nn.init.xavier_uniform_(embed) 186 | 187 | self.in_target_embedding = nn.Linear( 188 | self.dim_target_col, dim_hidden) 189 | 190 | def encode(self, X_ragged): 191 | # Feature Patch Embedding 192 | # Embed to a list of n_patch tensors, 193 | # each of size (N, dim_feature_embedding) 194 | 195 | X_features, X_target = self.preprocess_flattened_image(X_ragged) 196 | 197 | if self.image_share_embed: 198 | embeds = cycle(self.in_feature_embedding) 199 | else: 200 | embeds = self.in_feature_embedding 201 | 202 | X_embeds = [] 203 | for pixel_index in range(0, X_features.shape[1], self.patch_n_pixels): 204 | # Projection: 205 | # n: batch dimension, number of rows 206 | # p: patch size in number of locations (e.g., num RGB pixels) 207 | # h: dim_intensity * n_channels 208 | # = (1 + 1) * n_channels if we use BERT masking, 209 | # = 1 * n_channels otherwise 210 | # e: dim_feature_embedding, NPT hidden dimensions 211 | 212 | # X_input.shape = (n, p, h) 213 | X_input = X_features[ 214 | :, pixel_index:pixel_index+self.patch_n_pixels, :] 215 | 216 | # embed.shape = (p, h, e) 217 | embed = next(embeds) 218 | 219 | X_embeds.append(torch.einsum('nph,phe->ne', X_input, embed)) 220 | 221 | X_embeds.append(self.in_target_embedding(X_target)) 222 | X_embed = torch.stack(X_embeds, 1) 223 | 224 | return X_embed 225 | -------------------------------------------------------------------------------- /npt/model/npt_modules.py: -------------------------------------------------------------------------------- 1 | """Contains base attention modules.""" 2 | 3 | import math 4 | 5 | import torch 6 | import torch.nn as nn 7 | 8 | 9 | class SaveAttMaps(nn.Module): 10 | def __init__(self): 11 | super().__init__() 12 | self.curr_att_maps = None 13 | self.Q = None 14 | self.K = None 15 | self.V = None 16 | self.out = None 17 | self.out_pre_res = None 18 | 19 | def forward(self, X, Q, K, V): 20 | self.curr_att_maps = nn.Parameter(X) 21 | self.Q = nn.Parameter(Q) 22 | self.K = nn.Parameter(K) 23 | self.V = nn.Parameter(V) 24 | 25 | return X 26 | 27 | 28 | class MAB(nn.Module): 29 | """Multi-head Attention Block. 30 | 31 | Based on Set Transformer implementation 32 | (Lee et al. 2019, https://github.com/juho-lee/set_transformer). 33 | """ 34 | def __init__( 35 | self, dim_Q, dim_KV, dim_emb, dim_out, c): 36 | """ 37 | 38 | Inputs have shape (B_A, N_A, F_A), where 39 | * `B_A` is a batch dimension, along we parallelise computation, 40 | * `N_A` is the number of samples in each batch, along which we perform 41 | attention, and 42 | * `F_A` is dimension of the embedding at input 43 | * `F_A` is `dim_Q` for query matrix 44 | * `F_A` is `dim_KV` for key and value matrix. 45 | 46 | Q, K, and V then all get embedded to `dim_emb`. 47 | `dim_out` is the output dimensionality of the MAB which has shape 48 | (B_A, N_A, dim_out), this can be different from `dim_KV` due to 49 | the head_mixing. 50 | 51 | This naming scheme is inherited from set-transformer paper. 52 | """ 53 | super(MAB, self).__init__() 54 | mix_heads = c.model_mix_heads 55 | num_heads = c.model_num_heads 56 | sep_res_embed = c.model_sep_res_embed 57 | ln = c.model_att_block_layer_norm 58 | rff_depth = c.model_rff_depth 59 | self.att_score_norm = c.model_att_score_norm 60 | self.pre_layer_norm = c.model_pre_layer_norm 61 | self.viz_att_maps = c.viz_att_maps 62 | self.model_ablate_rff = c.model_ablate_rff 63 | 64 | if self.viz_att_maps: 65 | self.save_att_maps = SaveAttMaps() 66 | 67 | if dim_out is None: 68 | dim_out = dim_emb 69 | elif (dim_out is not None) and (mix_heads is None): 70 | print('Warning: dim_out transformation does not apply.') 71 | dim_out = dim_emb 72 | 73 | self.num_heads = num_heads 74 | self.dim_KV = dim_KV 75 | self.dim_split = dim_emb // num_heads 76 | self.fc_q = nn.Linear(dim_Q, dim_emb) 77 | self.fc_k = nn.Linear(dim_KV, dim_emb) 78 | self.fc_v = nn.Linear(dim_KV, dim_emb) 79 | self.fc_mix_heads = nn.Linear(dim_emb, dim_out) if mix_heads else None 80 | self.fc_res = nn.Linear(dim_Q, dim_out) if sep_res_embed else None 81 | 82 | if ln: 83 | if self.pre_layer_norm: # Applied to X 84 | self.ln0 = nn.LayerNorm(dim_Q, eps=c.model_layer_norm_eps) 85 | else: # Applied after MHA and residual 86 | self.ln0 = nn.LayerNorm(dim_out, eps=c.model_layer_norm_eps) 87 | 88 | self.ln1 = nn.LayerNorm(dim_out, eps=c.model_layer_norm_eps) 89 | else: 90 | self.ln0 = None 91 | self.ln1 = None 92 | 93 | self.hidden_dropout = ( 94 | nn.Dropout(p=c.model_hidden_dropout_prob) 95 | if c.model_hidden_dropout_prob else None) 96 | 97 | self.att_scores_dropout = ( 98 | nn.Dropout(p=c.model_att_score_dropout_prob) 99 | if c.model_att_score_dropout_prob else None) 100 | 101 | self.init_rff(dim_out, rff_depth) 102 | 103 | def init_rff(self, dim_out, rff_depth): 104 | # Linear layer with 4 times expansion factor as in 'Attention is 105 | # all you need'! 106 | self.rff = [nn.Linear(dim_out, 4 * dim_out), nn.GELU()] 107 | 108 | if self.hidden_dropout is not None: 109 | self.rff.append(self.hidden_dropout) 110 | 111 | for i in range(rff_depth - 1): 112 | self.rff += [nn.Linear(4 * dim_out, 4 * dim_out), nn.GELU()] 113 | 114 | if self.hidden_dropout is not None: 115 | self.rff.append(self.hidden_dropout) 116 | 117 | self.rff += [nn.Linear(4 * dim_out, dim_out)] 118 | 119 | if self.hidden_dropout is not None: 120 | self.rff.append(self.hidden_dropout) 121 | 122 | self.rff = nn.Sequential(*self.rff) 123 | 124 | def forward(self, X, Y): 125 | if self.pre_layer_norm and self.ln0 is not None: 126 | X_multihead = self.ln0(X) 127 | else: 128 | X_multihead = X 129 | 130 | Q = self.fc_q(X_multihead) 131 | 132 | if self.fc_res is None: 133 | X_res = Q 134 | else: 135 | X_res = self.fc_res(X) # Separate embedding for residual 136 | 137 | K = self.fc_k(Y) 138 | V = self.fc_v(Y) 139 | 140 | Q_ = torch.cat(Q.split(self.dim_split, 2), 0) 141 | K_ = torch.cat(K.split(self.dim_split, 2), 0) 142 | V_ = torch.cat(V.split(self.dim_split, 2), 0) 143 | 144 | # TODO: track issue at 145 | # https://github.com/juho-lee/set_transformer/issues/8 146 | # A = torch.softmax(Q_.bmm(K_.transpose(1,2))/math.sqrt(self.dim_V), 2) 147 | A = torch.einsum('ijl,ikl->ijk', Q_, K_) 148 | 149 | if self.att_score_norm == 'softmax': 150 | A = torch.softmax(A / math.sqrt(self.dim_KV), 2) 151 | elif self.att_score_norm == 'constant': 152 | A = A / self.dim_split 153 | else: 154 | raise NotImplementedError 155 | 156 | if self.viz_att_maps: 157 | A = self.save_att_maps(A, Q_, K_, V_) 158 | 159 | # Attention scores dropout is applied to the N x N_v matrix of 160 | # attention scores. 161 | # Hence, it drops out entire rows/cols to attend to. 162 | # This follows Vaswani et al. 2017 (original Transformer paper). 163 | 164 | if self.att_scores_dropout is not None: 165 | A = self.att_scores_dropout(A) 166 | 167 | multihead = A.bmm(V_) 168 | multihead = torch.cat(multihead.split(Q.size(0), 0), 2) 169 | 170 | # Add mixing of heads in hidden dim. 171 | 172 | if self.fc_mix_heads is not None: 173 | H = self.fc_mix_heads(multihead) 174 | else: 175 | H = multihead 176 | 177 | # Follow Vaswani et al. 2017 in applying dropout prior to 178 | # residual and LayerNorm 179 | if self.hidden_dropout is not None: 180 | H = self.hidden_dropout(H) 181 | 182 | # True to the paper would be to replace 183 | # self.fc_mix_heads = nn.Linear(dim_V, dim_Q) 184 | # and Q_out = X 185 | # Then, the output dim is equal to input dim, just like it's written 186 | # in the paper. We should definitely check if that boosts performance. 187 | # This will require changes to downstream structure (since downstream 188 | # blocks expect input_dim=dim_V and not dim_Q) 189 | 190 | # Residual connection 191 | Q_out = X_res 192 | H = H + Q_out 193 | 194 | # Post Layer-Norm, as in SetTransformer and BERT. 195 | if not self.pre_layer_norm and self.ln0 is not None: 196 | H = self.ln0(H) 197 | 198 | if self.pre_layer_norm and self.ln1 is not None: 199 | H_rff = self.ln1(H) 200 | else: 201 | H_rff = H 202 | 203 | if self.model_ablate_rff: 204 | expanded_linear_H = H_rff 205 | else: 206 | # Apply row-wise feed forward network 207 | expanded_linear_H = self.rff(H_rff) 208 | 209 | # Residual connection 210 | expanded_linear_H = H + expanded_linear_H 211 | 212 | if not self.pre_layer_norm and self.ln1 is not None: 213 | expanded_linear_H = self.ln1(expanded_linear_H) 214 | 215 | if self.viz_att_maps: 216 | self.save_att_maps.out = nn.Parameter(expanded_linear_H) 217 | self.save_att_maps.out_pre_res = nn.Parameter(H) 218 | 219 | return expanded_linear_H 220 | 221 | 222 | class MHSA(nn.Module): 223 | """ 224 | Multi-head Self-Attention Block. 225 | 226 | Based on implementation from Set Transformer (Lee et al. 2019, 227 | https://github.com/juho-lee/set_transformer). 228 | Alterations detailed in MAB method. 229 | """ 230 | has_inducing_points = False 231 | 232 | def __init__(self, dim_in, dim_emb, dim_out, c): 233 | super(MHSA, self).__init__() 234 | self.mab = MAB(dim_in, dim_in, dim_emb, dim_out, c) 235 | 236 | def forward(self, X): 237 | return self.mab(X, X) 238 | -------------------------------------------------------------------------------- /npt/optim.py: -------------------------------------------------------------------------------- 1 | """Learning rate scheduler.""" 2 | 3 | import numpy as np 4 | import torch 5 | from dotmap import DotMap 6 | from fairseq.optim.fairseq_optimizer import FairseqOptimizer 7 | from fairseq.optim.lr_scheduler import cosine_lr_scheduler 8 | from torch import nn 9 | from torch.optim.lr_scheduler import ( 10 | LambdaLR, CosineAnnealingLR) 11 | from transformers import ( 12 | get_constant_schedule, 13 | get_linear_schedule_with_warmup, get_polynomial_decay_schedule_with_warmup) 14 | 15 | 16 | def clip_gradient(model, clip: float): 17 | nn.utils.clip_grad_norm_(model.parameters(), clip) 18 | 19 | 20 | class ConcatLR(torch.optim.lr_scheduler._LRScheduler): 21 | """ 22 | From Over9000 23 | https://github.com/mgrankin/over9000/blob/master/train.py 24 | """ 25 | def __init__(self, optimizer, scheduler1, scheduler2, total_steps, 26 | pct_start=0.5, last_epoch=-1): 27 | self.scheduler1 = scheduler1 28 | self.scheduler2 = scheduler2 29 | self.step_start = float(pct_start * total_steps) - 1 30 | self.curr_epoch = 0 31 | super(ConcatLR, self).__init__(optimizer, last_epoch) 32 | 33 | def step(self): 34 | if self.curr_epoch <= self.step_start: 35 | self.scheduler1.step() 36 | else: 37 | self.scheduler2.step() 38 | self.curr_epoch += 1 39 | super().step() 40 | 41 | def get_lr(self): 42 | if self.curr_epoch <= self.step_start: 43 | return self.scheduler1.get_last_lr() 44 | else: 45 | return self.scheduler2.get_last_lr() 46 | 47 | 48 | class TradeoffAnnealer: 49 | def __init__(self, c, num_steps=None): 50 | """ 51 | Anneal the tradeoff between label and augmentation loss according 52 | to some schedule. 53 | 54 | :param c: config 55 | :param num_steps: int, provide when loading from checkpoint to fast- 56 | forward to that tradeoff value. 57 | """ 58 | self.c = c 59 | self.name = self.c.exp_tradeoff_annealing 60 | 61 | self.num_steps = 0 62 | self.init_tradeoff = self.c.exp_tradeoff 63 | self.curr_tradeoff = self.c.exp_tradeoff 64 | self.max_steps = self.get_max_steps() 65 | self.step_map = { 66 | 'constant': self.constant_step, 67 | 'cosine': self.cosine_step, 68 | 'linear_decline': self.linear_decline_step} 69 | 70 | if self.name not in self.step_map.keys(): 71 | raise NotImplementedError 72 | 73 | self.step = self.step_map[self.name] 74 | 75 | if num_steps > 0: 76 | # If we are loading a model from checkpoint, 77 | # should update the annealer to that number of steps. 78 | for _ in range(num_steps): 79 | self.step() 80 | 81 | print(f'Fast-forwarded tradeoff annealer to step {num_steps}.') 82 | 83 | print( 84 | f'Initialized "{self.name}" augmentation/label tradeoff annealer. ' 85 | f'Annealing to minimum value in {self.max_steps} steps.') 86 | 87 | def get_max_steps(self): 88 | # If annealing proportion is set to -1, 89 | if self.c.exp_tradeoff_annealing_proportion == -1: 90 | # and the optimizer proportion is set, we use the optimizer 91 | # proportion to determine how long it takes for the tradeoff to 92 | # anneal to 0. 93 | if self.c.exp_optimizer_warmup_proportion != -1: 94 | return int(np.ceil(self.c.exp_optimizer_warmup_proportion 95 | * self.c.exp_num_total_steps)) 96 | # and the optimizer proportion is not set, 97 | # we take all steps to anneal. 98 | else: 99 | return self.c.exp_num_total_steps 100 | 101 | if (self.c.exp_tradeoff_annealing_proportion < 0 102 | or self.c.exp_tradeoff_annealing_proportion > 1): 103 | raise Exception('Invalid tradeoff annealing proportion.') 104 | 105 | # Otherwise, we use the tradeoff annealing proportion to determine 106 | # for how long we anneal. 107 | return int(np.ceil(self.c.exp_tradeoff_annealing_proportion 108 | * self.c.exp_num_total_steps)) 109 | 110 | def constant_step(self): 111 | self.num_steps += 1 112 | return self.curr_tradeoff 113 | 114 | def linear_decline_step(self): 115 | curr = self.num_steps 116 | max_val = self.init_tradeoff 117 | 118 | if self.num_steps <= self.max_steps: 119 | self.curr_tradeoff = max_val - (curr / self.max_steps) * max_val 120 | else: 121 | self.curr_tradeoff = 0 122 | 123 | self.num_steps += 1 124 | 125 | return self.curr_tradeoff 126 | 127 | def cosine_step(self): 128 | if self.num_steps <= self.max_steps: 129 | self.curr_tradeoff = self.init_tradeoff * (1 / 2) * ( 130 | np.cos(np.pi * (self.num_steps / self.max_steps)) + 1) 131 | else: 132 | self.curr_tradeoff = 0 133 | 134 | self.num_steps += 1 135 | 136 | return self.curr_tradeoff 137 | 138 | 139 | class LRScheduler: 140 | def __init__(self, c, name, optimizer): 141 | self.c = c 142 | self.name = name 143 | self.optimizer = optimizer 144 | self.num_steps = 0 145 | 146 | self.construct_auto_scheduler() 147 | 148 | print(f'Initialized "{name}" learning rate scheduler.') 149 | 150 | def construct_auto_scheduler(self): 151 | total_steps = self.c.exp_num_total_steps 152 | 153 | if self.c.exp_optimizer_warmup_proportion >= 0: 154 | num_warmup_steps = ( 155 | total_steps * self.c.exp_optimizer_warmup_proportion) 156 | else: 157 | num_warmup_steps = self.c.exp_optimizer_warmup_fixed_n_steps 158 | 159 | print(f'Warming up for {num_warmup_steps}/{total_steps} steps.') 160 | 161 | if self.name == 'constant': 162 | self.scheduler = get_constant_schedule(optimizer=self.optimizer) 163 | elif self.name == 'linear_warmup': 164 | self.scheduler = get_linear_schedule_with_warmup( 165 | optimizer=self.optimizer, 166 | num_warmup_steps=num_warmup_steps, 167 | num_training_steps=total_steps) 168 | elif self.name == 'cosine_cyclic': 169 | args = dict( 170 | warmup_updates=num_warmup_steps, 171 | warmup_init_lr=1e-7, 172 | max_lr=self.c.exp_lr, 173 | lr=[1e-7], 174 | t_mult=2., 175 | lr_period_updates=num_warmup_steps * 2, 176 | lr_shrink=0.5) 177 | optim = FairseqOptimizer(None) 178 | optim._optimizer = optim.optimizer = self.optimizer 179 | self.scheduler = cosine_lr_scheduler.CosineSchedule( 180 | optimizer=optim, args=DotMap(args)) 181 | elif self.name == 'polynomial_decay_warmup': 182 | # Based on the fairseq implementation, which is based on BERT 183 | self.scheduler = get_polynomial_decay_schedule_with_warmup( 184 | optimizer=self.optimizer, 185 | num_warmup_steps=num_warmup_steps, 186 | num_training_steps=total_steps, 187 | lr_end=1e-7, 188 | power=1.0) 189 | elif self.name == 'flat_and_anneal': 190 | def d(x): 191 | return 1 192 | 193 | assert self.c.exp_optimizer_warmup_proportion >= 0 194 | 195 | # We use exp_optimizer_warmup_proportion to denote the 196 | # flat LR regime, prior to annealing 197 | dummy = LambdaLR(self.optimizer, d) 198 | cosine = CosineAnnealingLR( 199 | self.optimizer, int(total_steps * ( 200 | 1 - self.c.exp_optimizer_warmup_proportion))) 201 | self.scheduler = ConcatLR( 202 | self.optimizer, dummy, cosine, total_steps, 203 | self.c.exp_optimizer_warmup_proportion) 204 | else: 205 | raise NotImplementedError 206 | 207 | def step(self): 208 | self.num_steps += 1 209 | c_lr = self.c.exp_lr 210 | num = self.num_steps 211 | tot = self.c.exp_num_total_steps 212 | 213 | if self.name == 'cosine_cyclic': 214 | self.scheduler.step_update(num_updates=num) 215 | else: 216 | self.scheduler.step() 217 | -------------------------------------------------------------------------------- /npt/utils/analyse_wandb_project.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import pandas as pd 3 | import wandb 4 | 5 | 6 | def filter_col(col, df): 7 | return list(filter(lambda x: col in x, df.columns)) 8 | 9 | 10 | def get_df_from_project(project): 11 | """Load summary and config DataFrame for all runs of a project. 12 | 13 | Largely copied from wandb docs. 14 | """ 15 | api = wandb.Api() 16 | runs = api.runs(f'anonymous_tab/{project}') 17 | 18 | summary_list = [] 19 | config_list = [] 20 | name_list = [] 21 | 22 | for run in runs: 23 | # run.summary are the output key/values like accuracy. 24 | # We call ._json_dict to omit large files. 25 | summary_list.append(run.summary._json_dict) 26 | # run.config is the input metrics. 27 | # We remove special values that start with _. 28 | config_list.append( 29 | {k: v for k,v in run.config.items() if not k.startswith('_')}) 30 | # run.name is the name of the run. 31 | name_list.append(run.name) 32 | 33 | summary_df = pd.DataFrame.from_records(summary_list) 34 | config_df = pd.DataFrame.from_records(config_list) 35 | name_df = pd.DataFrame({'name': name_list}) 36 | df = pd.concat([name_df, config_df, summary_df], axis=1) 37 | 38 | return df 39 | 40 | 41 | def get_rankings(df, val_metric, test_metrics, higher_is_better=False): 42 | """Assign ranking acc to val_metric to each model per split.""" 43 | row_list = [] 44 | for split, split_df in df.groupby('cv_index'): 45 | 46 | # get models and performances for that split 47 | models = split_df['exp_group'] 48 | val_perfs = split_df[val_metric] 49 | test_perfs = split_df[test_metrics] 50 | 51 | # sort models, performances by val_perf 52 | sorting = np.argsort(val_perfs) 53 | if higher_is_better: 54 | sorting = sorting[::-1] 55 | models = models.values[sorting] 56 | val_perfs = val_perfs.values[sorting] 57 | test_perfs = test_perfs.values[sorting] 58 | 59 | # for each model report ranking/performance in each split 60 | for ranking, model in enumerate(models): 61 | row_list.append([ 62 | model, split, ranking, 63 | val_perfs[ranking], *test_perfs[ranking]]) 64 | 65 | rankings_df = pd.DataFrame( 66 | data=row_list, 67 | columns=[ 68 | 'exp_group', 'cv_index', 'ranking', val_metric, *test_metrics]) 69 | 70 | # add rmse 71 | for test_metric in test_metrics: 72 | if 'mse' in test_metric: 73 | rankings_df[test_metric.replace('mse', 'rmse')] = np.sqrt( 74 | rankings_df[test_metric]) 75 | 76 | return rankings_df 77 | 78 | 79 | def report_losses(rankings_df): 80 | """Select model in each split by ranking.""" 81 | losses = rankings_df[rankings_df.ranking == 0] 82 | for metric in rankings_df.columns: 83 | if metric in ['exp_group', 'cv_index', 'ranking']: 84 | continue 85 | loss = losses[metric] 86 | mean, std = loss.mean(), loss.std() 87 | if 'accuracy' in metric: 88 | mean *= 100 89 | std *= 100 90 | print(f'Metric {metric}: {mean:.2f} \\pm {std:.2f}') 91 | 92 | elif 'cat' in metric: 93 | print(f'Metric {metric}: {mean:.3f} \\pm {std:.3f}') 94 | 95 | else: 96 | print(f'Metric {metric}: {mean:.4f} \\pm {std:.4f}') 97 | 98 | if 'rmse_loss_unstd' in metric: 99 | std /= np.sqrt(len(loss)) 100 | print(f'Metric {metric}: {mean:.4f} \\pm {std:.4f} (std_err)') 101 | 102 | return losses 103 | -------------------------------------------------------------------------------- /npt/utils/batch_utils.py: -------------------------------------------------------------------------------- 1 | import warnings 2 | from collections import OrderedDict, defaultdict 3 | 4 | import numpy as np 5 | from sklearn.utils.multiclass import type_of_target 6 | from sklearn.utils.validation import column_or_1d 7 | 8 | import torch 9 | TORCH_MAJOR = int(torch.__version__.split('.')[0]) 10 | TORCH_MINOR = int(torch.__version__.split('.')[1]) 11 | 12 | if TORCH_MAJOR == 1 and TORCH_MINOR < 8: 13 | from torch._six import container_abcs 14 | else: 15 | import collections.abc as container_abcs 16 | 17 | collate_with_pre_batching_err_msg_format = ( 18 | "collate_with_pre_batched_map: " 19 | "batch must be a list with one map element; found {}") 20 | 21 | 22 | def collate_with_pre_batching(batch): 23 | r""" 24 | Collate function used by our PyTorch dataloader (in both distributed and 25 | serial settings). 26 | 27 | We avoid adding a batch dimension, as for NPT we have pre-batched data, 28 | where each element of the dataset is a map. 29 | 30 | :arg batch: List[Dict] (not as general as the default collate fn) 31 | """ 32 | if len(batch) > 1: 33 | raise NotImplementedError 34 | 35 | elem = batch[0] 36 | elem_type = type(elem) 37 | 38 | if isinstance(elem, container_abcs.Mapping): 39 | return elem # Just return the dict, as there will only be one in NPT 40 | 41 | raise TypeError(collate_with_pre_batching_err_msg_format.format(elem_type)) 42 | 43 | 44 | # TODO: batching over features? 45 | 46 | class StratifiedIndexSampler: 47 | def __init__( 48 | self, y, n_splits, shuffle=True, label_col=None, 49 | train_indices=None): 50 | self.y = y 51 | self.n_splits = n_splits 52 | self.shuffle = shuffle 53 | self.label_col = label_col 54 | self.train_indices = train_indices 55 | if label_col is not None and train_indices is not None: 56 | self.stratify_class_labels = True 57 | print('Stratifying train rows in each batch on the class label.') 58 | else: 59 | self.stratify_class_labels = False 60 | 61 | def _make_test_folds(self, labels): 62 | """ 63 | Slight alterations from sklearn (StratifiedKFold) 64 | """ 65 | y, n_splits, shuffle = labels, self.n_splits, self.shuffle 66 | y = np.asarray(y) 67 | type_of_target_y = type_of_target(y) 68 | allowed_target_types = ('binary', 'multiclass') 69 | if type_of_target_y not in allowed_target_types: 70 | raise ValueError( 71 | 'Supported target types are: {}. Got {!r} instead.'.format( 72 | allowed_target_types, type_of_target_y)) 73 | 74 | y = column_or_1d(y) 75 | 76 | _, y_idx, y_inv = np.unique(y, return_index=True, return_inverse=True) 77 | # y_inv encodes y according to lexicographic order. We invert y_idx to 78 | # map the classes so that they are encoded by order of appearance: 79 | # 0 represents the first label appearing in y, 1 the second, etc. 80 | _, class_perm = np.unique(y_idx, return_inverse=True) 81 | y_encoded = class_perm[y_inv] 82 | 83 | n_classes = len(y_idx) 84 | y_counts = np.bincount(y_encoded) 85 | min_groups = np.min(y_counts) 86 | if np.all(n_splits > y_counts): 87 | raise ValueError("n_splits=%d cannot be greater than the" 88 | " number of members in each class." 89 | % (n_splits)) 90 | if n_splits > min_groups: 91 | warnings.warn(("The least populated class in y has only %d" 92 | " members, which is less than n_splits=%d." 93 | % (min_groups, n_splits)), UserWarning) 94 | 95 | # Determine the optimal number of samples from each class in each fold, 96 | # using round robin over the sorted y. (This can be done direct from 97 | # counts, but that code is unreadable.) 98 | y_order = np.sort(y_encoded) 99 | allocation = np.asarray( 100 | [np.bincount(y_order[i::n_splits], minlength=n_classes) 101 | for i in range(n_splits)]) 102 | 103 | # To maintain the data order dependencies as best as possible within 104 | # the stratification constraint, we assign samples from each class in 105 | # blocks (and then mess that up when shuffle=True). 106 | test_folds = np.empty(len(y), dtype='i') 107 | for k in range(n_classes): 108 | # since the kth column of allocation stores the number of samples 109 | # of class k in each test set, this generates blocks of fold 110 | # indices corresponding to the allocation for class k. 111 | folds_for_class = np.arange(n_splits).repeat(allocation[:, k]) 112 | if shuffle: 113 | np.random.shuffle(folds_for_class) 114 | test_folds[y_encoded == k] = folds_for_class 115 | return test_folds 116 | 117 | def get_stratified_test_array(self, X): 118 | """ 119 | Based on sklearn function StratifiedKFold._iter_test_masks. 120 | """ 121 | if self.stratify_class_labels: 122 | return self.get_train_label_stratified_test_array(X) 123 | 124 | test_folds = self._make_test_folds(self.y) 125 | 126 | # Inefficient for huge arrays, particularly when we need to materialize 127 | # the index order. 128 | # for i in range(n_splits): 129 | # yield test_folds == i 130 | 131 | batch_index_to_row_indices = OrderedDict() 132 | batch_index_to_row_index_count = defaultdict(int) 133 | for row_index, batch_index in enumerate(test_folds): 134 | if batch_index not in batch_index_to_row_indices.keys(): 135 | batch_index_to_row_indices[batch_index] = [row_index] 136 | else: 137 | batch_index_to_row_indices[batch_index].append(row_index) 138 | 139 | batch_index_to_row_index_count[batch_index] += 1 140 | 141 | # Keep track of the batch sizes for each batch -- this can vary 142 | # towards the end of the epoch, and will not be precisely what the 143 | # user specified. Doesn't matter because the model is equivariant 144 | # w.r.t. rows. 145 | batch_sizes = [] 146 | for batch_index in batch_index_to_row_indices.keys(): 147 | batch_sizes.append(batch_index_to_row_index_count[batch_index]) 148 | 149 | return ( 150 | X[np.concatenate(list(batch_index_to_row_indices.values()))], 151 | batch_sizes) 152 | 153 | def get_train_label_stratified_test_array(self, X): 154 | train_class_folds = self._make_test_folds( 155 | self.label_col[self.train_indices]) 156 | 157 | # Mapping from the size of a stratified batch of training rows 158 | # to the index of the batch. 159 | train_batch_size_to_train_batch_indices = defaultdict(list) 160 | 161 | # Mapping from a train batch index to all of the actual train indices 162 | train_batch_index_to_train_row_indices = OrderedDict() 163 | 164 | for train_row_index, train_batch_index in enumerate(train_class_folds): 165 | if (train_batch_index not in 166 | train_batch_index_to_train_row_indices.keys()): 167 | train_batch_index_to_train_row_indices[ 168 | train_batch_index] = [train_row_index] 169 | else: 170 | train_batch_index_to_train_row_indices[ 171 | train_batch_index].append(train_row_index) 172 | 173 | for train_batch_index, train_row_indices in ( 174 | train_batch_index_to_train_row_indices.items()): 175 | train_batch_size_to_train_batch_indices[ 176 | len(train_row_indices)].append(train_batch_index) 177 | 178 | test_folds = self._make_test_folds(self.y) 179 | 180 | # Mapping our actual batch indices to the val and test rows which 181 | # have been successfully assigned 182 | batch_index_to_val_test_row_indices = OrderedDict() 183 | 184 | # Mapping our actual batch indices to the total number of row indices 185 | # in each batch. We will have to assign the stratified train batches 186 | # to fulfill this constraint. 187 | batch_index_to_row_index_count = defaultdict(int) 188 | 189 | # Mapping our actual batch indices to how many train spots are 190 | # "vacant" in each batch. These we will fill with our stratified 191 | # train batches. 192 | batch_index_to_train_row_index_count = defaultdict(int) 193 | 194 | for row_index, (batch_index, dataset_mode) in enumerate( 195 | zip(test_folds, self.y)): 196 | batch_index_to_row_index_count[batch_index] += 1 197 | 198 | if dataset_mode == 0: # Train 199 | batch_index_to_train_row_index_count[batch_index] += 1 200 | else: 201 | if batch_index not in ( 202 | batch_index_to_val_test_row_indices.keys()): 203 | batch_index_to_val_test_row_indices[ 204 | batch_index] = [row_index] 205 | else: 206 | batch_index_to_val_test_row_indices[ 207 | batch_index].append(row_index) 208 | 209 | # For all of our actual batches, let's find a suitable batch 210 | # of stratified training data for us to use. 211 | for batch_index, train_row_index_count in batch_index_to_train_row_index_count.items(): 212 | try: 213 | train_batch_index = ( 214 | train_batch_size_to_train_batch_indices[ 215 | train_row_index_count].pop()) 216 | except Exception as e: 217 | raise e 218 | batch_index_to_val_test_row_indices[batch_index] += ( 219 | train_batch_index_to_train_row_indices[train_batch_index]) 220 | 221 | for train_batch_arr in train_batch_size_to_train_batch_indices.values(): 222 | if len(train_batch_arr) != 0: 223 | raise Exception 224 | 225 | batch_sizes = [] 226 | for batch_index in batch_index_to_val_test_row_indices.keys(): 227 | batch_sizes.append(batch_index_to_row_index_count[batch_index]) 228 | 229 | batch_order_sorted_row_indices = X[ 230 | np.concatenate(list(batch_index_to_val_test_row_indices.values()))] 231 | assert ( 232 | len(set(batch_order_sorted_row_indices)) == 233 | len(batch_order_sorted_row_indices)) 234 | return batch_order_sorted_row_indices, batch_sizes 235 | -------------------------------------------------------------------------------- /npt/utils/config_utils.py: -------------------------------------------------------------------------------- 1 | class Args: 2 | def __init__(self, data: dict = None): 3 | if data is None: 4 | self.data = {} 5 | else: 6 | self.data = data 7 | 8 | def __getattr__(self, key): 9 | return self.data.get(key, None) 10 | -------------------------------------------------------------------------------- /npt/utils/cv_utils.py: -------------------------------------------------------------------------------- 1 | """Cross-validation utils.""" 2 | 3 | from collections import Counter 4 | from enum import IntEnum 5 | 6 | import numpy as np 7 | from sklearn.model_selection import StratifiedKFold, train_test_split, KFold 8 | 9 | 10 | class DatasetMode(IntEnum): 11 | """Used in batching.""" 12 | TRAIN = 0 13 | VAL = 1 14 | TEST = 2 15 | 16 | 17 | DATASET_MODE_TO_ENUM = { 18 | 'train': DatasetMode.TRAIN, 19 | 'val': DatasetMode.VAL, 20 | 'test': DatasetMode.TEST 21 | } 22 | 23 | DATASET_ENUM_TO_MODE = { 24 | DatasetMode.TRAIN: 'train', 25 | DatasetMode.VAL: 'val', 26 | DatasetMode.TEST: 'test' 27 | } 28 | 29 | 30 | def get_class_reg_train_val_test_splits( 31 | label_rows, c, should_stratify, fixed_test_set_index): 32 | """"Obtain train, validation, and test indices. 33 | num_data = len(label_rows) 34 | 35 | Stratify Logic: 36 | 37 | Perform stratified splits if the target is a single categorical column; 38 | else, (even if we have multiple categorical targets, for example) 39 | perform standard splits. 40 | 41 | If fixed_test_set_index is not None, 42 | use the index to perform the test split 43 | """ 44 | if should_stratify and label_rows.dtype == np.object: 45 | from sklearn.preprocessing import LabelEncoder 46 | # Encode the label column 47 | label_rows = LabelEncoder().fit_transform(label_rows) 48 | print('Detected an object dtype label column. Encoded to ints.') 49 | 50 | N = len(label_rows) 51 | n_cv_splits = get_n_cv_splits(c) 52 | 53 | if fixed_test_set_index: 54 | all_indices = np.arange(N) 55 | train_test_splits = [ 56 | (all_indices[:fixed_test_set_index], 57 | all_indices[fixed_test_set_index:])] 58 | else: 59 | kf_class = StratifiedKFold if should_stratify else KFold 60 | kf = kf_class( 61 | n_splits=n_cv_splits, shuffle=True, random_state=c.np_seed) 62 | train_test_splits = kf.split(np.arange(N), label_rows) 63 | 64 | for train_val_indices, test_indices in train_test_splits: 65 | val_indices = [] 66 | if c.exp_val_perc > 0: 67 | normed_val_perc = c.exp_val_perc / (1 - c.exp_test_perc) 68 | 69 | if should_stratify: 70 | train_val_label_rows = label_rows[train_val_indices] 71 | else: 72 | train_val_label_rows = None 73 | 74 | train_indices, val_indices = train_test_split( 75 | train_val_indices, test_size=normed_val_perc, shuffle=True, 76 | random_state=c.np_seed, stratify=train_val_label_rows) 77 | else: 78 | train_indices = train_val_indices 79 | 80 | train_perc = len(train_indices) / N 81 | val_perc = len(val_indices) / N 82 | test_perc = len(test_indices) / N 83 | print( 84 | f'Percentage of each group: Train {train_perc:.2f} ' 85 | f'| {val_perc:.2f} | {test_perc:.2f}') 86 | 87 | if c.exp_show_empirical_label_dist: 88 | print('Empirical Label Distributions:') 89 | for split_name, split_indices in zip( 90 | ['train', 'val', 'test'], 91 | [train_indices, val_indices, test_indices]): 92 | num_elem = len(split_indices) 93 | class_counter = Counter(label_rows[split_indices]) 94 | class_proportions = { 95 | key: class_counter[key] / num_elem 96 | for key in sorted(class_counter.keys())} 97 | print(f'{split_name}:') 98 | print(class_proportions) 99 | 100 | yield train_indices, val_indices, test_indices 101 | 102 | 103 | def get_n_cv_splits(c): 104 | return int(1 / c.exp_test_perc) # Rounds down 105 | -------------------------------------------------------------------------------- /npt/utils/data_loading_utils.py: -------------------------------------------------------------------------------- 1 | import json 2 | import sys 3 | 4 | import numpy as np 5 | import wget 6 | 7 | 8 | class NumpyEncoder(json.JSONEncoder): 9 | """ Special json encoder for numpy types """ 10 | def default(self, obj): 11 | if isinstance( 12 | obj, 13 | (np.int_, np.intc, np.intp, np.int8, np.int16, np.int32, 14 | np.int64, np.uint8, np.uint16, np.uint32, np.uint64)): 15 | return int(obj) 16 | elif isinstance(obj, (np.float_, np.float16, np.float32, 17 | np.float64)): 18 | return float(obj) 19 | elif isinstance(obj, (np.ndarray,)): # This is the fix 20 | return obj.tolist() 21 | return json.JSONEncoder.default(self, obj) 22 | 23 | 24 | def convert_unserialized_data_to_serialized_json(data): 25 | return json.dumps(data, cls=NumpyEncoder) 26 | 27 | 28 | def stack_ragged(array_list, axis=1): 29 | """ 30 | Stack ragged data table for downstream npz storage. 31 | From https://tonysyu.github.io/ragged-arrays.html 32 | """ 33 | lengths = [np.shape(a)[axis] for a in array_list] 34 | idx = np.cumsum(lengths[:-1]) 35 | stacked = np.concatenate(array_list, axis=axis) 36 | return stacked, idx 37 | 38 | 39 | def bar_progress(current, total, width=80): 40 | """Display progress bar while downloading. 41 | 42 | https://stackoverflow.com/questions/58125279/ 43 | python-wget-module-doesnt-show-progress-bar" 44 | """ 45 | 46 | progress_message = ( 47 | f'Downloading: {current/total * 100:.0f} %% ' 48 | f'[{current:.2e} / {total:.2e}] bytes') 49 | 50 | # Don't use print() as it will print in new line every time. 51 | sys.stdout.write("\r" + progress_message) 52 | sys.stdout.flush() 53 | 54 | 55 | def download(paths, urls): 56 | # Download URLs to paths. 57 | 58 | if not isinstance(urls, list): 59 | urls = [urls] 60 | 61 | if not isinstance(paths, list): 62 | paths = [paths] 63 | 64 | if not len(urls) == len(paths): 65 | raise ValueError('Need exactly one path per URL.') 66 | 67 | for path, url in zip(paths, urls): 68 | print(f'Downloading {url}.') 69 | path.parent.mkdir(parents=True, exist_ok=True) 70 | wget.download(url, out=str(path), bar=bar_progress) 71 | -------------------------------------------------------------------------------- /npt/utils/debug.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | 4 | 5 | def modify_data(c, batch_dict, dataset_mode, num_steps): 6 | """Modify data for debugging row interactions in synthetic experiments.""" 7 | 8 | if 'protein-duplicate' in c.debug_row_interactions_mode: 9 | return protein_duplicate( 10 | c, batch_dict, dataset_mode, c.debug_row_interactions_mode) 11 | else: 12 | raise ValueError 13 | 14 | 15 | def corrupt_rows(c, batch_dict, dataset_mode, row_index): 16 | """Corrupt rows: 17 | (i) Duplication experiments -- find the duplicated row of the specified 18 | `row_index`. Flip its label. 19 | (ii) Standard datasets -- for each column, apply an independent permutation 20 | over entries in all rows other than row `row_index`. 21 | """ 22 | if (c.debug_row_interactions and 23 | c.debug_row_interactions_mode == 'protein-duplicate'): 24 | return corrupt_duplicate_rows(c, batch_dict, dataset_mode, row_index) 25 | else: 26 | return corrupt_standard_dataset(c, batch_dict, dataset_mode, row_index) 27 | 28 | 29 | def duplicate_batch_dict(batch_dict): 30 | def recursive_clone(obj): 31 | if isinstance(obj, (int, float)): 32 | return obj 33 | elif isinstance(obj, list): 34 | return [recursive_clone(elem) for elem in obj] 35 | elif isinstance(obj, torch.Tensor): 36 | return obj.clone().detach() 37 | elif obj is None: 38 | return None 39 | else: 40 | raise NotImplementedError 41 | 42 | new_batch_dict = {} 43 | for key, value in batch_dict.items(): 44 | new_batch_dict[key] = recursive_clone(value) 45 | 46 | return new_batch_dict 47 | 48 | 49 | def corrupt_duplicate_rows(c, batch_dict, dataset_mode, row_index): 50 | """ 51 | The aim of this corruption is to show that using the `designated lookup 52 | row` (located at `row_index`, which is a duplicate of the row at 53 | `row_index` + N) is necessary to solve the task for duplicated datasets, 54 | like protein-duplication. 55 | 56 | We wish to remove the ability to perform a successful lookup, and 57 | accomplish this by "flipping" the label of the duplicated row. 58 | - We can't simply input just a single row to our model, because 59 | we'd be changing batch statistics, which could account for 60 | changes in the prediction. 61 | - We don't want to corrupt the features as well -- the model should 62 | still be able to lookup the right row, but then should fail 63 | because of the label alteration we made. 64 | 65 | We will select a new label to which we flip the label of the designated 66 | lookup row by selecting uniformly at random from other unmasked rows. 67 | These unmasked rows are specified by the label_matrix, which is aware 68 | of stochastic label masking changes. 69 | 70 | Finally, we restrict the label_matrix to assure that we are only 71 | evaluating a loss on the `row_index`. 72 | """ 73 | # Avoid overwriting things we will need in corruptions for other rows 74 | bd = duplicate_batch_dict(batch_dict) 75 | 76 | if bd['label_mask_matrix'] is not None: 77 | # Triggers for stochastic label masking. 78 | # Only not None in train mode. In which case we can use it to only 79 | # reveal those train indices that have been masked. 80 | label_matrix = 'label' 81 | else: 82 | # We are in val/test mode. In which case all val/test labels are masked 83 | # and need to be revealed at val/test time, to check that model is 84 | # actually learning interactions! 85 | label_matrix = dataset_mode 86 | 87 | # (Note that there may be stochastic masking on the train labels still. 88 | # but we do not reveal those anymore as there is no loss computed on 89 | # them.) 90 | 91 | if bd[f'{label_matrix}_mask_matrix'] is None: 92 | raise NotImplementedError 93 | 94 | num_cols = len(bd['data_arrs']) 95 | num_rows = bd['data_arrs'][0].shape[0] // 2 96 | 97 | # Keep track of target columns -- we will need to zero out the 98 | # label_matrix, and then set only the row_index in the specified 99 | # target columns so that we are only evaluating loss on our chosen 100 | # row index 101 | target_cols = [] 102 | 103 | for col in range(num_cols): 104 | # get true values wherever the label matrix has masks 105 | locations = bd[f'{label_matrix}_mask_matrix'][:, col].nonzero( 106 | as_tuple=True)[0] 107 | if locations.nelement() == 0: 108 | continue 109 | 110 | target_cols.append(col) 111 | 112 | # These locations currently give us indexes where the loss should 113 | # be evaluated. We can determine the locations of the unmasked rows 114 | # by subtracting the original number of rows. 115 | locations -= num_rows 116 | 117 | # Next, we remove the provided row_index, as we do not want to flip its 118 | # label to itself -- this would of course be unsuccessful in corrupting 119 | # the label! 120 | locations = locations.tolist() 121 | locations = list(set(locations) - {row_index}) 122 | 123 | # Randomly select one of the locations 124 | flip_index = np.random.choice(locations) 125 | 126 | # Replace the label of the `designated lookup row` with that of the 127 | # flip_index row we have just randomly selected 128 | bd[ 129 | 'masked_tensors'][col][row_index] = bd[ 130 | 'masked_tensors'][col][flip_index] 131 | 132 | # Only evaluate loss on the row_index in appropriate target columns. 133 | 134 | # Obtain loss index as originally specified row_index + number of rows 135 | loss_index = row_index + num_rows 136 | rows_to_zero = list(set(range(int(num_rows * 2))) - {loss_index}) 137 | bd[f'{label_matrix}_mask_matrix'][rows_to_zero, :] = False 138 | 139 | return bd 140 | 141 | 142 | def corrupt_standard_dataset(c, batch_dict, dataset_mode, row_index): 143 | """ 144 | The aim of this corruption is to show that using row interactions improves 145 | performance on a standard dataset, such as protein, higgs, or forest-cover. 146 | 147 | To accomplish this corruption, we independently permute each of the columns 148 | over all row indices, __excluding__ the specified row index. 149 | """ 150 | # Avoid overwriting things we will need in corruptions for other rows 151 | bd = duplicate_batch_dict(batch_dict) 152 | 153 | n_cols = len(bd['data_arrs']) 154 | n_rows = bd['data_arrs'][0].shape[0] 155 | 156 | # Row indices to shuffle -- exclude the given row_index 157 | row_indices = list(set(range(n_rows)) - {row_index}) 158 | 159 | # Shuffle all rows other than our selected one, row_index 160 | # Perform an independent permutation for each column so the row info 161 | # is destroyed (otherwise, our row-equivariant model won't have an 162 | # issue with permuted rows). 163 | for col in range(n_cols): 164 | # Test -- if we ablate shuffle, do not swap around elements 165 | if not c.debug_corrupt_standard_dataset_ablate_shuffle: 166 | shuffled_row_indices = np.random.permutation(row_indices) 167 | 168 | # Shuffle masked_tensors, which our model sees at input. 169 | # Don't need to shuffle data_arrs, because the row at which 170 | # we evaluate loss will be in the same place. 171 | bd['masked_tensors'][col][row_indices] = bd[ 172 | 'masked_tensors'][col][shuffled_row_indices] 173 | 174 | # We also zero out the 175 | # {dataset_mode}, augmentation, and label mask matrices at all 176 | # rows other than row_index 177 | for matrix in [dataset_mode, 'augmentation', 'label']: 178 | mask = f'{matrix}_mask_matrix' 179 | if bd[mask] is not None: 180 | bd[mask][:, col][row_indices] = False 181 | 182 | return bd 183 | 184 | 185 | def random_row_perm(N, batch_dict, dataset_mode): 186 | row_perm = torch.randperm(N) 187 | num_cols = len(batch_dict['data_arrs']) 188 | for col in range(num_cols): 189 | bdc = batch_dict['data_arrs'][col] 190 | bdc[:] = bdc[row_perm] 191 | 192 | mt = batch_dict['masked_tensors'][col] 193 | mt[:] = mt[row_perm] 194 | 195 | batch_dict[f'{dataset_mode}_mask_matrix'] = ( 196 | batch_dict[f'{dataset_mode}_mask_matrix'][row_perm]) 197 | 198 | return batch_dict 199 | 200 | 201 | def leakage(c, batch_dict, masked_tensors, label_mask_matrix, dataset_mode): 202 | if c.data_set != 'breast-cancer': 203 | raise Exception 204 | 205 | if not (c.model_label_bert_mask_prob[dataset_mode] == 1): 206 | raise ValueError( 207 | 'Leakage check only supported for deterministic label masking.') 208 | 209 | target_col = masked_tensors[0] 210 | assert target_col[:, -1].sum() == masked_tensors[0].size(0) 211 | assert target_col[:, 0].sum() == 0 212 | assert target_col[:, 1].sum() == 0 213 | assert label_mask_matrix is None 214 | 215 | n_label_loss_entries = batch_dict[ 216 | f'{dataset_mode}_mask_matrix'].sum() 217 | 218 | print(f'{dataset_mode} mode:') 219 | print(f'Inputs over {masked_tensors[0].size(0)} rows.') 220 | print( 221 | f'Computing label loss at {n_label_loss_entries} entries.') 222 | 223 | 224 | def protein_duplicate(c, batch_dict, dataset_mode, duplication_mode): 225 | """Append unmasked copy to the dataset. 226 | Allows for perfect loss if model exploits row interactions. 227 | This is version that respects dataset mode. 228 | Only unveil labels of current dataset mode. 229 | Currently does not unveil bert masks in copy. 230 | """ 231 | verbose = True 232 | if verbose: 233 | print('Protein-duplicate mode', duplication_mode) 234 | 235 | 236 | N_in, D = batch_dict['data_arrs'][0].shape 237 | num_cols = len(batch_dict['data_arrs']) 238 | N_out = 2 * N_in 239 | bd = batch_dict 240 | 241 | if bd['label_mask_matrix'] is not None: 242 | # Triggers for stochastic label masking. 243 | # Only not None in train mode. In which case we can use it to only 244 | # reveal those train indices that have been masked. 245 | label_matrix = 'label' 246 | else: 247 | # We are in val/test mode. In which case all val/test labels are masked 248 | # and need to be revealed at val/test time, to check that model is 249 | # actually learning interactions! 250 | label_matrix = dataset_mode 251 | 252 | # (Note that there may be stochastic masking on the train labels still. 253 | # but we do not reveal those anymore as there is no loss computed on 254 | # them.) 255 | 256 | # do the same for each col 257 | for col in range(num_cols): 258 | 259 | # duplicate real data 260 | bd['data_arrs'][col] = torch.cat([ 261 | bd['data_arrs'][col], 262 | bd['data_arrs'][col]], 0) 263 | 264 | # create new copy of data where masks are removed for everything that 265 | # is currently part of dataset_mode mask matrix 266 | # (i.e. all the labels) 267 | 268 | # append masked data again 269 | predict_rows = bd['masked_tensors'][col] 270 | if ('no-nn' in duplication_mode) and col > 2: 271 | lookup_rows = torch.ones_like(predict_rows) 272 | lookup_rows[:, 0] = torch.normal( 273 | mean=torch.Tensor(N_in*[1.]), 274 | std=torch.Tensor(N_in*[1.])) 275 | 276 | bd['masked_tensors'][col] = torch.cat([ 277 | lookup_rows, predict_rows], 0) 278 | else: 279 | lookup_rows = bd['masked_tensors'][col] 280 | bd['masked_tensors'][col] = torch.cat([ 281 | lookup_rows, predict_rows], 0) 282 | 283 | # now unveil relevant values 284 | for matrix in [label_matrix, 'augmentation']: 285 | if bd[f'{matrix}_mask_matrix'] is None: 286 | continue 287 | 288 | # get true values wherever current train/aug matrix has masks 289 | locations = bd[f'{matrix}_mask_matrix'][:, col].nonzero( 290 | as_tuple=True)[0] 291 | # in these locations replace masked tensors with true data 292 | dtype = bd['masked_tensors'][col].dtype 293 | bd['masked_tensors'][col][locations] = ( 294 | bd['data_arrs'][col][locations].type(dtype)) 295 | 296 | if ('target-add' in duplication_mode) and (col in bd['target_cols']): 297 | bd['masked_tensors'][col][locations] += 1 298 | 299 | 300 | # now modify the mask_matrices to fit dimensions of new data 301 | # (all zeros, don't need to predict on that new data) 302 | for matrix in [dataset_mode, 'augmentation', 'label']: 303 | mask = f'{matrix}_mask_matrix' 304 | if bd[mask] is not None: 305 | bd[mask] = torch.cat([ 306 | torch.zeros_like(bd[mask]), 307 | bd[mask]], 0) 308 | 309 | return batch_dict 310 | -------------------------------------------------------------------------------- /npt/utils/encode_utils.py: -------------------------------------------------------------------------------- 1 | from collections import deque 2 | 3 | import numpy as np 4 | import torch 5 | from sklearn.preprocessing import OneHotEncoder, StandardScaler, LabelEncoder 6 | 7 | 8 | def construct_encoded_col( 9 | non_missing_col_filter, encoded_non_missing_col_values): 10 | """ 11 | Construct encoded column. For missing values: 12 | If column is numeric, take the entry from original col. 13 | If column is categorical, take an empty numpy array of one-hot length. 14 | 15 | :param non_missing_col_filter: 16 | np.array[np.bool_], True in entries that are not missing 17 | :param encoded_non_missing_col_values: 18 | np.array, Encoded column for non-missing values 19 | :return: np.array, encoded_col 20 | """ 21 | encoded_col = [] 22 | encoded_queue = deque(encoded_non_missing_col_values) 23 | one_hot_length = encoded_non_missing_col_values.shape[1] 24 | 25 | for elem_was_encoded in non_missing_col_filter: 26 | if elem_was_encoded: 27 | encoded_col.append(encoded_queue.popleft()) 28 | else: 29 | # Replaces missing elements with a one-hot of correct length for 30 | # cat variables, and a single value for num variables 31 | encoded_col.append(np.zeros(one_hot_length)) 32 | 33 | return np.array(encoded_col) 34 | 35 | 36 | def get_compute_statistics_and_non_missing_matrix(data_dict, c): 37 | missing_matrix = data_dict['missing_matrix'] 38 | val_mask_matrix = data_dict['val_mask_matrix'] 39 | test_mask_matrix = data_dict['test_mask_matrix'] 40 | row_boundaries = data_dict['row_boundaries'] 41 | 42 | # Matrix with a 1 entry for all elements at which we 43 | # should compute a statistic / encode over 44 | compute_statistics_matrix = ( 45 | 1 - missing_matrix - val_mask_matrix - 46 | test_mask_matrix).astype(np.bool_) 47 | 48 | # If production, don't compute statistics using val/test 49 | if not c.model_is_semi_supervised: 50 | compute_statistics_matrix[row_boundaries['train']:] = False 51 | 52 | # Matrix with a 1 entry for all non-missing elements (i.e. those 53 | # we should transform) 54 | non_missing_matrix = ~missing_matrix 55 | 56 | return compute_statistics_matrix, non_missing_matrix, missing_matrix 57 | 58 | 59 | def encode_data( 60 | data_dict, compute_statistics_matrix, non_missing_matrix, 61 | missing_matrix, data_dtype, use_bert_masking, c): 62 | """ 63 | :return: 64 | Unpacked from data_dict: 65 | :param data_table: np.array, 2D unencoded data array 66 | :param N: int, number of rows 67 | :param D: int, number of columns 68 | :param cat_features: List[int], column indices with cat features 69 | :param num_features: List[int], column indices with num features 70 | :param compute_statistics_matrix: np.array[bool], True entries in 71 | locations that should be used to compute statistics / fit encoder 72 | :param non_missing_matrix: np.array[bool], True entries in 73 | locations that are not missing from data (don't attempt to encode 74 | missing entries, which could be NaN or an arbitrary missing token) 75 | :param missing_matrix: np.array[bool], inverse of the above 76 | :param np dtype to use for all data arrays 77 | :param use_bert_masking, if False, do not add an extra column to keep track 78 | of masked and missing values 79 | :return: Tuple[encoded_data, input_feature_dims] 80 | encoded_dataset: List[np.array], encoded columns 81 | input_feature_dims: List[int], Size of encoding for each feature. 82 | Needed to initialise embedding weights in NPT. 83 | """ 84 | data_table = data_dict['data_table'] 85 | N = data_dict['N'] 86 | D = data_dict['D'] 87 | cat_features = data_dict['cat_features'] 88 | num_features = data_dict['num_features'] 89 | cat_target_cols = data_dict['cat_target_cols'] 90 | 91 | encoded_dataset = [] 92 | input_feature_dims = [] 93 | 94 | standardisation = np.nan * np.ones((D, 2)) 95 | tabnet_mode = (c.model_class == 'sklearn-baselines' and 96 | c.sklearn_model == 'TabNet') 97 | 98 | # Extract just the sigmas in a JSON-serializable format 99 | # we use this as metadata for numerical columns to unstandardize them 100 | sigmas = [] 101 | if tabnet_mode: 102 | cat_col_dims = [] 103 | 104 | for col_index in range(D): 105 | # The column over which we compute statistics 106 | stat_filter = compute_statistics_matrix[:, col_index] 107 | stat_col = data_table[stat_filter, col_index].reshape(-1, 1) 108 | 109 | # Non-missing entries, which we transform 110 | non_missing_filter = non_missing_matrix[:, col_index] 111 | non_missing_col = data_table[ 112 | non_missing_filter, col_index].reshape(-1, 1) 113 | 114 | # Fit on stat_col, transform non_missing_col 115 | is_cat = False 116 | if col_index in cat_features: 117 | is_cat = True 118 | if tabnet_mode and col_index not in cat_target_cols: 119 | # Use TabNet's label encoding 120 | # https://github.com/dreamquark-ai/tabnet/blob/develop/ 121 | # forest_example.ipynb 122 | l_enc = LabelEncoder() 123 | encoded_col = np.expand_dims( 124 | l_enc.fit_transform(non_missing_col), -1) 125 | num_classes = len(l_enc.classes_) 126 | cat_col_dims.append(num_classes) 127 | else: 128 | fitted_encoder = OneHotEncoder(sparse=False).fit( 129 | non_missing_col) 130 | encoded_col = fitted_encoder.transform( 131 | non_missing_col).astype(np.bool_) 132 | 133 | # Stand-in for a np.nan, but JSON-serializable 134 | sigmas.append(-1) 135 | 136 | elif col_index in num_features: 137 | fitted_encoder = StandardScaler().fit(stat_col) 138 | encoded_col = fitted_encoder.transform(non_missing_col) 139 | standardisation[col_index, 0] = fitted_encoder.mean_[0] 140 | standardisation[col_index, 1] = fitted_encoder.scale_[0] 141 | sigmas.append(fitted_encoder.scale_[0]) 142 | else: 143 | raise NotImplementedError 144 | 145 | # Construct encoded column 146 | # (we have only encoded non-missing entries! need to fill in missing) 147 | encoded_col = construct_encoded_col( 148 | non_missing_col_filter=non_missing_filter, 149 | encoded_non_missing_col_values=encoded_col) 150 | 151 | if use_bert_masking: 152 | # Add mask tokens to numerical and categorical data 153 | # Each col is now shape Nx(H_j+1) 154 | encoded_col = np.hstack( 155 | [encoded_col, np.zeros((N, 1))]) 156 | 157 | # Get missing indices to zero out values and set mask token 158 | # TODO: try randomly sampling for missing indices from a std normal 159 | missing_filter = missing_matrix[:, col_index] 160 | 161 | # Zero out all one-hots (or the single numerical val) for these entries 162 | encoded_col[missing_filter, :] = 0 163 | 164 | # Set their mask token to 1 165 | encoded_col[missing_filter, -1] = 1 166 | 167 | if not tabnet_mode: 168 | # If categorical column, convert to bool 169 | if is_cat: 170 | encoded_col = encoded_col.astype(np.bool_) 171 | else: 172 | encoded_col = encoded_col.astype(data_dtype) 173 | 174 | encoded_dataset.append(encoded_col) 175 | input_feature_dims.append(encoded_col.shape[1]) 176 | 177 | if tabnet_mode: 178 | return ( 179 | encoded_dataset, input_feature_dims, standardisation, 180 | sigmas, cat_col_dims) 181 | else: 182 | return encoded_dataset, input_feature_dims, standardisation, sigmas 183 | 184 | 185 | def encode_data_dict(data_dict, c): 186 | # * TODO: need to vectorize for huge datasets 187 | # * TODO: (i.e. can't fit in CPU memory) 188 | compute_statistics_matrix, non_missing_matrix, missing_matrix = ( 189 | get_compute_statistics_and_non_missing_matrix(data_dict, c)) 190 | 191 | data_dtype = get_numpy_dtype(dtype_name=c.data_dtype) 192 | 193 | return encode_data( 194 | data_dict, compute_statistics_matrix, non_missing_matrix, 195 | missing_matrix, data_dtype, c.model_bert_augmentation, c) 196 | 197 | 198 | def get_numpy_dtype(dtype_name): 199 | if dtype_name == 'float32': 200 | dtype = np.float32 201 | elif dtype_name == 'float64': 202 | dtype = np.float64 203 | else: 204 | raise NotImplementedError 205 | 206 | return dtype 207 | 208 | 209 | def get_torch_dtype(dtype_name): 210 | if dtype_name == 'float32': 211 | dtype = torch.float32 212 | elif dtype_name == 'float64': 213 | dtype = torch.float64 214 | else: 215 | raise NotImplementedError 216 | 217 | return dtype 218 | 219 | 220 | def get_torch_tensor_type(dtype_name): 221 | if dtype_name == 'float32': 222 | dtype = torch.FloatTensor 223 | elif dtype_name == 'float64': 224 | dtype = torch.DoubleTensor 225 | else: 226 | raise NotImplementedError 227 | 228 | return dtype 229 | 230 | 231 | def torch_cast_to_dtype(obj, dtype_name): 232 | if dtype_name == 'float32': 233 | obj = obj.float() 234 | elif dtype_name == 'float64': 235 | obj = obj.double() 236 | elif dtype_name == 'long': 237 | obj = obj.long() 238 | else: 239 | raise NotImplementedError 240 | 241 | return obj 242 | -------------------------------------------------------------------------------- /npt/utils/eval_checkpoint_utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | from enum import Enum 3 | from pathlib import Path 4 | 5 | import torch 6 | import wandb 7 | from time import sleep 8 | from npt.utils.model_init_utils import ( 9 | init_model_opt_scaler, setup_ddp_model) 10 | 11 | 12 | class EarlyStopSignal(Enum): 13 | CONTINUE = 0 14 | STOP = 1 # Early stopping has triggered 15 | END = 2 # We have reached the final epoch 16 | 17 | 18 | class EarlyStopCounter: 19 | def __init__(self, c, data_cache_prefix, metadata, wandb_run, cv_index, 20 | n_splits, device=None): 21 | """ 22 | :param c: config 23 | :param data_cache_prefix: str; cache path for the dataset. Used for 24 | model checkpoints 25 | :param metadata: Dict, used for model initialization 26 | :param device: str; set in the distributed setting, otherwise uses 27 | config option c.exp_device. 28 | """ 29 | # The number of contiguous epochs for which validation 30 | # loss has not improved (early stopping) 31 | self.num_inc_valid_loss_epochs = 0 32 | 33 | # The number of times validation loss has improved since last 34 | # caching the model -- used for (infrequent) model checkpointing 35 | self.num_valid_improvements_since_cache = 0 36 | 37 | # The number of times validation loss must improve prior to our 38 | # caching of the model 39 | if c.exp_cache_cadence == -1: 40 | self.cache_cadence = float('inf') # We will never cache 41 | else: 42 | self.cache_cadence = c.exp_cache_cadence 43 | 44 | # Minimum validation loss that the counter has observed 45 | self.min_val_loss = float('inf') 46 | 47 | self.patience = c.exp_patience 48 | self.c = c 49 | self.wandb_run = wandb_run 50 | self.cv_index = cv_index 51 | self.n_splits = n_splits 52 | 53 | self.metadata = metadata 54 | 55 | self.stop_signal_message = ( 56 | f'Validation loss has not improved ' 57 | f'for {self.patience} contiguous epochs. ' 58 | f'Stopping evaluation now..') 59 | 60 | # Only needed for distribution 61 | self.device = device 62 | 63 | # Model caching 64 | """ 65 | checkpoint_setting: Union[str, None]; have options None, 66 | best_model, and all_checkpoints. 67 | None will never checkpoint models. 68 | best_model will only have in cache at any given time the best 69 | performing model yet evaluated. 70 | all_checkpoints will avoid overwriting, storing each best 71 | performing model. Can incur heavy memory load. 72 | """ 73 | # Cache models to separate directories for each CV split 74 | # (for any dataset in which we have multiple splits) 75 | if self.n_splits > 1: 76 | data_cache_prefix += f'__cv_{self.cv_index}' 77 | 78 | self.checkpoint_setting = c.exp_checkpoint_setting 79 | self.model_cache_path = Path(data_cache_prefix) / 'model_checkpoints' 80 | self.best_model_path = None 81 | 82 | # Only interact with file system in serial mode, or with first GPU 83 | if self.device is None or self.device == 0: 84 | # Create cache path, if it doesn't exist 85 | if not os.path.exists(self.model_cache_path): 86 | os.makedirs(self.model_cache_path) 87 | 88 | if not self.c.exp_load_from_checkpoint and not self.c.viz_att_maps: 89 | # Clear cache path, just in case there was a 90 | # previous run with same config 91 | self.clear_cache_path() 92 | 93 | def update( 94 | self, val_loss, model, optimizer, scaler, epoch, end_experiment, 95 | tradeoff_annealer=None): 96 | 97 | if val_loss < self.min_val_loss: 98 | self.min_val_loss = val_loss 99 | self.num_inc_valid_loss_epochs = 0 100 | self.num_valid_improvements_since_cache += 1 101 | 102 | # Only cache: 103 | # * If not performing a row corr experiment 104 | # * If in serial mode, or distributed mode with the GPU0 process 105 | # * AND when the validation loss has improved self.cache_cadence 106 | # times since the last model caching 107 | if not self.c.debug_eval_row_interactions: 108 | if ((self.device is None or self.device == 0) and 109 | (self.num_valid_improvements_since_cache >= 110 | self.cache_cadence)): 111 | print( 112 | f'Validation loss has improved ' 113 | f'{self.num_valid_improvements_since_cache} times since ' 114 | f'last caching the model. Caching now.') 115 | self.cache_model( 116 | model=model, optimizer=optimizer, scaler=scaler, 117 | val_loss=val_loss, epoch=epoch, 118 | tradeoff_annealer=tradeoff_annealer) 119 | self.num_valid_improvements_since_cache = 0 120 | else: 121 | self.num_inc_valid_loss_epochs += 1 122 | 123 | # Disallow early stopping with patience == -1 124 | if end_experiment: 125 | del model 126 | return EarlyStopSignal.END, self.load_cached_model() 127 | elif self.patience == -1: 128 | return EarlyStopSignal.CONTINUE, None 129 | elif self.num_inc_valid_loss_epochs > self.patience: 130 | del model 131 | return EarlyStopSignal.STOP, self.load_cached_model() 132 | 133 | return EarlyStopSignal.CONTINUE, None 134 | 135 | def load_cached_model(self): 136 | print('\nLoading cached model.') 137 | 138 | # Initialize model and optimizer objects 139 | model, optimizer, scaler = init_model_opt_scaler( 140 | self.c, metadata=self.metadata, 141 | device=self.device) 142 | 143 | # Distribute model, if in distributed setting 144 | if self.c.mp_distributed: 145 | model = setup_ddp_model(model=model, c=self.c, device=self.device) 146 | 147 | # Load from checkpoint, populate state dicts 148 | checkpoint = torch.load(self.best_model_path, map_location=self.device) 149 | # Strict setting -- allows us to load saved attention maps 150 | # when we wish to visualize them 151 | model.load_state_dict(checkpoint['model_state_dict'], 152 | strict=(not self.c.viz_att_maps)) 153 | 154 | if self.c.viz_att_maps: 155 | optimizer = None 156 | scaler = None 157 | else: 158 | optimizer.load_state_dict(checkpoint['optimizer_state_dict']) 159 | scaler.load_state_dict(checkpoint['scaler_state_dict']) 160 | 161 | print( 162 | f'Successfully loaded cached model from best performing epoch ' 163 | f'{checkpoint["epoch"]}.') 164 | 165 | try: 166 | num_steps = checkpoint['num_steps'] 167 | except KeyError: 168 | num_steps = None 169 | 170 | return model, optimizer, scaler, num_steps 171 | 172 | def clear_cache_path(self): 173 | file_list = [ 174 | f for f in os.listdir(self.model_cache_path)] 175 | for f in file_list: 176 | os.remove(self.model_cache_path / f) 177 | 178 | def cache_model( 179 | self, model, optimizer, scaler, 180 | val_loss, epoch, tradeoff_annealer=None): 181 | if self.checkpoint_setting is None: 182 | return 183 | 184 | if self.checkpoint_setting not in [ 185 | 'best_model', 'all_checkpoints']: 186 | raise NotImplementedError 187 | 188 | # Delete all existing checkpoints 189 | if self.checkpoint_setting == 'best_model': 190 | print('Storing new best performing model.') 191 | self.clear_cache_path() 192 | 193 | val_loss = val_loss.item() 194 | checkpoint_dict = { 195 | 'epoch': epoch, 196 | 'model_state_dict': model.state_dict(), 197 | 'optimizer_state_dict': optimizer.state_dict(), 198 | 'scaler_state_dict': scaler.state_dict(), 199 | 'val_loss': val_loss} 200 | 201 | if tradeoff_annealer is not None: 202 | checkpoint_dict['num_steps'] = tradeoff_annealer.num_steps 203 | 204 | # Store the new model checkpoint 205 | self.best_model_path = self.model_cache_path / f'model_{epoch}.pt' 206 | 207 | # We encountered issues with the model being reliably checkpointed. 208 | # This is a clunky way of confirming it is / giving the script 209 | # "multiple tries", but, if it ain't broke... 210 | model_is_checkpointed = False 211 | counter = 0 212 | while model_is_checkpointed is False and counter < 10000: 213 | if counter % 10 == 0: 214 | print(f'Model checkpointing attempts: {counter}.') 215 | 216 | # Attempt to save 217 | torch.save(checkpoint_dict, self.best_model_path) 218 | 219 | # If we find the file there, continue on 220 | if os.path.isfile(self.best_model_path): 221 | model_is_checkpointed = True 222 | 223 | # If the file is not yet found, sleep to avoid bothering the server 224 | if model_is_checkpointed is False: 225 | sleep(0.5) 226 | 227 | counter += 1 228 | 229 | # # Save as a wandb artifact 230 | # artifact = wandb.Artifact(self.c.model_checkpoint_key, type='model') 231 | # artifact.add_file(str(self.best_model_path)) 232 | # self.wandb_run.log_artifact(artifact) 233 | # self.wandb_run.join() 234 | 235 | print( 236 | f'Stored epoch {epoch} model checkpoint to ' 237 | f'{self.best_model_path}.') 238 | print(f'Val loss: {val_loss}.') 239 | 240 | def is_model_checkpoint(self, file_name): 241 | return ( 242 | os.path.isfile(self.model_cache_path / file_name) and 243 | file_name.startswith('model') and file_name.endswith('.pt')) 244 | 245 | @staticmethod 246 | def get_epoch_from_checkpoint_name(checkpoint_name): 247 | return int(checkpoint_name.split('.')[0].split('_')[1]) 248 | 249 | def get_most_recent_checkpoint(self): 250 | if not os.path.isdir(self.model_cache_path): 251 | print( 252 | f'No cache path yet exists ' 253 | f'{self.model_cache_path}') 254 | return None 255 | 256 | checkpoint_names = [ 257 | file_or_dir for file_or_dir in os.listdir(self.model_cache_path) 258 | if self.is_model_checkpoint(file_or_dir)] 259 | 260 | if not checkpoint_names: 261 | print( 262 | f'Did not find a checkpoint at cache path ' 263 | f'{self.model_cache_path}') 264 | return None 265 | 266 | # We assume models stored later are strictly better (i.e. only 267 | # stored at an improvement in validation loss) 268 | max_checkpoint_epoch = max( 269 | [self.get_epoch_from_checkpoint_name(checkpoint_name) 270 | for checkpoint_name in checkpoint_names]) 271 | self.best_model_path = ( 272 | self.model_cache_path / f'model_{max_checkpoint_epoch}.pt') 273 | 274 | # Return the newest checkpointed model 275 | return max_checkpoint_epoch, self.load_cached_model() 276 | -------------------------------------------------------------------------------- /npt/utils/image_loading_utils.py: -------------------------------------------------------------------------------- 1 | """From https://github.com/ildoonet/pytorch-randaugment/blob/48b8f509c4bbda93bbe733d98b3fd052b6e4c8ae/RandAugment/data.py#L32""" 2 | 3 | import torch 4 | import torchvision 5 | from sklearn.model_selection import StratifiedShuffleSplit 6 | from torch.utils.data import SubsetRandomSampler, Sampler 7 | from torchvision.transforms import transforms 8 | 9 | _CIFAR_MEAN, _CIFAR_STD = (0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010) 10 | 11 | 12 | def get_dataloaders(dataset, batch, dataroot, c, split=0.15, split_idx=0): 13 | if 'cifar' in dataset: 14 | if c.model_image_random_crop_and_flip: 15 | print('Using random crops and flips in data augmentation.') 16 | transform_train = transforms.Compose([ 17 | transforms.RandomCrop(32, padding=4), 18 | transforms.RandomHorizontalFlip(), 19 | transforms.ToTensor(), 20 | transforms.Normalize(_CIFAR_MEAN, _CIFAR_STD), 21 | ]) 22 | else: 23 | print('NOT using random crops and flips in data augmentation.') 24 | transform_train = transforms.Compose([ 25 | transforms.ToTensor(), 26 | transforms.Normalize(_CIFAR_MEAN, _CIFAR_STD), 27 | ]) 28 | 29 | transform_test = transforms.Compose([ 30 | transforms.ToTensor(), 31 | transforms.Normalize(_CIFAR_MEAN, _CIFAR_STD), 32 | ]) 33 | else: 34 | raise ValueError('dataset=%s' % dataset) 35 | 36 | if dataset == 'cifar10': 37 | total_trainset = torchvision.datasets.CIFAR10(root=dataroot, train=True, download=True, transform=transform_train) 38 | testset = torchvision.datasets.CIFAR10(root=dataroot, train=False, download=True, transform=transform_test) 39 | else: 40 | raise ValueError('invalid dataset name=%s' % dataset) 41 | 42 | train_sampler = None 43 | if split > 0.0: 44 | sss = StratifiedShuffleSplit(n_splits=5, test_size=split, random_state=c.np_seed) 45 | sss = sss.split(list(range(len(total_trainset))), total_trainset.targets) 46 | for _ in range(split_idx + 1): 47 | train_idx, valid_idx = next(sss) 48 | 49 | train_sampler = SubsetRandomSampler(train_idx) 50 | valid_sampler = SubsetSampler(valid_idx) 51 | else: 52 | valid_sampler = SubsetSampler([]) 53 | 54 | data_loader_nprocs = c.data_loader_nprocs 55 | trainloader = torch.utils.data.DataLoader( 56 | total_trainset, batch_size=batch, shuffle=True if train_sampler is None else False, num_workers=data_loader_nprocs, pin_memory=True, 57 | sampler=train_sampler, drop_last=False) 58 | validloader = torch.utils.data.DataLoader( 59 | total_trainset, batch_size=batch, shuffle=False, num_workers=data_loader_nprocs, pin_memory=True, 60 | sampler=valid_sampler, drop_last=False) 61 | 62 | testloader = torch.utils.data.DataLoader( 63 | testset, batch_size=batch, shuffle=False, num_workers=data_loader_nprocs, pin_memory=True, 64 | drop_last=False 65 | ) 66 | return train_sampler, trainloader, validloader, testloader 67 | 68 | 69 | class SubsetSampler(Sampler): 70 | r"""Samples elements from a given list of indices, without replacement. 71 | Arguments: 72 | indices (sequence): a sequence of indices 73 | """ 74 | 75 | def __init__(self, indices): 76 | self.indices = indices 77 | 78 | def __iter__(self): 79 | return (i for i in self.indices) 80 | 81 | def __len__(self): 82 | return len(self.indices) 83 | -------------------------------------------------------------------------------- /npt/utils/logging_utils.py: -------------------------------------------------------------------------------- 1 | """Some logging utils.""" 2 | 3 | import time 4 | import torch 5 | import wandb 6 | 7 | 8 | class Logger: 9 | def __init__(self, c, optimizer, gpu, tradeoff_annealer): 10 | self.c = c 11 | self.optimizer = optimizer 12 | self.gpu = gpu # If not None, only log for GPU 0 13 | self.tradeoff_annealer = tradeoff_annealer 14 | 15 | def start_counting(self): 16 | self.train_start = time.time() 17 | self.checkpoint_start = self.train_start 18 | 19 | def log(self, train_loss, val_loss, test_loss, steps, epoch): 20 | dataset_mode_to_loss_dict = { 21 | 'train': train_loss, 22 | 'val': val_loss} 23 | if test_loss is not None: 24 | dataset_mode_to_loss_dict.update({'test': test_loss}) 25 | 26 | # Construct loggable dict 27 | wandb_loss_dict = self.construct_loggable_dict( 28 | dataset_mode_to_loss_dict) 29 | 30 | if self.tradeoff_annealer is not None: 31 | wandb_loss_dict['tradeoff'] = self.tradeoff_annealer.curr_tradeoff 32 | 33 | wandb_loss_dict['step'] = steps 34 | wandb_loss_dict['epoch'] = epoch 35 | wandb_loss_dict['lr'] = self.optimizer.param_groups[0]['lr'] 36 | wandb_loss_dict['checkpoint_time'] = ( 37 | f'{time.time() - self.checkpoint_start:.3f}') 38 | self.checkpoint_start = time.time() 39 | 40 | # Log to wandb 41 | if self.gpu is None or self.gpu == 0: 42 | wandb.log(wandb_loss_dict, step=steps) 43 | 44 | # Log to stdout 45 | self.print_loss_dict(wandb_loss_dict) 46 | 47 | return wandb_loss_dict 48 | 49 | def summary_log(self, loss_dict, new_min): 50 | # No summary metrics written 51 | if self.c.mp_distributed: 52 | return 0 53 | 54 | # Do not update summary metrics if not min (min already updated) 55 | if not new_min: 56 | return 0 57 | 58 | loss_dict.update({'time': time.time() - self.train_start}) 59 | # Always need to rewrite old summary loss dict, because wandb overrides 60 | # the summary dict when calling normal log 61 | lowest_dict = {f'best_{i}': j for i, j in loss_dict.items()} 62 | 63 | wandb.run.summary.update(lowest_dict) 64 | 65 | @staticmethod 66 | def safe_torch_to_float(val): 67 | if type(val) == torch.Tensor: 68 | return val.detach().cpu().numpy().item(0) 69 | else: 70 | return val 71 | 72 | @staticmethod 73 | def construct_loggable_dict(dataset_mode_to_loss_dict): 74 | wandb_loss_dict = dict() 75 | for dataset_mode, loss_dict in dataset_mode_to_loss_dict.items(): 76 | for key, value in loss_dict.items(): 77 | key = f'{dataset_mode}_{key}' 78 | if type(value) == dict: 79 | for key2, value2 in value.items(): 80 | joint_key = f'{key}_{key2}' 81 | wandb_loss_dict[joint_key] = ( 82 | Logger.safe_torch_to_float(value2)) 83 | else: 84 | wandb_loss_dict[key] = Logger.safe_torch_to_float(value) 85 | 86 | return wandb_loss_dict 87 | 88 | @staticmethod 89 | def print_loss_dict(loss_dict): 90 | train_keys = [] 91 | val_keys = [] 92 | test_keys = [] 93 | summary_keys = [] 94 | 95 | for key in loss_dict.keys(): 96 | if 'train' in key: 97 | train_keys.append(key) 98 | elif 'val' in key: 99 | val_keys.append(key) 100 | elif 'test' in key: 101 | test_keys.append(key) 102 | else: 103 | summary_keys.append(key) 104 | 105 | line = '' 106 | for key in summary_keys: 107 | line += f'{key} {loss_dict[key]} | ' 108 | line += f'\nTrain Stats\n' 109 | for key in train_keys: 110 | line += f'{key} {loss_dict[key]:.3f} | ' 111 | line += f'\nVal Stats\n' 112 | for key in val_keys: 113 | line += f'{key} {loss_dict[key]:.3f} | ' 114 | line += f'\nTest Stats\n' 115 | for key in test_keys: 116 | line += f'{key} {loss_dict[key]:.3f} | ' 117 | line += '\n' 118 | print(line) 119 | 120 | def intermediate_log(self, loss_dict, num_steps, batch_index, epoch): 121 | """Log during mini-batches.""" 122 | 123 | tb = 'train_batch' 124 | ld = loss_dict 125 | 126 | wandb_dict = dict( 127 | batch_index=batch_index, 128 | epoch=epoch) 129 | 130 | losses = dict() 131 | 132 | losses.update({ 133 | f'{tb}_total_loss': 134 | ld['total_loss']}) 135 | 136 | if tl := ld['label'].get('total_loss', False): 137 | losses.update({ 138 | f'{tb}_label_total_loss': tl}) 139 | 140 | if tl := ld['augmentation'].get('total_loss', False): 141 | losses.update({ 142 | f'{tb}_augmentation_total_loss': tl}) 143 | 144 | if val := ld['label'].get('cat_accuracy', False): 145 | losses.update({f'{tb}_label_accuracy': val}) 146 | if val := ld['label'].get('num_mse_loss', False): 147 | losses.update({f'{tb}_label_num_mse': val}) 148 | 149 | losses = {i: j.detach().cpu().item() for i, j in losses.items()} 150 | wandb_dict.update(losses) 151 | 152 | print(f'step: {num_steps}, {wandb_dict}') 153 | 154 | if self.gpu is None or self.gpu == 0: 155 | wandb.log(wandb_dict, step=num_steps) 156 | -------------------------------------------------------------------------------- /npt/utils/memory_utils.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | 3 | import sys 4 | 5 | 6 | def get_size(obj, seen=None): 7 | """Recursively finds size of objects""" 8 | size = sys.getsizeof(obj) 9 | if seen is None: 10 | seen = set() 11 | obj_id = id(obj) 12 | if obj_id in seen: 13 | return 0 14 | # Important mark as seen *before* entering recursion to gracefully handle 15 | # self-referential objects 16 | seen.add(obj_id) 17 | if isinstance(obj, dict): 18 | size += sum([get_size(v, seen) for v in obj.values()]) 19 | size += sum([get_size(k, seen) for k in obj.keys()]) 20 | elif hasattr(obj, '__dict__'): 21 | size += get_size(obj.__dict__, seen) 22 | elif hasattr(obj, '__iter__') and not isinstance(obj, (str, bytes, bytearray)): 23 | size += sum([get_size(i, seen) for i in obj]) 24 | return size 25 | -------------------------------------------------------------------------------- /npt/utils/model_init_utils.py: -------------------------------------------------------------------------------- 1 | import wandb 2 | from torch.cuda.amp import GradScaler 3 | from torch.nn.parallel import DistributedDataParallel as DDP 4 | 5 | from npt.model.npt import NPTModel 6 | from npt.utils.encode_utils import get_torch_dtype 7 | from npt.utils.train_utils import count_parameters, init_optimizer 8 | 9 | 10 | def init_model_opt_scaler_from_dataset(dataset, c, device=None): 11 | return init_model_opt_scaler( 12 | c, metadata=dataset.metadata, device=device) 13 | 14 | 15 | def init_model_opt_scaler(c, metadata, device=None): 16 | if device is None: 17 | device = c.exp_device 18 | 19 | model = NPTModel( 20 | c, metadata=metadata, device=device) 21 | 22 | model_torch_dtype = get_torch_dtype(dtype_name=c.model_dtype) 23 | model = model.to(device=device).type(model_torch_dtype) 24 | print(f'Model has {count_parameters(model)} parameters,' 25 | f'batch size {c.exp_batch_size}.') 26 | 27 | optimizer = init_optimizer( 28 | c=c, model_parameters=model.parameters(), device=device) 29 | print(f'Initialized "{c.exp_optimizer}" optimizer.') 30 | 31 | # Automatic Mixed Precision (AMP) 32 | # If c.model_amp is False, the GradScaler call becomes a no-op 33 | # so we can switch between default/mixed precision without if/else 34 | # statements. 35 | scaler = GradScaler(enabled=c.model_amp) 36 | if c.model_amp: 37 | print(f'Initialized gradient scaler for Automatic Mixed Precision.') 38 | 39 | return model, optimizer, scaler 40 | 41 | 42 | def setup_ddp_model(model, c, device): 43 | if not c.exp_azure_sweep and device == 0: 44 | wandb.watch(model, log="all", log_freq=10) 45 | 46 | # Deal with image patcher issues 47 | if c.model_image_n_patches: 48 | image_patcher = model.image_patcher.to(device=device) 49 | 50 | print(f'DDP with bucket size of {c.mp_bucket_cap_mb} MB.') 51 | 52 | # If we are not using train augmentation, we must "find unused params" 53 | # to avoid synchronizing gradients on the features 54 | find_unused_params = (c.model_augmentation_bert_mask_prob['train'] == 0) 55 | 56 | if find_unused_params: 57 | print('Finding unused params in DDP.') 58 | 59 | # Wrap model 60 | model = DDP( 61 | model, device_ids=[device], bucket_cap_mb=c.mp_bucket_cap_mb, 62 | find_unused_parameters=find_unused_params) 63 | 64 | if c.model_image_n_patches: 65 | model.image_patcher = image_patcher 66 | else: 67 | model.image_patcher = None 68 | 69 | return model 70 | -------------------------------------------------------------------------------- /npt/utils/optim_utils.py: -------------------------------------------------------------------------------- 1 | """ 2 | # Lookahead --> rewritten by lessw2020, but big thanks to Github @LonePatient and @RWightman for ideas from their code. 3 | # Lookahead paper --> MZhang,G Hinton https://arxiv.org/abs/1907.08610 4 | 5 | Lamb optimizer from 6 | https://github.com/cybertronai/pytorch-lamb/blob/master/pytorch_lamb/lamb.py. 7 | Paper: 8 | `Large Batch Optimization for Deep Learning: Training BERT in 76 minutes` - 9 | https://arxiv.org/abs/1904.00962 10 | 11 | Lookahead optimizer wrapper from 12 | https://github.com/mgrankin/over9000/blob/master/lookahead.py. 13 | Paper: 14 | `Lookahead Optimizer: k steps forward, 1 step back` - 15 | https://arxiv.org/abs/1907.08610 16 | """ 17 | 18 | import collections 19 | from collections import defaultdict 20 | 21 | import torch 22 | from torch.optim import Adam 23 | from torch.optim.optimizer import Optimizer 24 | from torch.utils.tensorboard import SummaryWriter 25 | 26 | 27 | class Lookahead(Optimizer): 28 | def __init__(self, base_optimizer, alpha=0.5, k=6): 29 | if not 0.0 <= alpha <= 1.0: 30 | raise ValueError(f'Invalid slow update rate: {alpha}') 31 | if not 1 <= k: 32 | raise ValueError(f'Invalid lookahead steps: {k}') 33 | defaults = dict(lookahead_alpha=alpha, lookahead_k=k, lookahead_step=0) 34 | self.base_optimizer = base_optimizer 35 | self.param_groups = self.base_optimizer.param_groups 36 | self.defaults = base_optimizer.defaults 37 | self.defaults.update(defaults) 38 | self.state = defaultdict(dict) 39 | # manually add our defaults to the param groups 40 | for name, default in defaults.items(): 41 | for group in self.param_groups: 42 | group.setdefault(name, default) 43 | 44 | def update_slow(self, group): 45 | for fast_p in group["params"]: 46 | if fast_p.grad is None: 47 | continue 48 | param_state = self.state[fast_p] 49 | if 'slow_buffer' not in param_state: 50 | param_state['slow_buffer'] = torch.empty_like(fast_p.data) 51 | param_state['slow_buffer'].copy_(fast_p.data) 52 | slow = param_state['slow_buffer'] 53 | slow.add_(group['lookahead_alpha'], fast_p.data - slow) 54 | fast_p.data.copy_(slow) 55 | 56 | def sync_lookahead(self): 57 | for group in self.param_groups: 58 | self.update_slow(group) 59 | 60 | def step(self, closure=None): 61 | loss = self.base_optimizer.step(closure) 62 | for group in self.param_groups: 63 | group['lookahead_step'] += 1 64 | if group['lookahead_step'] % group['lookahead_k'] == 0: 65 | self.update_slow(group) 66 | return loss 67 | 68 | def state_dict(self): 69 | fast_state_dict = self.base_optimizer.state_dict() 70 | slow_state = { 71 | (id(k) if isinstance(k, torch.Tensor) else k): v 72 | for k, v in self.state.items() 73 | } 74 | fast_state = fast_state_dict['state'] 75 | param_groups = fast_state_dict['param_groups'] 76 | return { 77 | 'state': fast_state, 78 | 'slow_state': slow_state, 79 | 'param_groups': param_groups, 80 | } 81 | 82 | def load_state_dict(self, state_dict): 83 | fast_state_dict = { 84 | 'state': state_dict['state'], 85 | 'param_groups': state_dict['param_groups'], 86 | } 87 | self.base_optimizer.load_state_dict(fast_state_dict) 88 | 89 | # We want to restore the slow state, but share param_groups reference 90 | # with base_optimizer. This is a bit redundant but least code 91 | slow_state_new = False 92 | if 'slow_state' not in state_dict: 93 | print('Loading state_dict from optimizer without Lookahead applied.') 94 | state_dict['slow_state'] = defaultdict(dict) 95 | slow_state_new = True 96 | slow_state_dict = { 97 | 'state': state_dict['slow_state'], 98 | 'param_groups': state_dict['param_groups'], # this is pointless but saves code 99 | } 100 | super(Lookahead, self).load_state_dict(slow_state_dict) 101 | self.param_groups = self.base_optimizer.param_groups # make both ref same container 102 | if slow_state_new: 103 | # reapply defaults to catch missing lookahead specific ones 104 | for name, default in self.defaults.items(): 105 | for group in self.param_groups: 106 | group.setdefault(name, default) 107 | 108 | 109 | def LookaheadAdam(params, alpha=0.5, k=6, *args, **kwargs): 110 | adam = Adam(params, *args, **kwargs) 111 | return Lookahead(adam, alpha, k) 112 | 113 | 114 | def log_lamb_rs(optimizer: Optimizer, event_writer: SummaryWriter, token_count: int): 115 | """Log a histogram of trust ratio scalars in across layers.""" 116 | results = collections.defaultdict(list) 117 | for group in optimizer.param_groups: 118 | for p in group['params']: 119 | state = optimizer.state[p] 120 | for i in ('weight_norm', 'adam_norm', 'trust_ratio'): 121 | if i in state: 122 | results[i].append(state[i]) 123 | 124 | for k, v in results.items(): 125 | event_writer.add_histogram(f'lamb/{k}', torch.tensor(v), token_count) 126 | 127 | 128 | class Lamb(Optimizer): 129 | r"""Implements Lamb algorithm. 130 | It has been proposed in `Large Batch Optimization for Deep Learning: Training BERT in 76 minutes`_. 131 | Arguments: 132 | params (iterable): iterable of parameters to optimize or dicts defining 133 | parameter groups 134 | lr (float, optional): learning rate (default: 1e-3) 135 | betas (Tuple[float, float], optional): coefficients used for computing 136 | running averages of gradient and its square (default: (0.9, 0.999)) 137 | eps (float, optional): term added to the denominator to improve 138 | numerical stability (default: 1e-8) 139 | weight_decay (float, optional): weight decay (L2 penalty) (default: 0) 140 | adam (bool, optional): always use trust ratio = 1, which turns this into 141 | Adam. Useful for comparison purposes. 142 | .. _Large Batch Optimization for Deep Learning: Training BERT in 76 minutes: 143 | https://arxiv.org/abs/1904.00962 144 | """ 145 | 146 | def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-6, 147 | weight_decay=0, adam=False): 148 | if not 0.0 <= lr: 149 | raise ValueError("Invalid learning rate: {}".format(lr)) 150 | if not 0.0 <= eps: 151 | raise ValueError("Invalid epsilon value: {}".format(eps)) 152 | if not 0.0 <= betas[0] < 1.0: 153 | raise ValueError("Invalid beta parameter at index 0: {}".format(betas[0])) 154 | if not 0.0 <= betas[1] < 1.0: 155 | raise ValueError("Invalid beta parameter at index 1: {}".format(betas[1])) 156 | defaults = dict(lr=lr, betas=betas, eps=eps, 157 | weight_decay=weight_decay) 158 | self.adam = adam 159 | super(Lamb, self).__init__(params, defaults) 160 | 161 | # @profile 162 | def step(self, closure=None): 163 | """Performs a single optimization step. 164 | Arguments: 165 | closure (callable, optional): A closure that reevaluates the model 166 | and returns the loss. 167 | """ 168 | loss = None 169 | if closure is not None: 170 | loss = closure() 171 | 172 | for group in self.param_groups: 173 | for p in group['params']: 174 | if p.grad is None: 175 | continue 176 | grad = p.grad.data 177 | if grad.is_sparse: 178 | raise RuntimeError('Lamb does not support sparse gradients, consider SparseAdam instad.') 179 | 180 | state = self.state[p] 181 | 182 | # State initialization 183 | if len(state) == 0: 184 | state['step'] = 0 185 | # Exponential moving average of gradient values 186 | state['exp_avg'] = torch.zeros_like(p.data) 187 | # Exponential moving average of squared gradient values 188 | state['exp_avg_sq'] = torch.zeros_like(p.data) 189 | 190 | exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq'] 191 | beta1, beta2 = group['betas'] 192 | 193 | state['step'] += 1 194 | 195 | # Decay the first and second moment running average coefficient 196 | # m_t 197 | exp_avg.mul_(beta1).add_(grad, alpha=1 - beta1) 198 | # v_t 199 | exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=1 - beta2) 200 | 201 | # Paper v3 does not use debiasing. 202 | # bias_correction1 = 1 - beta1 ** state['step'] 203 | # bias_correction2 = 1 - beta2 ** state['step'] 204 | # Apply bias to lr to avoid broadcast. 205 | step_size = group['lr'] # * math.sqrt(bias_correction2) / bias_correction1 206 | 207 | weight_norm = p.data.pow(2).sum().sqrt().clamp(0, 10) 208 | 209 | adam_step = exp_avg / exp_avg_sq.sqrt().add(group['eps']) 210 | if group['weight_decay'] != 0: 211 | adam_step.add_(p.data, alpha=group['weight_decay']) 212 | 213 | adam_norm = adam_step.pow(2).sum().sqrt() 214 | if weight_norm == 0 or adam_norm == 0: 215 | trust_ratio = 1 216 | else: 217 | trust_ratio = weight_norm / adam_norm 218 | state['weight_norm'] = weight_norm 219 | state['adam_norm'] = adam_norm 220 | state['trust_ratio'] = trust_ratio 221 | if self.adam: 222 | trust_ratio = 1 223 | 224 | p.data.add_(adam_step, alpha=-step_size * trust_ratio) 225 | 226 | return loss 227 | -------------------------------------------------------------------------------- /npt/utils/plotting.py: -------------------------------------------------------------------------------- 1 | from matplotlib import rcParams 2 | from matplotlib import rc 3 | import seaborn as sns 4 | 5 | """"Set some global params and plotting settings.""" 6 | 7 | # Font options 8 | rc('text', usetex=True) 9 | # rcParams['pdf.fonttype'] = 42 10 | # rcParams['ps.fonttype'] = 42 11 | fs = 9 12 | label_fs = fs - 1 13 | family = 'serif' 14 | rcParams['font.family'] = 'serif' 15 | rcParams['font.sans-serif'] = ['Times'] 16 | rcParams['font.size'] = fs 17 | 18 | prop = dict(size=fs) 19 | legend_kwargs = dict(frameon=True, prop=prop) 20 | new_kwargs = dict(prop=dict(size=fs-4)) 21 | 22 | # Styling 23 | c = 'black' 24 | rcParams.update({'axes.edgecolor': c, 'xtick.color': c, 'ytick.color': c}) 25 | rcParams.update({'axes.linewidth': 0.5}) 26 | linewidth = 3.25063 # in inches 27 | textwidth = 6.75133 28 | 29 | # Global Names (Sadly not always used) 30 | acquisition_step_label = 'Acquired Points' 31 | LABEL_ACQUIRED_DOUBLE = 'Acquired Points' 32 | LABEL_ACQUIRED_FULL = 'Number of Acquired Test Points' 33 | diff_to_empircal_label = r'Difference to Full Test Loss' 34 | std_diff_to_empirical_label = 'Standard Deviation of Estimator Error' 35 | sample_efficiency_label = 'Efficiency' 36 | LABEL_RANDOM = 'I.I.D. Acquisition' 37 | LABEL_STD = 'Median \n Squared Error' 38 | LABEL_RELATIVE_COST = 'Relative Labeling Cost' 39 | LABEL_MEAN_LOG = 'Mean Log Squared Error' 40 | # Color palette 41 | CB_color_cycle = ['#377eb8', '#ff7f00', '#4daf4a', 42 | '#f781bf', '#a65628', '#984ea3', 43 | '#999999', '#e41a1c', '#dede00'] 44 | CB_color_cycle = [CB_color_cycle[i] for i in [0, 1, 2, -2, 5, 4, 3, -3, -1]] 45 | cbpal = sns.palettes.color_palette(palette=CB_color_cycle) 46 | pal = sns.color_palette('colorblind') 47 | pal[5], pal[6], pal[-2] = cbpal[5], cbpal[6], cbpal[-1] -------------------------------------------------------------------------------- /npt/utils/preprocess_utils.py: -------------------------------------------------------------------------------- 1 | """Generic data utils invoked by dataset loaders.""" 2 | 3 | import numpy as np 4 | from scipy import sparse 5 | from sklearn.pipeline import Pipeline 6 | from sklearn.preprocessing import OneHotEncoder, KBinsDiscretizer 7 | 8 | 9 | # We create the preprocessing pipelines for both numeric and categorical data. 10 | numeric_transformer = Pipeline(steps=[ 11 | ('k-bin-discretize', KBinsDiscretizer(n_bins=10, strategy='quantile'))]) 12 | 13 | categorical_transformer = Pipeline(steps=[ 14 | ('onehot', OneHotEncoder(handle_unknown='ignore'))]) 15 | 16 | 17 | """Preprocessing Functions: Classification and Regression Datasets.""" 18 | 19 | 20 | def get_dense_from_dok(dok_matrix): 21 | return np.array( 22 | [list(key) for key in dok_matrix.keys()]) 23 | 24 | 25 | def get_matrix_from_rows(rows, cols, N, D): 26 | """ 27 | Constructs dense matrix with True in all locations where a label is. 28 | 29 | Labels occur in the specified rows, for each col in cols. 30 | """ 31 | matrix = np.zeros((N, D), dtype=np.bool_) 32 | for col in cols: 33 | matrix[rows, col] = True 34 | 35 | return matrix 36 | 37 | 38 | def get_entries_from_rows(rows, col, D): 39 | """ 40 | Given list of rows return list of [rows, col], where col is repeated over 41 | elements of list. 42 | """ 43 | if type(col) != int: 44 | raise NotImplementedError 45 | 46 | N = len(rows) 47 | entries = np.stack([rows, col * np.ones(N)], axis=-1) 48 | return entries 49 | 50 | 51 | def indices_to_matrix_entries(indices, n_cols): 52 | """Convert list of 1D indices to 2D matrix indices. 53 | 54 | 1D indices enumerate all positions in matrix, while 2D indices enumerate 55 | the rows and columns separately. 56 | Input: 57 | indices (np.array, N*n_cols): List of 1D indices. 58 | n_cols (int): Number of columns in target matrix. 59 | Returns: 60 | matrix_entries (np.array, (N, n_cols)): Matrix entries. Equal to a 61 | sparse representation. 62 | 63 | """ 64 | if type(indices) == list: 65 | indices = np.array(indices) 66 | rows = indices // n_cols 67 | cols = indices % n_cols 68 | matrix_entries = np.stack([rows, cols], 1) 69 | return matrix_entries 70 | 71 | 72 | def entries_to_dense(entries, N, D): 73 | """Convert array of binary masking entries to dense matrix. 74 | 75 | Input: 76 | entries (np.array, 2xM): List of sparse positions. 77 | N: Number of rows. 78 | D: Number of cols. 79 | 80 | Returns: 81 | dense_matrix (np.array, NxD): Dense matrix with 1 for all entries, 82 | else 0. 83 | """ 84 | # check for empty 85 | if entries.size == 0: 86 | return np.zeros((N, D)) 87 | 88 | data = np.ones(entries.shape[0]) 89 | sparse_matrix = sparse.csr_matrix( 90 | (data, (entries[:, 0], entries[:, 1])), shape=(N, D), dtype=np.bool_) 91 | dense_matrix = sparse_matrix.toarray().astype(dtype=np.bool_) 92 | 93 | assert set(np.where(dense_matrix == 1)[0]) == set(entries[:, 0]) 94 | assert set(np.where(dense_matrix == 1)[1]) == set(entries[:, 1]) 95 | 96 | return dense_matrix 97 | -------------------------------------------------------------------------------- /npt/utils/train_utils.py: -------------------------------------------------------------------------------- 1 | """Utils for model/optimizer initialization and training.""" 2 | import pprint 3 | 4 | from torch import optim 5 | 6 | from npt.utils.optim_utils import Lookahead, Lamb 7 | 8 | 9 | def init_optimizer(c, model_parameters, device): 10 | if 'default' in c.exp_optimizer: 11 | optimizer = optim.Adam(params=model_parameters, lr=c.exp_lr) 12 | elif 'lamb' in c.exp_optimizer: 13 | lamb = Lamb 14 | optimizer = lamb( 15 | model_parameters, lr=c.exp_lr, betas=(0.9, 0.999), 16 | weight_decay=c.exp_weight_decay, eps=1e-6) 17 | else: 18 | raise NotImplementedError 19 | 20 | if c.exp_optimizer.startswith('lookahead_'): 21 | optimizer = Lookahead(optimizer, k=c.exp_lookahead_update_cadence) 22 | 23 | return optimizer 24 | 25 | 26 | def get_sorted_params(model): 27 | param_count_and_name = [] 28 | for n,p in model.named_parameters(): 29 | if p.requires_grad: 30 | param_count_and_name.append((p.numel(), n)) 31 | 32 | pprint.pprint(sorted(param_count_and_name, reverse=True)) 33 | 34 | 35 | def count_parameters(model): 36 | r""" 37 | Due to Federico Baldassarre 38 | https://discuss.pytorch.org/t/how-do-i-check-the-number-of-parameters-of-a-model/4325/7 39 | """ 40 | return sum(p.numel() for p in model.parameters() if p.requires_grad) 41 | -------------------------------------------------------------------------------- /npt/utils/viz_att_maps.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import matplotlib.pyplot as plt 4 | import numpy as np 5 | import seaborn as sns 6 | from matplotlib.patches import Rectangle 7 | 8 | from npt.utils.eval_checkpoint_utils import EarlyStopCounter 9 | 10 | 11 | def plot_grid_query_pix(width, ax=None): 12 | if ax is None: 13 | plt.figure() 14 | ax = plt.gca() 15 | 16 | ax.set_xticks(np.arange(-width / 2, width / 2)) # , minor=True) 17 | ax.set_aspect(1) 18 | ax.set_yticks(np.arange(-width / 2, width / 2)) # , minor=True) 19 | ax.tick_params( 20 | axis="both", 21 | which="both", 22 | bottom=False, 23 | top=False, 24 | left=False, 25 | labelbottom=False, 26 | labelleft=False, 27 | ) 28 | ax.grid(True, alpha=0.5) 29 | 30 | # query pixel 31 | querry_pix = Rectangle(xy=(-0.5,-0.5), 32 | width=1, 33 | height=1, 34 | edgecolor="black", 35 | fc='None', 36 | lw=2) 37 | 38 | ax.add_patch(querry_pix); 39 | 40 | ax.set_xlim(-width / 2, width / 2) 41 | ax.set_ylim(-width / 2, width / 2) 42 | ax.set_aspect("equal") 43 | 44 | 45 | def plot_attention_layer(attention_probs, axes): 46 | """Plot the 2D attention probabilities for a particular MAB attention map.""" 47 | 48 | contours = np.array([0.9, 0.5]) 49 | linestyles = [":", "-"] 50 | flat_colors = ["#3498db", "#f1c40f", "#2ecc71", "#e74c3c", "#e67e22", "#9b59b6", "#34495e", "#1abc9c", "#95a5a6"] 51 | 52 | shape = attention_probs.shape 53 | num_heads, height, width = shape 54 | # attention_probs = attention_probs.reshape(width, height, num_heads) 55 | 56 | try: 57 | ax = axes[0] 58 | except: 59 | attention_prob_head = attention_probs[0].detach().cpu().numpy() 60 | sns.heatmap(attention_prob_head, ax=axes, square=True) 61 | axes.set_title(f'Head 1') 62 | return axes 63 | 64 | for head_index in range(num_heads): 65 | attention_prob_head = attention_probs[head_index].detach().cpu().numpy() 66 | sns.heatmap(attention_prob_head, ax=axes[head_index], square=True) 67 | axes[head_index].set_title(f'Head {head_index}') 68 | 69 | return axes 70 | 71 | 72 | def viz_att_maps(c, dataset): 73 | early_stop_counter = EarlyStopCounter( 74 | c=c, data_cache_prefix=dataset.model_cache_path, 75 | metadata=dataset.metadata, 76 | device=c.exp_device) 77 | 78 | # Initialize from checkpoint, if available 79 | num_steps = 0 80 | 81 | checkpoint = early_stop_counter.get_most_recent_checkpoint() 82 | if checkpoint is not None: 83 | checkpoint_epoch, ( 84 | model, optimizer, num_steps) = checkpoint 85 | else: 86 | raise Exception('Could not find a checkpoint!') 87 | 88 | dataset.set_mode(mode='test', epoch=num_steps) 89 | batch_dataset = dataset.cv_dataset 90 | batch_dict = next(batch_dataset) 91 | 92 | from npt.utils import debug 93 | 94 | if c.debug_row_interactions: 95 | print('Detected debug mode.' 96 | 'Modifying batch input to duplicate rows.') 97 | batch_dict = debug.modify_data(c, batch_dict, 'test', 0) 98 | 99 | # Run a forward pass 100 | masked_tensors = batch_dict['masked_tensors'] 101 | masked_tensors = [ 102 | masked_arr.to(device=c.exp_device) 103 | for masked_arr in masked_tensors] 104 | model.eval() 105 | model(masked_tensors) 106 | 107 | # Grab attention maps from SaveAttMaps modules 108 | # Collect metadata as we go 109 | layers = [] 110 | att_maps = [] 111 | 112 | for name, param in model.named_parameters(): 113 | if 'curr_att_maps' not in name: 114 | continue 115 | 116 | _, layer, _, _, _ = name.split('.') 117 | layers.append(int(layer)) 118 | att_maps.append(param) 119 | 120 | n_heads = c.model_num_heads 121 | 122 | from tensorboardX import SummaryWriter 123 | 124 | # create tensorboard writer 125 | # adapted from https://github.com/epfml/attention-cnn 126 | 127 | if not c.model_checkpoint_key: 128 | raise NotImplementedError 129 | 130 | save_path = os.path.join(c.viz_att_maps_save_path, c.model_checkpoint_key) 131 | tensorboard_writer = SummaryWriter( 132 | logdir=save_path, max_queue=100, flush_secs=10) 133 | print(f"Tensorboard logs saved in '{save_path}'") 134 | 135 | for i in range(len(att_maps)): 136 | layer_index = layers[i] 137 | att_map = att_maps[i] 138 | 139 | # If n_heads != att_map.size(0), we have attention over the 140 | # columns, which is applied to every 141 | # one of the batch dimension axes independently 142 | # e.g. we will have an attention map of shape (n_heads * N, D, D) 143 | # Just subsample a row for each head 144 | att_map_first_dim_size = att_map.size(0) 145 | if n_heads != att_map_first_dim_size: 146 | print('Subsampling attention over the columns.') 147 | print(f'Original size: {att_map.size()}') 148 | n_rows = att_map_first_dim_size // n_heads 149 | row_subsample_indices = [] 150 | for row_index in range(0, att_map_first_dim_size, n_rows): 151 | row_subsample_indices.append(row_index) 152 | 153 | att_map = att_map[row_subsample_indices, :, :] 154 | print(f'Final size: {att_map.size()}') 155 | 156 | fig, axes = plt.subplots(ncols=n_heads, figsize=(15 * n_heads, 15)) 157 | 158 | plot_attention_layer( 159 | att_map, axes=axes) 160 | if tensorboard_writer: 161 | tensorboard_writer.add_figure( 162 | f"attention/layer{layer_index}", fig, global_step=1) 163 | plt.close(fig) 164 | -------------------------------------------------------------------------------- /run.py: -------------------------------------------------------------------------------- 1 | """Load model, data and corresponding configs. Trigger training.""" 2 | import os 3 | import pathlib 4 | import sys 5 | 6 | import numpy as np 7 | import torch 8 | import torch.multiprocessing as mp 9 | import wandb 10 | 11 | from baselines.sklearn_tune import run_sklearn_hypertuning 12 | from npt.column_encoding_dataset import ColumnEncodingDataset 13 | from npt.configs import build_parser 14 | from npt.distribution import distributed_train_wrapper 15 | from npt.train import Trainer 16 | from npt.utils.model_init_utils import init_model_opt_scaler_from_dataset 17 | from npt.utils.viz_att_maps import viz_att_maps 18 | 19 | 20 | def main(args): 21 | """Load model, data, configs, start training.""" 22 | args, wandb_args = setup_args(args) 23 | run_cv(args=args, wandb_args=wandb_args) 24 | 25 | 26 | def setup_args(args): 27 | print('Configuring arguments...') 28 | 29 | if args.exp_azure_sweep: 30 | print('Removing old logs.') 31 | os.system('rm -r wandb') 32 | 33 | if args.np_seed == -1: 34 | args.np_seed = np.random.randint(0, 1000) 35 | if args.torch_seed == -1: 36 | args.torch_seed = np.random.randint(0, 1000) 37 | if args.exp_name is None: 38 | args.exp_name = f'{wandb.util.generate_id()}' 39 | if (args.exp_group is None) and (args.exp_n_runs > 1): 40 | # Assuming you want to do CV, group runs together. 41 | args.exp_group = f'{wandb.util.generate_id()}' 42 | print(f"Doing k-FOLD CV. Assigning group name {args.exp_group}.") 43 | 44 | if args.exp_azure_sweep: 45 | print("Azure sweep run!") 46 | # Our configs may run oom. That's okay. 47 | os.environ['WANDB_AGENT_DISABLE_FLAPPING'] = 'true' 48 | 49 | if not isinstance(args.model_augmentation_bert_mask_prob, dict): 50 | print('Reading dict for model_augmentation_bert_mask_prob.') 51 | # Well, this is ugly. But I blame it on argparse. 52 | # There is just no good way to parse dicts as arguments. 53 | # Good thing, I don't care about code security. 54 | exec( 55 | f'args.model_augmentation_bert_mask_prob = ' 56 | f'{args.model_augmentation_bert_mask_prob}') 57 | 58 | if not isinstance(args.model_label_bert_mask_prob, dict): 59 | print('Reading dict for model_augmentation_bert_mask_prob.') 60 | exec( 61 | f'args.model_label_bert_mask_prob = ' 62 | f'{args.model_label_bert_mask_prob}') 63 | 64 | if not args.model_bert_augmentation: 65 | for value in args.model_augmentation_bert_mask_prob.values(): 66 | assert value == 0 67 | for value in args.model_label_bert_mask_prob.values(): 68 | assert value == 1 69 | 70 | if (args.model_class == 'sklearn-baselines' and 71 | args.sklearn_model == 'TabNet' and not args.data_force_reload): 72 | raise ValueError('For TabNet, user must specify data_force_reload ' 73 | 'to encode data in a TabNet-compatible manner.') 74 | 75 | pathlib.Path(args.wandb_dir).mkdir(parents=True, exist_ok=True) 76 | 77 | # Set seeds 78 | np.random.seed(args.np_seed) 79 | 80 | # Resolve CUDA device(s) 81 | if args.exp_use_cuda and torch.cuda.is_available(): 82 | if args.exp_device is not None: 83 | print(f'Running model with CUDA on device {args.exp_device}.') 84 | exp_device = args.exp_device 85 | else: 86 | print(f'Running model with CUDA') 87 | exp_device = 'cuda:0' 88 | else: 89 | print('Running model on CPU.') 90 | exp_device = 'cpu' 91 | 92 | args.exp_device = exp_device 93 | 94 | wandb_args = dict( 95 | project=args.project, 96 | entity=args.entity, 97 | dir=args.wandb_dir, 98 | reinit=True, 99 | name=args.exp_name, 100 | group=args.exp_group) 101 | 102 | return args, wandb_args 103 | 104 | 105 | def run_cv(args, wandb_args): 106 | 107 | if args.mp_distributed: 108 | wandb_run = None 109 | c = args 110 | else: 111 | wandb_run = wandb.init(**wandb_args) 112 | args.cv_index = 0 113 | wandb.config.update(args, allow_val_change=True) 114 | c = wandb.config 115 | 116 | if c.model_class == 'NPT': 117 | run_cv_splits(wandb_args, args, c, wandb_run) 118 | elif c.model_class == 'sklearn-baselines': 119 | run_sklearn_hypertuning( 120 | ColumnEncodingDataset(c), wandb_args, args, c, wandb_run) 121 | 122 | 123 | def run_cv_splits(wandb_args, args, c, wandb_run): 124 | 125 | dataset = ColumnEncodingDataset(c) 126 | 127 | ####################################################################### 128 | # Distributed Setting 129 | if c.mp_distributed: 130 | torch.manual_seed(c.torch_seed) 131 | 132 | # Fix from 133 | # https://github.com/facebookresearch/maskrcnn-benchmark/issues/103 134 | # torch.multiprocessing.set_sharing_strategy('file_system') 135 | 136 | dataset.load_next_cv_split() 137 | dataset.dataset_gen = None 138 | args = {'dataset': dataset, 'c': c, 'wandb_args': wandb_args} 139 | os.environ['MASTER_ADDR'] = 'localhost' 140 | os.environ['MASTER_PORT'] = '8888' 141 | mp.spawn( 142 | distributed_train_wrapper, nprocs=c.mp_gpus, args=(args,), 143 | join=True) 144 | mp.set_start_method('fork') 145 | return 146 | 147 | starting_cv_index = 0 148 | total_n_cv_splits = min(dataset.n_cv_splits, c.exp_n_runs) 149 | 150 | # Since we're doing CV by default, model init is in a loop. 151 | for cv_index in range(starting_cv_index, total_n_cv_splits): 152 | print(f'CV Index: {cv_index}') 153 | 154 | print(f'Train-test Split {cv_index + 1}/{dataset.n_cv_splits}') 155 | 156 | if c.exp_n_runs < dataset.n_cv_splits: 157 | print( 158 | f'c.exp_n_runs = {c.exp_n_runs}. ' 159 | f'Stopping at {c.exp_n_runs} splits.') 160 | 161 | # New wandb logger for each run 162 | if cv_index > 0: 163 | wandb_args['name'] = f'{wandb.util.generate_id()}' 164 | args.exp_name = wandb_args['name'] 165 | args.cv_index = cv_index 166 | wandb_run = wandb.init(**wandb_args) 167 | wandb.config.update(args, allow_val_change=True) 168 | 169 | ####################################################################### 170 | # Load New CV Split 171 | dataset.load_next_cv_split() 172 | 173 | if c.viz_att_maps: 174 | print('Attempting to visualize attention maps.') 175 | return viz_att_maps(c, dataset) 176 | 177 | if c.model_class == 'DKL': 178 | print(f'Running DKL on dataset {c.data_set}.') 179 | from baselines.models.dkl_run import main 180 | return main(c, dataset) 181 | 182 | ####################################################################### 183 | # Initialise Model 184 | model, optimizer, scaler = init_model_opt_scaler_from_dataset( 185 | dataset=dataset, c=c, device=c.exp_device) 186 | 187 | # if not c.exp_azure_sweep: 188 | # wandb.watch(model, log="all", log_freq=10) 189 | 190 | ####################################################################### 191 | # Run training 192 | trainer = Trainer( 193 | model=model, optimizer=optimizer, scaler=scaler, 194 | c=c, wandb_run=wandb_run, cv_index=cv_index, dataset=dataset) 195 | trainer.train_and_eval() 196 | 197 | wandb_run.finish() 198 | 199 | 200 | if __name__ == '__main__': 201 | parser = build_parser() 202 | args = parser.parse_args() 203 | 204 | main(args) 205 | -------------------------------------------------------------------------------- /scripts/image_data.sh: -------------------------------------------------------------------------------- 1 | # MNIST 2 | python run.py --data_set mnist --exp_eval_every_n 1 --exp_eval_every_epoch_or_steps epochs --exp_optimizer_warmup_proportion 0.3 --exp_optimizer lookahead_lamb --exp_azure_sweep True --metrics_auroc False --model_num_heads 8 --model_stacking_depth 8 --exp_lr 1e-3 --exp_scheduler flat_and_anneal --exp_num_total_steps 500000 --exp_gradient_clipping 1 --exp_batch_size 512 --model_dim_hidden 16 --model_augmentation_bert_mask_prob "dict(train=0.15, val=0, test=0)" --exp_tradeoff 1 --exp_tradeoff_annealing cosine --model_checkpoint_key mnist__dim_hidden_16__feature_mask__label_mask__linear --model_image_n_patches 49 --model_image_patch_type linear --model_image_n_channels 1 --model_label_bert_mask_prob "dict(train=0.15, val=0.0, test=0.0)" 3 | # CIFAR-10 4 | python run.py --data_set cifar10 --exp_eval_every_n 1 --exp_eval_every_epoch_or_steps epochs --exp_optimizer_warmup_proportion 0.7 --exp_optimizer lookahead_lamb --exp_azure_sweep True --metrics_auroc False --exp_print_every_nth_forward 100 --model_num_heads 8 --model_stacking_depth 8 --exp_lr 1e-3 --exp_scheduler flat_and_anneal --exp_num_total_steps 1000000 --exp_gradient_clipping 1 --exp_batch_size 512 --model_dim_hidden 64 --model_augmentation_bert_mask_prob "dict(train=0.15, val=0, test=0)" --exp_tradeoff 1 --exp_tradeoff_annealing cosine --model_checkpoint_key cifar10__bs_512__feature_mask__linear__flip_crop --data_force_reload True --model_image_n_patches 64 --model_image_patch_type linear --model_image_n_channels 3 --model_image_n_classes 10 --model_image_random_crop_and_flip True 5 | -------------------------------------------------------------------------------- /scripts/row_corruption_tests.sh: -------------------------------------------------------------------------------- 1 | ### UCI Class Reg 2 | 3 | ## Small Data 4 | # Breast Cancer 5 | python run.py --data_set breast-cancer --exp_eval_every_n 1 --exp_eval_every_epoch_or_steps epochs --exp_optimizer_warmup_proportion 0.5 --exp_optimizer lookahead_lamb --exp_azure_sweep True --metrics_auroc True --model_num_heads 8 --model_stacking_depth 8 --exp_lr 5e-4 --exp_scheduler flat_and_anneal --exp_num_total_steps 20000 --exp_gradient_clipping 1 --exp_batch_size -1 --model_dim_hidden 32 --model_augmentation_bert_mask_prob "dict(train=0.15, val=0, test=0)" --exp_tradeoff 1 --exp_tradeoff_annealing cosine --model_label_bert_mask_prob "dict(train=0.1, val=0.0, test=0.0)" --model_checkpoint_key breast_cancer__dim_hidden_32__bs_-1__feature_mask__label_mask_10 --exp_test_perc 0.1 --exp_val_perc 0.2 --debug_eval_row_interactions True --exp_load_from_checkpoint True --debug_eval_row_interactions_timer 1000000 6 | # Boston 7 | python run.py --data_set boston-housing --exp_eval_every_n 1 --exp_eval_every_epoch_or_steps epochs --exp_optimizer_warmup_proportion 0.5 --exp_optimizer lookahead_lamb --exp_azure_sweep True --metrics_auroc False --model_num_heads 8 --model_stacking_depth 8 --exp_lr 1e-3 --exp_scheduler flat_and_anneal --exp_num_total_steps 4000 --exp_gradient_clipping 1 --exp_batch_size -1 --model_dim_hidden 128 --model_augmentation_bert_mask_prob "dict(train=0.15, val=0, test=0)" --exp_tradeoff 1 --exp_tradeoff_annealing cosine --model_label_bert_mask_prob "dict(train=0.5, val=0.0, test=0.0)" --model_checkpoint_key boston__bs_-1__feature_mask__label_mask_50 --exp_test_perc 0.1 --exp_val_perc 0.2 --debug_eval_row_interactions True --exp_load_from_checkpoint True --debug_eval_row_interactions_timer 1000000 8 | # Concrete 9 | python run.py --data_set concrete --exp_eval_every_n 1 --exp_eval_every_epoch_or_steps epochs --exp_optimizer_warmup_proportion 0.5 --exp_optimizer lookahead_lamb --exp_azure_sweep True --metrics_auroc False --model_num_heads 8 --model_stacking_depth 8 --exp_lr 1e-3 --exp_scheduler flat_and_anneal --exp_num_total_steps 10000 --exp_gradient_clipping 1 --exp_batch_size -1 --model_dim_hidden 128 --model_augmentation_bert_mask_prob "dict(train=0.15, val=0, test=0)" --exp_tradeoff 1 --exp_tradeoff_annealing cosine --model_label_bert_mask_prob "dict(train=0.5, val=0.0, test=0.0)" --model_checkpoint_key concrete__bs_-1__feature_mask__label_mask_50 --exp_test_perc 0.1 --exp_val_perc 0.2 --debug_eval_row_interactions True --exp_load_from_checkpoint True --debug_eval_row_interactions_timer 1000000 10 | # Yacht 11 | python run.py --data_set yacht --exp_eval_every_n 1 --exp_eval_every_epoch_or_steps epochs --exp_optimizer_warmup_proportion 0.5 --exp_optimizer lookahead_lamb --exp_azure_sweep True --metrics_auroc False --model_num_heads 8 --model_stacking_depth 8 --exp_lr 1e-3 --exp_scheduler flat_and_anneal --exp_num_total_steps 20000 --exp_gradient_clipping 1 --exp_batch_size -1 --model_dim_hidden 128 --model_augmentation_bert_mask_prob "dict(train=0.15, val=0, test=0)" --exp_tradeoff 1 --exp_tradeoff_annealing cosine --model_label_bert_mask_prob "dict(train=0.1, val=0.0, test=0.0)" --model_checkpoint_key yacht__bs_-1__feature_mask__label_mask_10 --exp_test_perc 0.1 --exp_val_perc 0.2 --debug_eval_row_interactions True --exp_load_from_checkpoint True --debug_eval_row_interactions_timer 1000000 12 | 13 | ## Medium and Large Data 14 | 15 | # Protein 16 | python run.py --data_set protein --data_loader_nprocs=4 --exp_eval_every_n 5 --exp_eval_every_epoch_or_steps epochs --exp_optimizer_warmup_proportion 0.7 --exp_optimizer lookahead_lamb --exp_azure_sweep True --metrics_auroc False --exp_print_every_nth_forward 100 --model_num_heads 8 --model_stacking_depth 8 --exp_lr 1e-3 --exp_scheduler flat_and_anneal --exp_num_total_steps 400000 --exp_gradient_clipping 1 --exp_batch_size 2048 --model_dim_hidden 128 --model_augmentation_bert_mask_prob "dict(train=0.15, val=0, test=0)" --exp_tradeoff 1 --exp_tradeoff_annealing cosine --model_label_bert_mask_prob "dict(train=0.15, val=0.0, test=0.0)" --model_checkpoint_key protein__bs_2048__feature_mask__label_mask --debug_eval_row_interactions True --exp_load_from_checkpoint True --debug_eval_row_interactions_timer 1000000 17 | # Kick 18 | python run.py --data_set kick --data_loader_nprocs=4 --exp_eval_every_n 5 --exp_eval_every_epoch_or_steps epochs --exp_optimizer_warmup_proportion 0.7 --exp_optimizer lookahead_lamb --exp_azure_sweep True --metrics_auroc True --exp_print_every_nth_forward 100 --model_num_heads 8 --model_stacking_depth 8 --exp_lr 1e-3 --exp_scheduler flat_and_anneal --exp_num_total_steps 250000 --exp_gradient_clipping 1 --exp_batch_size 4096 --model_dim_hidden 64 --model_augmentation_bert_mask_prob "dict(train=0.15, val=0, test=0)" --exp_tradeoff 1 --exp_tradeoff_annealing cosine --model_label_bert_mask_prob "dict(train=0.5, val=0.0, test=0.0)" --model_checkpoint_key kick__bs_4096__feature_mask__label_mask --debug_eval_row_interactions True --exp_load_from_checkpoint True --debug_eval_row_interactions_timer 1000000 19 | # Poker Hand 20 | python run.py --data_set poker-hand --exp_eval_every_n 5 --exp_eval_every_epoch_or_steps epochs --exp_optimizer_warmup_proportion 0.7 --exp_optimizer lookahead_lamb --exp_azure_sweep True --metrics_auroc False --exp_print_every_nth_forward 100 --model_num_heads 8 --model_stacking_depth 8 --exp_lr 1e-3 --exp_scheduler flat_and_anneal --exp_num_total_steps 200000 --exp_gradient_clipping 1 --exp_batch_size 4096 --model_dim_hidden 64 --model_augmentation_bert_mask_prob "dict(train=0.15, val=0, test=0)" --exp_tradeoff 1 --exp_tradeoff_annealing cosine --model_label_bert_mask_prob "dict(train=0.5, val=0.0, test=0.0)" --model_checkpoint_key poker_hand__bs_4096__feature_mask__label_mask_0.5__class_balance --exp_batch_class_balancing True --debug_eval_row_interactions True --exp_load_from_checkpoint True --debug_eval_row_interactions_timer 1000000 21 | # Income 22 | python run.py --data_set income --data_loader_nprocs=2 --exp_eval_every_n 5 --exp_eval_every_epoch_or_steps epochs --exp_optimizer_warmup_proportion 0.7 --exp_optimizer lookahead_lamb --exp_azure_sweep True --metrics_auroc True --exp_print_every_nth_forward 100 --model_num_heads 8 --model_stacking_depth 8 --exp_lr 1e-3 --exp_scheduler flat_and_anneal --exp_num_total_steps 2000000 --exp_gradient_clipping 1 --exp_batch_size 2048 --model_dim_hidden 64 --model_augmentation_bert_mask_prob "dict(train=0.0, val=0.0, test=0.0)" --exp_tradeoff 0 --exp_tradeoff_annealing constant --model_label_bert_mask_prob "dict(train=0.15, val=0.0, test=0.0)" --model_checkpoint_key income__bs_2048__label_mask --debug_eval_row_interactions True --exp_load_from_checkpoint True --debug_eval_row_interactions_timer 1000000 23 | # Forest Cover 24 | python run.py --data_set forest-cover --exp_eval_every_n 5 --exp_eval_every_epoch_or_steps epochs --exp_optimizer_warmup_proportion 0.01 --exp_optimizer lookahead_lamb --exp_azure_sweep True --metrics_auroc True --exp_print_every_nth_forward 1000 --model_num_heads 8 --model_stacking_depth 8 --exp_lr 5e-3 --exp_scheduler polynomial_decay_warmup --exp_num_total_steps 800000 --exp_gradient_clipping 1 --exp_batch_size 1800 --model_dim_hidden 64 --model_augmentation_bert_mask_prob "dict(train=0.15, val=0, test=0)" --exp_tradeoff 1 --exp_tradeoff_annealing cosine --model_label_bert_mask_prob "dict(train=0.5, val=0.0, test=0.0)" --model_checkpoint_key forest_cover__bs_2048__feature_mask__label_mask_0.5__poly_decay__class_balance --exp_batch_class_balancing True --debug_eval_row_interactions True --exp_load_from_checkpoint True --debug_eval_row_interactions_timer 1000000 25 | # Higgs 26 | python run.py --data_set higgs --data_loader_nprocs=4 --exp_eval_every_n 1 --exp_eval_every_epoch_or_steps epochs --exp_optimizer_warmup_proportion 0.7 --exp_optimizer lookahead_lamb --exp_azure_sweep True --metrics_auroc True --exp_print_every_nth_forward 100 --model_num_heads 8 --model_stacking_depth 8 --exp_lr 1e-3 --exp_scheduler flat_and_anneal --exp_num_total_steps 500000 --exp_gradient_clipping 1 --exp_batch_size 4096 --model_dim_hidden 64 --model_augmentation_bert_mask_prob "dict(train=0.15, val=0, test=0)" --exp_tradeoff 1 --exp_tradeoff_annealing cosine --model_checkpoint_key higgs__bs_4096__feature_mask --debug_eval_row_interactions True --exp_load_from_checkpoint True --debug_eval_row_interactions_timer 1000000 27 | 28 | ## Image Data 29 | # MNIST 30 | python run.py --data_set mnist --exp_eval_every_n 1 --exp_eval_every_epoch_or_steps epochs --exp_optimizer_warmup_proportion 0.3 --exp_optimizer lookahead_lamb --exp_azure_sweep True --metrics_auroc False --model_num_heads 8 --model_stacking_depth 8 --exp_lr 1e-3 --exp_scheduler flat_and_anneal --exp_num_total_steps 500000 --exp_gradient_clipping 1 --exp_batch_size 512 --model_dim_hidden 16 --model_augmentation_bert_mask_prob "dict(train=0.15, val=0, test=0)" --exp_tradeoff 1 --exp_tradeoff_annealing cosine --model_checkpoint_key mnist__dim_hidden_16__feature_mask__label_mask__linear --model_image_n_patches 49 --model_image_patch_type linear --model_image_n_channels 1 --model_label_bert_mask_prob "dict(train=0.15, val=0.0, test=0.0)" --debug_eval_row_interactions True --exp_load_from_checkpoint True --debug_eval_row_interactions_timer 1000000 31 | # CIFAR-10 32 | python run.py --data_set cifar10 --exp_eval_every_n 1 --exp_eval_every_epoch_or_steps epochs --exp_optimizer_warmup_proportion 0.7 --exp_optimizer lookahead_lamb --exp_azure_sweep True --metrics_auroc False --exp_print_every_nth_forward 100 --model_num_heads 8 --model_stacking_depth 8 --exp_lr 1e-3 --exp_scheduler flat_and_anneal --exp_num_total_steps 1000000 --exp_gradient_clipping 1 --exp_batch_size 512 --model_dim_hidden 64 --model_augmentation_bert_mask_prob "dict(train=0.15, val=0, test=0)" --exp_tradeoff 1 --exp_tradeoff_annealing cosine --model_checkpoint_key cifar10__bs_512__feature_mask__linear__flip_crop --data_force_reload True --model_image_n_patches 64 --model_image_patch_type linear --model_image_n_channels 3 --model_image_n_classes 10 --model_image_random_crop_and_flip True --debug_eval_row_interactions True --exp_load_from_checkpoint True --debug_eval_row_interactions_timer 1000000 33 | --------------------------------------------------------------------------------