├── .github └── workflows │ └── python-app.yml ├── .gitignore ├── LICENSE ├── Makefile ├── README.md ├── benchmarks ├── README.md ├── covertype.py └── covertype_opt_inf.py ├── data └── test │ └── covtype_sample.csv ├── examples └── train_mnist.py ├── images ├── tabnet.png └── virtual_bs_vs_gbn.png ├── local └── tuner_orig_dataset.py ├── requirements.txt ├── scripts └── install.sh ├── setup.py ├── tabnet ├── __init__.py ├── callbacks │ ├── __init__.py │ ├── lrfinder.py │ └── tensorboard.py ├── datasets │ ├── __init__.py │ └── covertype.py ├── models │ ├── __init__.py │ ├── classify.py │ ├── gbn.py │ ├── model.py │ ├── transformers.py │ └── utils.py ├── schedules │ ├── __init__.py │ └── decay_with_warmup.py └── utils.py └── tests ├── test_classify.py ├── test_custom_bn.py ├── test_dataset.py ├── test_lr_finder.py └── test_tabnet.py /.github/workflows/python-app.yml: -------------------------------------------------------------------------------- 1 | # This workflow will install Python dependencies, run tests and lint with a single version of Python 2 | # For more information see: https://help.github.com/actions/language-and-framework-guides/using-python-with-github-actions 3 | 4 | name: Python application 5 | 6 | on: 7 | push: 8 | branches: [ master ] 9 | pull_request: 10 | branches: [ master ] 11 | 12 | jobs: 13 | build: 14 | 15 | runs-on: ubuntu-latest 16 | 17 | steps: 18 | - uses: actions/checkout@v2 19 | - name: Set up Python 3.8 20 | uses: actions/setup-python@v2 21 | with: 22 | python-version: 3.8 23 | - name: Install dependencies 24 | run: | 25 | python -m pip install --upgrade pip 26 | pip install black pytest wheel 27 | pip install . 28 | - name: Lint with black 29 | run: | 30 | black . 31 | - name: Test with pytest 32 | run: | 33 | pytest . 34 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | pip-wheel-metadata/ 24 | share/python-wheels/ 25 | *.egg-info/ 26 | .installed.cfg 27 | *.egg 28 | MANIFEST 29 | 30 | # PyInstaller 31 | # Usually these files are written by a python script from a template 32 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 33 | *.manifest 34 | *.spec 35 | 36 | # Installer logs 37 | pip-log.txt 38 | pip-delete-this-directory.txt 39 | 40 | # Unit test / coverage reports 41 | htmlcov/ 42 | .tox/ 43 | .nox/ 44 | .coverage 45 | .coverage.* 46 | .cache 47 | nosetests.xml 48 | coverage.xml 49 | *.cover 50 | *.py,cover 51 | .hypothesis/ 52 | .pytest_cache/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | target/ 76 | 77 | # Jupyter Notebook 78 | .ipynb_checkpoints 79 | 80 | # IPython 81 | profile_default/ 82 | ipython_config.py 83 | 84 | # pyenv 85 | .python-version 86 | 87 | # pipenv 88 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 89 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 90 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 91 | # install all needed dependencies. 92 | #Pipfile.lock 93 | 94 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 95 | __pypackages__/ 96 | 97 | # Celery stuff 98 | celerybeat-schedule 99 | celerybeat.pid 100 | 101 | # SageMath parsed files 102 | *.sage.py 103 | 104 | # Environments 105 | .env 106 | .venv 107 | env/ 108 | venv/ 109 | ENV/ 110 | env.bak/ 111 | venv.bak/ 112 | 113 | # Spyder project settings 114 | .spyderproject 115 | .spyproject 116 | 117 | # Rope project settings 118 | .ropeproject 119 | 120 | # mkdocs documentation 121 | /site 122 | 123 | # mypy 124 | .mypy_cache/ 125 | .dmypy.json 126 | dmypy.json 127 | 128 | # Pyre type checker 129 | .pyre/ 130 | 131 | # vscode stuff 132 | settings.json 133 | launch.json 134 | 135 | # don't commit notebooks 136 | *.ipynb 137 | 138 | data/ 139 | 140 | .DS_Store 141 | 142 | .logs 143 | .search 144 | .outs 145 | .vscode 146 | .tmp -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2021 tensorflow-tabnet 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /Makefile: -------------------------------------------------------------------------------- 1 | lint: 2 | black . 3 | 4 | test: 5 | pytest -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # TensorFlow TabNet 2 | 3 | A TensorFlow 2.X implementation of the paper [TabNet](https://arxiv.org/abs/1908.07442). 4 | 5 | ![TabNet](images/tabnet.png) 6 | 7 | TabNet is already available for TF2 [here](https://github.com/titu1994/tf-TabNet). However, an error in the batch normalizations of the shared feature transformer blocks makes it unable to learn. In the original code, only the fully connected layers weights are shared, not the batch normalizations which remain unique. When sharing the batch normalizations, training of TabNet on the datasets & hyperparameters of the original paper yields extremely bad results (if able to learn at all). 8 | 9 | Moreover, the implementation uses `virtual_batch_size` from `tf.keras.layers.BatchNormalization` which has a couple of major issues: 10 | - Less updates of the batch normalizations with respect to a vanilla Ghost Batch Normalization. Which makes the training unstable and longer. 11 | - Does not allow batch size which can't be divided by the virtual batch size used during training (even for inference which should be independent of batch size). 12 | 13 | Below is a plot of the training accuracy for a model trained with a true Ghost Batch Normalization and one trained with the incorrect `virtual_batch_size` argument from Keras: 14 | 15 | ![GitHub Logo](images/virtual_bs_vs_gbn.png) 16 | 17 | Probably not aware of the issues introduced above, the implementation proposes to use [Group Normalization](https://arxiv.org/abs/1803.08494) instead of [Ghost Batch Normalization](https://arxiv.org/abs/1705.08741) which does make things better and able to learn. However, no comparision of the results obtained is proposed. Are they even close to the original ones? Can the model really generalize and obtain state of the art results? 18 | 19 | Therefore, a new correct and tested TF2 implementation of TabNet is proposed. It not only ports the original code but take advantage of TF2 modular approach to make it easier to finetune and train a TabNet model on different tasks. 20 | 21 | Currently in development but can already be used. 22 | 23 | ## Setup 24 | 25 | ### Install 26 | Since this project is still under development, it is best to install it from the repo directly as follow: 27 | 28 | ```bash 29 | pip install git+https://github.com/ostamand/tensorflow-tabnet.git 30 | ``` 31 | 32 | You can then import & train a classifier using: 33 | 34 | ```python 35 | from tabnet.modeles.classify import TabNetClassifier 36 | ``` 37 | 38 | ### For development 39 | 40 | ```bash 41 | python -m venv venv 42 | source venv/bin/activate 43 | chmod +x scripts/install.sh; ./scripts/install.sh 44 | ``` 45 | 46 | ## Dataset 47 | - http://archive.ics.uci.edu/ml//datasets/Covertype 48 | - https://www.kaggle.com/uciml/forest-cover-type-dataset/data# 49 | 50 | ## Reference 51 | - [TabNet paper](https://arxiv.org/abs/1908.07442) 52 | - [original TF1 code base](https://github.com/google-research/google-research/tree/master/tabnet) -------------------------------------------------------------------------------- /benchmarks/README.md: -------------------------------------------------------------------------------- 1 | # Benchmarks 2 | 3 | To confirm the implementation of TabNet is correct, a series of benchmarks on are compared to the original paper results. Most of the time the exact same hyperparameters are used (indicated when it is not the case). Some differences are expected since dataset splits are different. 4 | 5 | ## Covertype 6 | 7 | Classification of forest cover type from carthographic variables. 8 | 9 | To run the training: 10 | 11 | ```bash 12 | python3 benchmarks/covertype.py 13 | ``` 14 | 15 | The following hyperparameters are used by default (can easily be changed thru arguments to the `covertype.py` script) 16 | 17 | | Hyperparameter | Value | 18 | | -------------- | ----- | 19 | | Feature dim. | 64 | 20 | | Output dim. | 64 | 21 | | Sparsity Coeff.| 0.0001| 22 | | Batch Size | 163284| 23 | | Virtual batch size | 512 | 24 | | Batch Norm. Momentum | 0.7 | 25 | | Number of steps | 5 | 26 | | Relaxation factor | 1.5 | 27 | | Minimum learning rate | 1e-6| 28 | | Decay steps | 500 | 29 | | Total steps | 21k only (was 130k) | 30 | | Clip norm. | 2.0 | 31 | | Dropout rate | 0.2 | 32 | 33 | 34 | Here is a summary of the changes with respect to the paper implementation 35 | 36 | - Add warmup which helps generalization when a large batch size is used [(reference)](https://arxiv.org/abs/1906.03548) 37 | - Add dropout on the classifier head 38 | - Add an option to used infernce example weighing by providing a `alpha > 0` during inference [(reference)](https://arxiv.org/abs/1906.03548) 39 | 40 | To optimize the inference weighting `alpha` on the validation dataset use: 41 | 42 | ```bash 43 | python3 benchmarks/covertype_opt_inf.py --model_dir .outs/w200 44 | ``` 45 | 46 | where `--model_dir` is where the model was saved after training. 47 | 48 | Results obtained on the test loss are summarized in the table below (trained only 21k steps VS 130k for the reference): 49 | 50 | | Model | Test Accuracy | 51 | | -- | -- | 52 | | Reference | 96.99 | 53 | | `TabNetClassifier` | 96.59 | 54 | 55 | 56 | 57 | -------------------------------------------------------------------------------- /benchmarks/covertype.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | from typing import Text 4 | from datetime import datetime 5 | import json 6 | import shutil 7 | 8 | import tensorflow as tf 9 | import numpy as np 10 | 11 | from tabnet.models.classify import TabNetClassifier 12 | from tabnet.datasets.covertype import get_dataset, get_data 13 | from tabnet.callbacks import TensorBoardWithLR, LRFinder 14 | from tabnet.schedules import DecayWithWarmupSchedule 15 | from tabnet.utils import set_seed 16 | 17 | 18 | TMPDIR = ".tmp" 19 | LOGDIR = ".logs" 20 | OUTDIR = ".outs" 21 | DATA_PATH = "data/covtype.csv" 22 | CONFIGS = { 23 | "feature_dim": 64, 24 | "output_dim": 64, 25 | "num_features": 54, 26 | "sparsity_coefficient": 0.0001, 27 | "batch_size": 16384, 28 | "bn_virtual_bs": 512, 29 | "bn_momentum": 0.7, 30 | "n_steps": 5, 31 | "relaxation_factor": 1.5, 32 | "n_classes": 7, 33 | "learning_rate": 0.02, 34 | "min_learning_rate": 1e-6, 35 | "decay_steps": 500, 36 | "decay_rate": 0.95, 37 | "total_steps": 130000, 38 | "clipnorm": 2.0, 39 | "dp": 0.2, 40 | "seed": 42, 41 | } 42 | 43 | 44 | def clean_tmp_dir(): 45 | if os.path.exists(TMPDIR): 46 | shutil.rmtree(TMPDIR) 47 | os.makedirs(TMPDIR) 48 | 49 | 50 | def run_lrfinder( 51 | ds: tf.data.Dataset, model: tf.keras.Model, optimizer, lossf, steps_per_epoch: int, 52 | ) -> None: 53 | lrfinder = LRFinder(num_steps=steps_per_epoch, max_lr=1) 54 | 55 | _ = model.fit(ds, epochs=1, steps_per_epoch=steps_per_epoch, callbacks=[lrfinder],) 56 | 57 | 58 | def train( 59 | run_name: Text, 60 | data_path: Text, 61 | out_dir: Text, 62 | bn_momentum: float, 63 | bn_virtual_bs: int, 64 | clipnorm: float, 65 | decay_rate: float, 66 | decay_steps: int, 67 | learning_rate: float, 68 | sparsity_coefficient: float, 69 | epochs: int, 70 | cleanup: bool, 71 | warmup: int, 72 | dp: float, 73 | seed: int, 74 | do_lr_finder: bool, 75 | ): 76 | set_seed(seed) 77 | clean_tmp_dir() 78 | 79 | if cleanup: 80 | out_dir = os.path.join(out_dir, run_name) 81 | if os.path.exists(out_dir): 82 | shutil.rmtree(out_dir) 83 | 84 | df_tr, df_val, df_test = get_data(data_path) 85 | 86 | ds_tr = get_dataset( 87 | df_tr, shuffle=True, batch_size=CONFIGS["batch_size"], seed=seed 88 | ) 89 | ds_val = get_dataset( 90 | df_val, shuffle=False, batch_size=CONFIGS["batch_size"], drop_remainder=False 91 | ) 92 | ds_test = get_dataset( 93 | df_test, shuffle=False, batch_size=CONFIGS["batch_size"], drop_remainder=False 94 | ) 95 | 96 | num_train_steps = np.floor(len(df_tr) / CONFIGS["batch_size"]) 97 | num_valid_steps = np.ceil(len(df_val) / CONFIGS["batch_size"]) 98 | num_test_steps = np.ceil(len(df_test) / CONFIGS["batch_size"]) 99 | 100 | model = TabNetClassifier( 101 | num_features=CONFIGS["num_features"], 102 | feature_dim=CONFIGS["feature_dim"], 103 | output_dim=CONFIGS["output_dim"], 104 | n_classes=CONFIGS["n_classes"], 105 | n_step=CONFIGS["n_steps"], 106 | relaxation_factor=CONFIGS["relaxation_factor"], 107 | sparsity_coefficient=sparsity_coefficient, 108 | bn_momentum=bn_momentum, 109 | bn_virtual_divider=int(CONFIGS["batch_size"] / CONFIGS["bn_virtual_bs"]), 110 | dp=dp if dp > 0 else None, 111 | ) 112 | 113 | model.build((None, CONFIGS["num_features"])) 114 | model.summary() 115 | 116 | if warmup: 117 | lr = DecayWithWarmupSchedule( 118 | learning_rate, CONFIGS["min_learning_rate"], warmup, decay_rate, decay_steps 119 | ) 120 | elif do_lr_finder: 121 | lr = learning_rate 122 | else: 123 | lr = tf.keras.optimizers.schedules.ExponentialDecay( 124 | learning_rate, 125 | decay_steps=decay_steps, 126 | decay_rate=decay_rate, 127 | staircase=False, 128 | ) 129 | 130 | optimizer = tf.keras.optimizers.Adam(learning_rate=lr, clipnorm=clipnorm) 131 | 132 | lossf = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True) 133 | 134 | model.compile( 135 | optimizer, 136 | loss=lossf, 137 | metrics=[tf.keras.metrics.SparseCategoricalAccuracy(name="accuracy")], 138 | ) 139 | 140 | if do_lr_finder: 141 | run_lrfinder(ds_tr, model, optimizer, lossf, num_train_steps) 142 | return 143 | 144 | epochs = ( 145 | int(np.ceil(CONFIGS["total_steps"] / num_train_steps)) 146 | if epochs is None 147 | else epochs 148 | ) 149 | 150 | log_dir = ( 151 | os.path.join(LOGDIR, datetime.strftime(datetime.now(), "%Y-%m-%d-%H-%M-%S")) 152 | if run_name is None 153 | else os.path.join(LOGDIR, run_name) 154 | ) 155 | 156 | if os.path.exists(log_dir): 157 | shutil.rmtree(log_dir) 158 | 159 | checkpoint_path = os.path.join(TMPDIR, "checkpoint") 160 | 161 | callbacks = [ 162 | TensorBoardWithLR(log_dir=log_dir, write_graph=True, profile_batch=0), 163 | tf.keras.callbacks.ModelCheckpoint( 164 | filepath=checkpoint_path, 165 | monitor="val_accuracy", 166 | verbose=1, 167 | mode="max", 168 | save_best_only=True, 169 | ), 170 | ] 171 | 172 | # train 173 | 174 | h = model.fit( 175 | ds_tr, 176 | epochs=epochs, 177 | validation_data=ds_val, 178 | steps_per_epoch=num_train_steps, 179 | validation_steps=num_valid_steps, 180 | callbacks=callbacks, 181 | ) 182 | 183 | model.load_weights(checkpoint_path) 184 | model.save_to_directory(out_dir) 185 | 186 | # evaluate 187 | 188 | metrics = model.evaluate(ds_test, steps=num_test_steps, return_dict=True) 189 | 190 | with open(os.path.join(out_dir, "test_results.json"), "w") as f: 191 | json.dump(metrics, f) 192 | 193 | print(metrics) 194 | 195 | 196 | # example: python benchmarks/covertype.py --run_name w200_dp0.4 --epochs 1500 --warmup 200 --dp 0.4 197 | if __name__ == "__main__": 198 | parser = argparse.ArgumentParser("TabNet Covertype Training") 199 | parser.add_argument("--run_name", default=None, type=str) 200 | parser.add_argument("--data_path", default=DATA_PATH, type=str) 201 | parser.add_argument("--out_dir", default=OUTDIR, type=str) 202 | parser.add_argument("--bn_momentum", default=CONFIGS["bn_momentum"], type=float) 203 | parser.add_argument("--bn_virtual_bs", default=CONFIGS["bn_virtual_bs"], type=int) 204 | parser.add_argument("--clipnorm", default=CONFIGS["clipnorm"], type=float) 205 | parser.add_argument("--decay_rate", default=CONFIGS["decay_rate"], type=float) 206 | parser.add_argument("--decay_steps", default=CONFIGS["decay_steps"], type=int) 207 | parser.add_argument("--learning_rate", default=CONFIGS["learning_rate"], type=int) 208 | parser.add_argument("--dp", default=CONFIGS["dp"], type=float) 209 | parser.add_argument("--seed", default=CONFIGS["seed"], type=int) 210 | parser.add_argument( 211 | "--sparsity_coefficient", default=CONFIGS["sparsity_coefficient"], type=float 212 | ) 213 | parser.add_argument("--epochs", default=None, type=int) 214 | parser.add_argument( 215 | "--cleanup", 216 | action="store_true", 217 | help="Cleanup the output folder before starting the training.", 218 | ) 219 | parser.add_argument("--warmup", default=None, type=int) 220 | parser.add_argument( 221 | "--do_lr_finder", action="store_true", help="Runs only the LR finder only" 222 | ) 223 | args = parser.parse_args() 224 | 225 | train( 226 | args.run_name, 227 | args.data_path, 228 | args.out_dir, 229 | args.bn_momentum, 230 | args.bn_virtual_bs, 231 | args.clipnorm, 232 | args.decay_rate, 233 | args.decay_steps, 234 | args.learning_rate, 235 | args.sparsity_coefficient, 236 | args.epochs, 237 | args.cleanup, 238 | args.warmup, 239 | args.dp, 240 | args.seed, 241 | args.do_lr_finder, 242 | ) 243 | -------------------------------------------------------------------------------- /benchmarks/covertype_opt_inf.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | from typing import Text 3 | 4 | import tensorflow as tf 5 | import numpy as np 6 | 7 | from tabnet.datasets.covertype import get_dataset, get_data 8 | from tabnet.utils import set_seed 9 | from tabnet.models.classify import TabNetClassifier 10 | 11 | 12 | DATA_PATH = "data/covtype.csv" 13 | 14 | 15 | def get_accuracy( 16 | model: tf.keras.Model, 17 | ds: tf.data.Dataset, 18 | alpha: float, 19 | num_steps: int, 20 | size_of_dataset: int, 21 | ): 22 | accuracy = 0 23 | ds_iter = iter(ds) 24 | for _ in range(num_steps): 25 | x, y_true = next(ds_iter) 26 | logits = model(x, training=False, alpha=alpha) 27 | probs = tf.nn.softmax(logits) 28 | y_pred = tf.argmax(probs, axis=-1) 29 | accuracy += tf.reduce_sum(tf.cast(y_pred == y_true, tf.float32)).numpy() 30 | return accuracy / size_of_dataset 31 | 32 | 33 | def optimize_alpha_on_dataset( 34 | model: tf.keras.Model, ds: tf.data.Dataset, size_of_dataset: int 35 | ): 36 | sample, _ = next(iter(ds)) 37 | bs = sample.shape[0] 38 | num_steps = int(np.ceil(size_of_dataset / bs)) 39 | 40 | alphas = [1 / (2 * bs) * i for i in range(10)] 41 | 42 | accuracies = [] 43 | for alpha in alphas: 44 | accuracies.append(get_accuracy(model, ds, alpha, num_steps, size_of_dataset)) 45 | return alphas, accuracies 46 | 47 | 48 | def main(model_dir: Text, data_path: Text, batch_size: int, seed: int): 49 | set_seed(seed) 50 | model = TabNetClassifier.load_from_directory(model_dir) 51 | 52 | _, df_val, df_test = get_data(data_path) 53 | 54 | ds_val = get_dataset( 55 | df_val, shuffle=False, batch_size=batch_size, drop_remainder=False 56 | ) 57 | 58 | ds_test = get_dataset( 59 | df_test, shuffle=False, batch_size=batch_size, drop_remainder=False 60 | ) 61 | 62 | # optimize on validation dataset 63 | alphas, accuracies = optimize_alpha_on_dataset(model, ds_val, len(df_val)) 64 | best_alpha = alphas[np.argmax(accuracies)] 65 | 66 | # check on test dataset 67 | size_of_dataset = len(df_test) 68 | num_steps = int(np.ceil(size_of_dataset / batch_size)) 69 | test_accuracy_opt = get_accuracy( 70 | model, ds_test, best_alpha, num_steps, len(df_test) 71 | ) 72 | test_accuracy_orig = get_accuracy(model, ds_test, 0.0, num_steps, len(df_test)) 73 | 74 | print(accuracies) 75 | 76 | print(f"Accuracy: {np.min(accuracies)} (min) {np.max(accuracies)} (max)") 77 | print(f"Alphas: {np.min(alphas)} (min) {np.max(alphas)} (max)") 78 | print( 79 | f"Alpha at Accuracy: {alphas[np.argmin(accuracies)]} (min) {alphas[np.argmax(accuracies)]} (max)" 80 | ) 81 | print(f"Test Accuracy: {test_accuracy_opt} (opt) {test_accuracy_orig} (orig)") 82 | 83 | 84 | # python benchmarks/covertype_opt_inf.py --model_dir .outs/w100_dp0.4 85 | if __name__ == "__main__": 86 | parser = argparse.ArgumentParser() 87 | parser.add_argument("--model_dir", type=str) 88 | parser.add_argument("--data_path", default=DATA_PATH, type=str) 89 | parser.add_argument("--batch_size", default=512, type=int) 90 | parser.add_argument("--seed", default=42, type=int) 91 | args = parser.parse_args() 92 | 93 | main(args.model_dir, args.data_path, args.batch_size, args.seed) 94 | -------------------------------------------------------------------------------- /data/test/covtype_sample.csv: -------------------------------------------------------------------------------- 1 | 2596,51,3,258,0,510,221,232,148,6279,1,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,1,0,0,0,0,0,0,0,0,0,0,0,5 2 | 2590,56,2,212,-6,390,220,235,151,6225,1,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,1,0,0,0,0,0,0,0,0,0,0,0,5 3 | 2804,139,9,268,65,3180,234,238,135,6121,1,0,0,0,0,0,0,0,0,0,0,0,0,0,0,1,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,2 4 | 2785,155,18,242,118,3090,238,238,122,6211,1,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,1,0,0,0,0,0,0,0,0,0,0,2 5 | 2595,45,2,153,-1,391,220,234,150,6172,1,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,1,0,0,0,0,0,0,0,0,0,0,0,5 6 | 2579,132,6,300,-15,67,230,237,140,6031,1,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,1,0,0,0,0,0,0,0,0,0,0,0,2 7 | 2606,45,7,270,5,633,222,225,138,6256,1,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,1,0,0,0,0,0,0,0,0,0,0,0,5 8 | 2605,49,4,234,7,573,222,230,144,6228,1,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,1,0,0,0,0,0,0,0,0,0,0,0,5 9 | 2617,45,9,240,56,666,223,221,133,6244,1,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,1,0,0,0,0,0,0,0,0,0,0,0,5 10 | 2612,59,10,247,11,636,228,219,124,6230,1,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,1,0,0,0,0,0,0,0,0,0,0,0,5 11 | 2612,201,4,180,51,735,218,243,161,6222,1,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,1,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,5 12 | 2886,151,11,371,26,5253,234,240,136,4051,1,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,1,0,0,0,0,0,0,0,0,0,0,2 13 | 2742,134,22,150,69,3215,248,224,92,6091,1,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,1,0,0,0,0,0,0,0,0,0,0,2 14 | 2609,214,7,150,46,771,213,247,170,6211,1,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,1,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,5 15 | 2503,157,4,67,4,674,224,240,151,5600,1,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,1,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,5 16 | 2495,51,7,42,2,752,224,225,137,5576,1,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,1,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,5 17 | 2610,259,1,120,-1,607,216,239,161,6096,1,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,1,0,0,0,0,0,0,0,0,0,0,0,5 18 | 2517,72,7,85,6,595,228,227,133,5607,1,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,1,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,5 19 | 2504,0,4,95,5,691,214,232,156,5572,1,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,1,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,5 20 | 2503,38,5,85,10,741,220,228,144,5555,1,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,1,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,5 21 | 2501,71,9,60,8,767,230,223,126,5547,1,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,1,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,5 22 | 2880,209,17,216,30,4986,206,253,179,4323,1,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,1,0,0,0,0,0,0,0,0,0,0,2 23 | 2768,114,23,192,82,3339,252,209,71,5972,1,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,1,0,0,0,0,0,0,0,0,0,0,5 24 | 2511,54,8,124,0,638,225,222,130,5569,1,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,1,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,5 25 | 2507,22,9,120,14,732,215,221,143,5534,1,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,1,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,5 26 | 2492,135,6,0,0,860,229,237,142,5494,1,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,1,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,5 27 | 2489,163,10,30,-4,849,230,243,145,5486,1,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,1,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,5 28 | 2962,148,16,323,23,5916,240,236,120,3395,1,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,1,0,0,0,0,0,0,0,0,0,0,0,2 29 | 2811,135,1,212,30,3670,220,238,154,5643,1,0,0,0,0,0,0,0,0,0,0,0,0,0,0,1,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,2 30 | 2739,117,24,127,53,3281,253,210,71,6033,1,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,1,0,0,0,0,0,0,0,0,0,0,5 31 | 2703,122,30,67,27,3191,254,201,52,6123,1,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,1,0,0,0,0,0,0,0,0,0,0,5 32 | 2522,105,7,120,1,595,233,231,130,5569,1,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,1,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,5 33 | 2519,102,6,124,4,616,230,233,137,5559,1,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,1,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,5 34 | 2516,23,6,150,4,658,216,227,147,5541,1,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,1,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,5 35 | 2515,41,9,162,4,680,221,220,133,5532,1,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,1,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,5 36 | 2900,45,19,242,20,5199,221,195,100,4115,1,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,1,0,0,0,0,0,0,0,0,0,0,0,2 37 | 2709,125,28,67,23,3224,253,207,61,6094,1,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,1,0,0,0,0,0,0,0,0,0,0,5 38 | 2511,92,7,182,18,722,231,229,131,5494,1,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,1,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,5 39 | 2749,98,30,124,53,3316,252,183,36,6005,1,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,1,0,0,0,0,0,0,0,0,0,0,5 40 | 2686,354,12,0,0,3167,200,219,157,6155,1,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,1,0,0,0,0,0,0,0,0,0,0,5 41 | 2699,347,3,0,0,2096,213,234,159,6853,1,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,1,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,1 42 | 2570,346,2,0,0,331,215,235,158,5745,1,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,1,0,0,0,0,0,0,0,0,0,0,0,2 43 | 2533,71,9,150,-3,577,230,223,126,5552,1,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,1,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,5 44 | 2703,330,27,30,17,3141,146,197,184,6186,1,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,1,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,5 45 | 2678,128,5,95,23,1660,229,236,141,6546,1,0,0,0,0,0,0,0,0,0,0,0,0,0,0,1,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,2 46 | 2529,68,8,210,-5,666,228,225,130,5484,1,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,1,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,5 47 | 2524,94,7,212,-4,684,232,229,130,5474,1,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,1,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,5 48 | 2536,99,6,234,0,659,230,232,136,5475,1,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,1,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,5 49 | 2498,66,6,95,7,900,227,227,135,5357,1,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,1,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,5 50 | 2489,100,7,85,13,810,232,231,131,5334,1,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,1,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,5 51 | 2713,117,30,60,17,3297,254,198,48,6039,1,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,1,0,0,0,0,0,0,0,0,0,0,5 52 | 2739,323,25,85,43,3118,149,205,192,6219,1,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,1,0,0,0,0,0,0,0,0,0,0,0,1 53 | 2696,72,2,30,0,3271,222,234,149,6071,1,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,1,0,0,0,0,0,0,0,0,0,0,1 54 | 2510,79,14,192,19,891,237,215,106,5325,1,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,1,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,5 55 | 2502,81,7,175,11,912,230,227,129,5316,1,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,1,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,5 56 | 2722,315,24,30,19,3216,148,212,200,6132,1,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,1,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,1 57 | 2500,74,11,190,9,930,233,219,116,5279,1,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,1,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,5 58 | 2486,68,5,180,-4,870,225,230,139,5262,1,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,1,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,5 59 | 2489,11,4,175,13,840,216,232,153,5254,1,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,1,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,5 60 | 2489,42,6,162,13,810,221,227,141,5247,1,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,1,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,5 61 | 2490,75,5,134,17,810,227,230,137,5218,1,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,1,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,5 62 | 2952,107,11,42,7,5845,239,226,116,3509,1,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,1,0,0,0,0,0,0,0,0,0,0,0,2 63 | 2705,90,8,134,22,2023,232,228,129,6615,1,0,0,0,0,0,0,0,0,0,0,0,0,0,0,1,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,2 64 | 2507,40,7,153,10,930,221,224,138,5221,1,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,1,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,5 65 | 2500,49,14,150,27,870,225,210,116,5205,1,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,1,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,5 66 | 2493,63,10,127,20,840,229,221,124,5197,1,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,1,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,5 67 | 2509,59,7,134,10,900,226,226,134,5184,1,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,1,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,5 68 | 2919,13,13,90,6,5321,207,214,142,4060,1,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,1,0,0,0,0,0,0,0,0,0,0,0,1 69 | 2740,54,6,218,42,2287,224,227,138,6686,1,0,0,0,0,0,0,0,0,0,0,0,0,0,0,1,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,2 70 | 2640,80,8,180,-2,1092,231,225,127,5866,1,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,1,0,0,0,0,0,0,0,0,0,0,0,2 71 | 2843,166,12,242,53,4434,230,244,144,4956,1,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,1,0,0,0,0,0,0,0,0,0,0,2 72 | 3008,45,14,277,10,6371,223,208,116,3036,1,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,1,0,0,0,0,0,0,0,0,0,0,0,2 73 | 2893,114,16,108,30,5066,245,223,102,4340,1,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,1,0,0,0,0,0,0,0,0,0,0,0,2 74 | 2850,6,9,0,0,4858,210,223,151,4548,1,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,1,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,2 75 | 2628,30,10,240,19,960,217,218,136,5645,1,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,1,0,0,0,0,0,0,0,0,0,0,0,2 76 | 2864,118,18,201,74,4567,248,221,93,4849,1,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,1,0,0,0,0,0,0,0,0,0,0,2 77 | 2827,160,28,134,65,3948,235,233,108,5474,1,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,1,0,0,0,0,0,0,0,0,0,0,2 78 | 2529,326,5,30,14,1062,207,234,166,5047,1,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,1,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,2 79 | 2808,99,7,382,95,3107,233,230,130,6341,1,0,0,0,0,0,0,0,0,0,0,0,0,0,0,1,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,1 80 | 2840,153,26,134,42,4613,241,231,102,4833,1,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,1,0,0,0,0,0,0,0,0,0,0,2 81 | 2795,79,10,531,96,2980,233,223,121,6497,1,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,1,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,1 82 | 2746,143,16,67,22,2440,241,235,119,6597,1,0,0,0,0,0,0,0,0,0,0,0,0,0,0,1,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,2 83 | 2847,352,26,150,82,3697,166,187,152,5796,1,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,1,0,0,0,0,0,0,0,0,0,0,0,1 84 | 2840,14,14,216,88,3552,206,210,140,5944,1,0,0,0,0,0,0,0,0,0,0,0,0,0,0,1,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,1 85 | 2537,42,7,210,17,1132,222,224,137,4919,1,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,1,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,2 86 | 2860,358,17,175,98,3705,191,206,151,5800,1,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,1,0,0,0,0,0,0,0,0,0,0,0,1 87 | 2818,332,26,30,17,4526,151,197,181,4978,1,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,1,0,0,0,0,0,0,0,0,0,0,0,2 88 | 2801,18,7,560,58,3084,215,226,148,6457,1,0,0,0,0,0,0,0,0,0,0,0,0,0,0,1,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,2 89 | 2791,63,10,418,48,2942,229,221,124,6606,1,0,0,0,0,0,0,0,0,0,0,0,0,0,0,1,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,2 90 | 2745,306,11,67,24,2416,190,234,184,6428,1,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,1,0,0,0,0,0,0,0,0,0,0,0,2 91 | 2514,102,6,272,-5,1082,230,233,137,4811,1,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,1,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,2 92 | 2788,13,16,30,8,4126,203,206,137,5396,1,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,1,0,0,0,0,0,0,0,0,0,0,0,2 93 | 2562,354,12,67,9,1057,200,218,156,5031,1,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,1,0,0,0,0,0,0,0,0,0,0,2 94 | 3073,173,12,108,-3,6836,227,246,149,2735,1,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,1,0,0,0,0,0,0,0,0,0,0,0,2 95 | 2978,71,10,426,85,5742,231,221,121,3792,1,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,1,0,0,0,0,0,0,0,0,0,0,0,2 96 | 2860,31,10,295,98,3644,218,218,135,5904,1,0,0,0,0,0,0,0,0,0,0,0,0,0,0,1,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,1 97 | 3067,164,11,85,7,6811,230,243,144,2774,1,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,1,0,0,0,0,0,0,0,0,0,0,0,2 98 | 2804,72,5,543,61,3115,225,231,141,6471,1,0,0,0,0,0,0,0,0,0,0,0,0,0,0,1,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,2 99 | 2562,59,3,0,0,1116,221,233,148,5091,1,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,1,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,2 100 | 2567,34,9,190,16,1136,219,221,138,4924,1,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,1,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,2 101 | -------------------------------------------------------------------------------- /examples/train_mnist.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | from typing import Text, List 3 | import pickle 4 | import shutil 5 | import os 6 | 7 | import tensorflow as tf 8 | import tensorflow_datasets as tfds 9 | from kerastuner.tuners import RandomSearch 10 | 11 | from tabnet.models.classify import TabNetClassifier 12 | from tabnet.utils import set_seed 13 | from tabnet.schedules import DecayWithWarmupSchedule 14 | 15 | 16 | SEARCH_DIR = ".search" 17 | SEED = 42 18 | DEFAULTS = {"num_features": 784, "n_classes": 10, "min_learning_rate": 1e-6} # 28x28 19 | 20 | 21 | # because doing a training on MNIST is something I MUST do, no? 22 | # this time let's add a twist & do hyperparameter optimization with kerastuner 23 | 24 | 25 | def build_model(hp): 26 | model = TabNetClassifier( 27 | num_features=DEFAULTS["num_features"], 28 | feature_dim=hp.Choice("feature_dim", values=[16, 32, 64], default=32), 29 | output_dim=hp.Choice("output_dim", values=[16, 32, 64], default=32), 30 | n_classes=DEFAULTS["n_classes"], 31 | n_step=hp.Choice("n_step", values=[2, 4, 5, 6], default=4), 32 | relaxation_factor=hp.Choice( 33 | "relaxation_factor", values=[1.0, 1.25, 1.5, 2.0, 3.0], default=1.5 34 | ), 35 | sparsity_coefficient=hp.Choice( 36 | "sparsity_coefficient", 37 | values=[0.0001, 0.001, 0.01, 0.02, 0.05], 38 | default=0.0001, 39 | ), 40 | bn_momentum=hp.Choice("bn_momentum", values=[0.6, 0.7, 0.9], default=0.7), 41 | bn_virtual_divider=1, # let's not use Ghost Batch Normalization. batch sizes are too small 42 | dp=hp.Choice("dp", values=[0.0, 0.1, 0.2, 0.3, 0.4], default=0.0), 43 | ) 44 | lr = DecayWithWarmupSchedule( 45 | hp.Choice( 46 | "learning_rate", values=[0.001, 0.005, 0.01, 0.02, 0.05], default=0.02 47 | ), 48 | DEFAULTS["min_learning_rate"], 49 | hp.Choice("warmup", values=[1, 5, 10, 20], default=5), 50 | hp.Choice("decay_rate", values=[0.8, 0.90, 0.95, 0.99], default=0.95), 51 | hp.Choice("decay_steps", values=[10, 100, 500, 1000], default=500), 52 | ) 53 | 54 | optimizer = tf.keras.optimizers.Adam( 55 | learning_rate=lr, 56 | clipnorm=hp.Choice("clipnorm", values=[1, 2, 5, 10], default=2), 57 | ) 58 | 59 | lossf = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True) 60 | 61 | model.compile( 62 | optimizer, 63 | loss=lossf, 64 | metrics=[tf.keras.metrics.SparseCategoricalAccuracy(name="accuracy")], 65 | ) 66 | 67 | return model 68 | 69 | 70 | def prepare_dataset( 71 | ds: tf.data.Dataset, 72 | batch_size: int, 73 | shuffle: bool = False, 74 | drop_remainder: bool = False, 75 | ): 76 | size_of_dataset = ds.reduce(0, lambda x, _: x + 1).numpy() 77 | if shuffle: 78 | ds = ds.shuffle(buffer_size=size_of_dataset, seed=SEED) 79 | ds: tf.data.Dataset = ds.batch(batch_size, drop_remainder=drop_remainder) 80 | 81 | @tf.function 82 | def prepare_data(features): 83 | image = tf.cast(features["image"], tf.float32) 84 | bs = tf.shape(image)[0] 85 | image = tf.reshape(image / 255.0, (bs, -1)) 86 | return image, features["label"] 87 | 88 | autotune = tf.data.experimental.AUTOTUNE 89 | ds = ds.map(prepare_data, num_parallel_calls=autotune).prefetch(autotune) 90 | return ds 91 | 92 | 93 | def search( 94 | epochs: int, 95 | batch_size: int, 96 | n_trials: int, 97 | execution_per_trial: int, 98 | project: Text, 99 | do_cleanup: bool, 100 | ): 101 | set_seed(SEED) 102 | 103 | dir_to_clean = os.path.join(SEARCH_DIR, project) 104 | if do_cleanup and os.path.exists(dir_to_clean): 105 | shutil.rmtree(dir_to_clean) 106 | 107 | # first 80% for train. remaining 20% for val & test dataset for final eval. 108 | ds_tr, ds_val, ds_test = tfds.load( 109 | name="mnist", 110 | split=["train[:80%]", "train[-20%:]", "test"], 111 | data_dir="mnist", 112 | shuffle_files=False, 113 | ) 114 | 115 | ds_tr = prepare_dataset(ds_tr, batch_size, shuffle=True, drop_remainder=True) 116 | ds_val = prepare_dataset(ds_val, batch_size, shuffle=False, drop_remainder=False) 117 | ds_test = prepare_dataset(ds_test, batch_size, shuffle=False, drop_remainder=False) 118 | 119 | tuner = RandomSearch( 120 | build_model, 121 | objective="val_accuracy", 122 | max_trials=n_trials, 123 | executions_per_trial=execution_per_trial, 124 | directory=SEARCH_DIR, 125 | project_name=project, 126 | ) 127 | 128 | # ? add callbacks 129 | tuner.search( 130 | ds_tr, epochs=epochs, validation_data=ds_val, 131 | ) 132 | 133 | best_model: tf.keras.Model = tuner.get_best_models(num_models=1)[0] 134 | best_model.build((None, DEFAULTS["num_features"])) 135 | results = best_model.evaluate(ds_test, return_dict=True) 136 | 137 | tuner.results_summary(num_trials=1) 138 | best_hyperparams = tuner.get_best_hyperparameters(num_trials=1) 139 | print(f"Test results: {results}") 140 | 141 | output = {"results": results, "best_hyperparams": best_hyperparams} 142 | with open("search_results.pickle", "wb") as f: 143 | pickle.dump(output, f) 144 | 145 | # best_model.save("tabnet_saved_model") 146 | 147 | 148 | # python3 examples/train_mnist.py --trials 2 --epochs 10 --bs 128 149 | if __name__ == "__main__": 150 | parser = argparse.ArgumentParser() 151 | parser.add_argument("--trials", default=1, type=int) 152 | parser.add_argument("--epochs", default=1, type=int) 153 | parser.add_argument("--bs", default=32, type=int) 154 | parser.add_argument("--exec_per_trial", default=2, type=int) 155 | parser.add_argument("--project", default="test", type=str) 156 | parser.add_argument("--cleanup", action="store_true") 157 | args = parser.parse_args() 158 | 159 | search( 160 | args.epochs, 161 | args.bs, 162 | args.trials, 163 | args.exec_per_trial, 164 | args.project, 165 | args.cleanup, 166 | ) 167 | -------------------------------------------------------------------------------- /images/tabnet.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ostamand/tensorflow-tabnet/47ef98fd251aea2bc2ac1d4b7b6bfe8517cfeea9/images/tabnet.png -------------------------------------------------------------------------------- /images/virtual_bs_vs_gbn.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ostamand/tensorflow-tabnet/47ef98fd251aea2bc2ac1d4b7b6bfe8517cfeea9/images/virtual_bs_vs_gbn.png -------------------------------------------------------------------------------- /local/tuner_orig_dataset.py: -------------------------------------------------------------------------------- 1 | import os 2 | import shutil 3 | from argparse import ArgumentParser 4 | from typing import Text 5 | from datetime import datetime 6 | 7 | import tensorflow as tf 8 | import numpy as np 9 | from kerastuner.tuners import BayesianOptimization, RandomSearch 10 | 11 | from tabnet.models import TabNetClassifier 12 | from local.original_dataset import ( 13 | input_fn, 14 | get_columns, 15 | NUM_FEATURES, 16 | N_VAL_SAMPLES, 17 | N_TR_SAMPLES, 18 | ) 19 | 20 | 21 | SEARCH_DIR = ".search" 22 | 23 | 24 | DEFAULTS = { 25 | "N_d": 64, 26 | "N_a": 64, 27 | "sparsity_coefficient": 0.0001, 28 | "batch_size": 16384, 29 | "bn_virtual_bs": 512, # 256 30 | "bn_momentum": 0.7, 31 | "n_steps": 5, 32 | "relaxation_factor": 1.5, 33 | "n_classes": 7, 34 | "learning_rate": 0.02, 35 | "decay_steps": 500, # 20 36 | "decay_rate": 0.95, 37 | "total_steps": 130000, 38 | "gradient_thresh": 2000, 39 | } 40 | 41 | 42 | def build_model(hp): 43 | model = TabNetClassifier( 44 | feature_columns=get_columns(), 45 | num_features=NUM_FEATURES, 46 | feature_dim=DEFAULTS["N_d"], 47 | output_dim=DEFAULTS["N_a"], 48 | n_classes=DEFAULTS["n_classes"], 49 | n_step=DEFAULTS["n_steps"], 50 | relaxation_factor=DEFAULTS["relaxation_factor"], 51 | sparsity_coefficient=hp.Choice( 52 | "sparsity_coefficient", values=[0.0001, 0.001, 0.01], default=0.0001 53 | ), 54 | bn_momentum=hp.Choice("bn_momentum", values=[0.6, 0.7, 0.9], default=0.7), 55 | bn_virtual_bs=hp.Choice( 56 | "bn_virtual_bs", values=[128, 256, 512, 1024], default=512 57 | ), 58 | ) 59 | 60 | lr = tf.keras.optimizers.schedules.ExponentialDecay( 61 | hp.Choice( 62 | "learning_rate", values=[0.002, 0.005, 0.01, 0.02, 0.05], default=0.02 63 | ), 64 | decay_steps=hp.Choice("decay_steps", values=[10, 100, 500, 1000], default=500), 65 | decay_rate=hp.Choice("decay_rate", values=[0.90, 0.95, 0.99], default=0.95), 66 | staircase=False, 67 | ) 68 | 69 | optimizer = tf.keras.optimizers.Adam( 70 | learning_rate=lr, 71 | clipnorm=hp.Choice("clipnorm", values=[1, 2, 5, 10], default=2), 72 | ) 73 | 74 | lossf = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True) 75 | 76 | model.compile( 77 | optimizer, 78 | loss=lossf, 79 | metrics=[tf.keras.metrics.SparseCategoricalAccuracy(name="accuracy")], 80 | ) 81 | 82 | return model 83 | 84 | 85 | # Total runtime: 158.35 mins = 10 trials 86 | def search( 87 | epochs: int, 88 | n_trials: int, 89 | execution_per_trial: int, 90 | project: Text = "test", 91 | cleanup: bool = False, 92 | ): 93 | start_time = datetime.now() 94 | 95 | results_path = os.path.join(SEARCH_DIR, project) 96 | if cleanup and os.path.exists(results_path): 97 | shutil.rmtree(results_path) 98 | 99 | ds_tr = input_fn( 100 | "data/train_covertype.csv", shuffle=True, batch_size=DEFAULTS["batch_size"] 101 | ) 102 | ds_val = input_fn( 103 | "data/val_covertype.csv", shuffle=False, batch_size=DEFAULTS["batch_size"] 104 | ) 105 | 106 | num_train_steps = np.floor(N_TR_SAMPLES / DEFAULTS["batch_size"]) 107 | num_valid_steps = np.floor(N_VAL_SAMPLES / DEFAULTS["batch_size"]) 108 | 109 | # RandomSearch, BayesianOptimization 110 | tuner = RandomSearch( 111 | build_model, 112 | objective="val_loss", 113 | max_trials=n_trials, 114 | executions_per_trial=execution_per_trial, 115 | directory=SEARCH_DIR, 116 | project_name=project, 117 | ) 118 | 119 | # tuner.search_space_summary() 120 | 121 | tuner.search( 122 | ds_tr, 123 | epochs=epochs, 124 | validation_data=ds_val, 125 | steps_per_epoch=num_train_steps, 126 | validation_steps=num_valid_steps, 127 | ) 128 | 129 | # models = tuner.get_best_models(num_models=1) 130 | 131 | tuner.results_summary(num_trials=2) 132 | 133 | print(f"Total runtime: {(datetime.now() - start_time).seconds / 60:.2f} mins") 134 | 135 | 136 | # python local/tuner_orig_dataset.py --trials 10 --epoch 25 --project test_rnd_10_trials 137 | if __name__ == "__main__": 138 | parser = ArgumentParser() 139 | parser.add_argument("--trials", default=1, type=int) 140 | parser.add_argument("--epochs", default=1, type=int) 141 | parser.add_argument("--exec_per_trial", default=2, type=int) 142 | parser.add_argument("--project", default="test", type=str) 143 | parser.add_argument("--cleanup", action="store_true") 144 | args = parser.parse_args() 145 | 146 | search(args.epochs, args.trials, args.exec_per_trial, args.project, args.cleanup) 147 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | tensorflow==2.3.0 2 | tensorflow-addons==0.11.1 3 | tensorflow-datasets==3.2.1 4 | keras-tuner==1.0.1 5 | black 6 | ipykernel 7 | pandas 8 | numpy 9 | scikit-learn==0.23.2 10 | pytest 11 | matplotlib==3.3.1 -------------------------------------------------------------------------------- /scripts/install.sh: -------------------------------------------------------------------------------- 1 | pip install -r requirements.txt 2 | pip install -e . -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | from setuptools import setup, find_packages 2 | 3 | REQUIREMENTS = [ 4 | "tensorflow>=2.0.0", 5 | "tensorflow-addons>=0.11.1", 6 | "pandas>=1.1.0", 7 | "scikit-learn>=0.23.2", 8 | "matplotlib>=3.3.1" 9 | ] 10 | 11 | setup( 12 | name="tabnet", 13 | version="0.1", 14 | author="Olivier St-Amand", 15 | author_email="olivier.st.amand.1@gmail.com", 16 | description="TensorFlow 2 implementation of TabNet", 17 | license="MIT", 18 | packages=find_packages(exclude=["tests"]), 19 | python_requires=">=3.5.0", 20 | install_requires=REQUIREMENTS, 21 | ) 22 | -------------------------------------------------------------------------------- /tabnet/__init__.py: -------------------------------------------------------------------------------- 1 | from tabnet.models.model import TabNet 2 | -------------------------------------------------------------------------------- /tabnet/callbacks/__init__.py: -------------------------------------------------------------------------------- 1 | from tabnet.callbacks.lrfinder import LRFinder 2 | from tabnet.callbacks.tensorboard import TensorBoardWithLR 3 | -------------------------------------------------------------------------------- /tabnet/callbacks/lrfinder.py: -------------------------------------------------------------------------------- 1 | from typing import Text 2 | from math import log10 3 | 4 | import tensorflow as tf 5 | import tensorflow.keras.backend as K 6 | import matplotlib.pyplot as plt 7 | 8 | 9 | class LRFinder(tf.keras.callbacks.Callback): 10 | def __init__( 11 | self, 12 | min_lr: float = 1e-6, 13 | max_lr: float = 1e-2, 14 | num_steps: int = 100, 15 | monitor: Text = "loss", 16 | figname: Text = "lrfinder.png", 17 | ): 18 | super(LRFinder, self).__init__() 19 | self.monitor = monitor 20 | self.num_steps = num_steps 21 | self.min_lr, self.max_lr = min_lr, max_lr 22 | self.figname = figname 23 | 24 | # log(y) = m * log(x) + b 25 | # m = log(y2/y1) / log(x2/x1) 26 | # b = log(y2) - m * log(x2) 27 | self.m = log10(max_lr / min_lr) / log10(num_steps - 1) 28 | self.b = log10(max_lr) - self.m * log10(num_steps - 1) 29 | self.__reset() 30 | 31 | def __reset(self): 32 | self.losses = [] 33 | self.lrs = [] 34 | 35 | def set_lr(self, step: int) -> float: 36 | lr = pow(10, self.m * log10(step) + self.b) 37 | K.set_value(self.model.optimizer.lr, lr) 38 | return lr 39 | 40 | def save_fig(self): 41 | plt.semilogx(self.lrs, self.losses) 42 | plt.title("LR Finder") 43 | plt.xlabel("lr") 44 | plt.ylabel("batch loss") 45 | plt.savefig(self.figname) 46 | 47 | def on_train_batch_end(self, batch, logs=None): 48 | it = len(self.losses) + 1 49 | 50 | if len(self.losses) < self.num_steps: 51 | self.losses.append(logs[self.monitor]) 52 | 53 | if len(self.lrs) < self.num_steps: 54 | self.lrs.append(self.set_lr(it)) 55 | else: 56 | self.model.stop_training = True 57 | if self.figname is not None: 58 | self.save_fig() 59 | 60 | def on_train_begin(self, logs=None): 61 | self.__reset() 62 | self.lrs.append(self.set_lr(1)) 63 | -------------------------------------------------------------------------------- /tabnet/callbacks/tensorboard.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | import tensorflow.keras.backend as K 3 | 4 | 5 | class TensorBoardWithLR(tf.keras.callbacks.TensorBoard): 6 | def on_epoch_end(self, epoch, logs=None): 7 | logs.update({"lr": self.model.optimizer.lr(epoch)}) 8 | return super().on_epoch_end(epoch, logs=logs) 9 | -------------------------------------------------------------------------------- /tabnet/datasets/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ostamand/tensorflow-tabnet/47ef98fd251aea2bc2ac1d4b7b6bfe8517cfeea9/tabnet/datasets/__init__.py -------------------------------------------------------------------------------- /tabnet/datasets/covertype.py: -------------------------------------------------------------------------------- 1 | from typing import Text, Tuple 2 | 3 | import pandas as pd 4 | import numpy as np 5 | import tensorflow as tf 6 | from sklearn.model_selection import train_test_split 7 | 8 | 9 | RANDOM_SEED = 0 10 | 11 | 12 | def get_data(path_to_csv: Text, seed: int = RANDOM_SEED) -> Tuple[pd.DataFrame]: 13 | df = pd.read_csv(path_to_csv) 14 | df_tr, df_test = train_test_split(df, test_size=0.2, random_state=seed) 15 | df_tr, df_val = train_test_split(df_tr, test_size=0.2 / 0.6, random_state=seed) 16 | return df_tr, df_val, df_test 17 | 18 | 19 | def get_dataset( 20 | df: pd.DataFrame, 21 | take: int = None, 22 | shuffle: bool = False, 23 | batch_size: int = 16384, 24 | drop_remainder: bool = True, 25 | seed: int = RANDOM_SEED, 26 | ) -> tf.data.Dataset: 27 | x = df[df.columns[:-1]].values.astype(np.float32) 28 | y = df[df.columns[-1]].values - 1 29 | ds: tf.data.Dataset = tf.data.Dataset.from_tensor_slices((x, y)) 30 | if shuffle: 31 | ds = ds.shuffle(buffer_size=len(x), seed=seed) 32 | ds = ds.batch(batch_size, drop_remainder=drop_remainder) 33 | if take is not None: 34 | ds = ds.take(take) 35 | ds = ds.repeat().prefetch(tf.data.experimental.AUTOTUNE) 36 | return ds 37 | -------------------------------------------------------------------------------- /tabnet/models/__init__.py: -------------------------------------------------------------------------------- 1 | from tabnet.models.model import TabNet 2 | from tabnet.models.classify import TabNetClassifier 3 | -------------------------------------------------------------------------------- /tabnet/models/classify.py: -------------------------------------------------------------------------------- 1 | import os 2 | from typing import List, Text 3 | import pickle 4 | 5 | import tensorflow as tf 6 | from tabnet.models import TabNet 7 | 8 | 9 | class TabNetClassifier(tf.keras.Model): 10 | def __init__( 11 | self, 12 | num_features: int, 13 | feature_dim: int, 14 | output_dim: int, 15 | n_classes: int, 16 | feature_columns: List = None, 17 | n_step: int = 1, 18 | n_total: int = 4, 19 | n_shared: int = 2, 20 | relaxation_factor: float = 1.5, 21 | sparsity_coefficient: float = 1e-5, 22 | bn_epsilon: float = 1e-5, 23 | bn_momentum: float = 0.7, 24 | bn_virtual_divider: int = 32, 25 | dp: float = None, 26 | output_activation: str = None, 27 | **kwargs 28 | ): 29 | super(TabNetClassifier, self).__init__() 30 | 31 | self.configs = { 32 | "num_features": num_features, 33 | "feature_dim": feature_dim, 34 | "output_dim": output_dim, 35 | "n_classes": n_classes, 36 | "feature_columns": feature_columns, 37 | "n_step": n_step, 38 | "n_total": n_total, 39 | "n_shared": n_shared, 40 | "relaxation_factor": relaxation_factor, 41 | "sparsity_coefficient": sparsity_coefficient, 42 | "bn_epsilon": bn_epsilon, 43 | "bn_momentum": bn_momentum, 44 | "bn_virtual_divider": bn_virtual_divider, 45 | "dp": dp, 46 | "output_activation": output_activation, 47 | } 48 | for k, v in kwargs.items(): 49 | self.configs[k] = v 50 | 51 | self.sparsity_coefficient = sparsity_coefficient 52 | 53 | self.model = TabNet( 54 | feature_columns=feature_columns, 55 | num_features=num_features, 56 | feature_dim=feature_dim, 57 | output_dim=output_dim, 58 | n_step=n_step, 59 | relaxation_factor=relaxation_factor, 60 | bn_epsilon=bn_epsilon, 61 | bn_momentum=bn_momentum, 62 | bn_virtual_divider=bn_virtual_divider, 63 | ) 64 | self.dp = tf.keras.layers.Dropout(dp) if dp is not None else dp 65 | self.head = tf.keras.layers.Dense(n_classes, activation=output_activation, use_bias=False) 66 | 67 | def call(self, x, training: bool = None, alpha: float = 0.0): 68 | out, sparse_loss, _ = self.model(x, training=training, alpha=alpha) 69 | if self.dp is not None: 70 | out = self.dp(out, training=training) 71 | y = self.head(out, training=training) 72 | 73 | if training: 74 | self.add_loss(-self.sparsity_coefficient * sparse_loss) 75 | 76 | return y 77 | 78 | def get_config(self): 79 | return self.configs 80 | 81 | def save_to_directory(self, path_to_folder: Text): 82 | self.save_weights(os.path.join(path_to_folder, "ckpt"), overwrite=True) 83 | with open(os.path.join(path_to_folder, "configs.pickle"), "wb") as f: 84 | pickle.dump(self.configs, f) 85 | 86 | @classmethod 87 | def load_from_directory(cls, path_to_folder: Text): 88 | with open(os.path.join(path_to_folder, "configs.pickle"), "rb") as f: 89 | configs = pickle.load(f) 90 | model: tf.keras.Model = cls(**configs) 91 | model.build((None, configs["num_features"])) 92 | load_status = model.load_weights(os.path.join(path_to_folder, "ckpt")) 93 | load_status.expect_partial() 94 | return model 95 | -------------------------------------------------------------------------------- /tabnet/models/gbn.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | 3 | 4 | class GhostBatchNormalization(tf.keras.Model): 5 | def __init__( 6 | self, virtual_divider: int = 1, momentum: float = 0.9, epsilon: float = 1e-5 7 | ): 8 | super(GhostBatchNormalization, self).__init__() 9 | self.virtual_divider = virtual_divider 10 | self.bn = BatchNormInferenceWeighting(momentum=momentum) 11 | 12 | def call(self, x, training: bool = None, alpha: float = 0.0): 13 | if training: 14 | chunks = tf.split(x, self.virtual_divider) 15 | x = [self.bn(x, training=True) for x in chunks] 16 | return tf.concat(x, 0) 17 | return self.bn(x, training=False, alpha=alpha) 18 | 19 | @property 20 | def moving_mean(self): 21 | return self.bn.moving_mean 22 | 23 | @property 24 | def moving_variance(self): 25 | return self.bn.moving_variance 26 | 27 | 28 | class BatchNormInferenceWeighting(tf.keras.layers.Layer): 29 | def __init__(self, momentum: float = 0.9, epsilon: float = None): 30 | super(BatchNormInferenceWeighting, self).__init__() 31 | self.momentum = momentum 32 | self.epsilon = tf.keras.backend.epsilon() if epsilon is None else epsilon 33 | 34 | def build(self, input_shape): 35 | channels = input_shape[-1] 36 | 37 | self.gamma = tf.Variable( 38 | initial_value=tf.ones((channels,), tf.float32), trainable=True, 39 | ) 40 | self.beta = tf.Variable( 41 | initial_value=tf.zeros((channels,), tf.float32), trainable=True, 42 | ) 43 | 44 | self.moving_mean = tf.Variable( 45 | initial_value=tf.zeros((channels,), tf.float32), trainable=False, 46 | ) 47 | self.moving_mean_of_squares = tf.Variable( 48 | initial_value=tf.zeros((channels,), tf.float32), trainable=False, 49 | ) 50 | 51 | def __update_moving(self, var, value): 52 | var.assign(var * self.momentum + (1 - self.momentum) * value) 53 | 54 | def __apply_normalization(self, x, mean, variance): 55 | return self.gamma * (x - mean) / tf.sqrt(variance + self.epsilon) + self.beta 56 | 57 | def call(self, x, training: bool = None, alpha: float = 0.0): 58 | mean = tf.reduce_mean(x, axis=0) 59 | mean_of_squares = tf.reduce_mean(tf.pow(x, 2), axis=0) 60 | 61 | if training: 62 | # update moving stats 63 | self.__update_moving(self.moving_mean, mean) 64 | self.__update_moving(self.moving_mean_of_squares, mean_of_squares) 65 | 66 | variance = mean_of_squares - tf.pow(mean, 2) 67 | x = self.__apply_normalization(x, mean, variance) 68 | else: 69 | mean = alpha * mean + (1 - alpha) * self.moving_mean 70 | variance = ( 71 | alpha * mean_of_squares + (1 - alpha) * self.moving_mean_of_squares 72 | ) - tf.pow(mean, 2) 73 | x = self.__apply_normalization(x, mean, variance) 74 | 75 | return x 76 | -------------------------------------------------------------------------------- /tabnet/models/model.py: -------------------------------------------------------------------------------- 1 | from typing import List, Tuple 2 | 3 | import tensorflow as tf 4 | 5 | from tabnet.models.transformers import ( 6 | FeatureTransformer, 7 | AttentiveTransformer, 8 | ) 9 | 10 | 11 | class TabNet(tf.keras.Model): 12 | def __init__( 13 | self, 14 | num_features: int, 15 | feature_dim: int, 16 | output_dim: int, 17 | feature_columns: List = None, 18 | n_step: int = 1, 19 | n_total: int = 4, 20 | n_shared: int = 2, 21 | relaxation_factor: float = 1.5, 22 | bn_epsilon: float = 1e-5, 23 | bn_momentum: float = 0.7, 24 | bn_virtual_divider: int = 1, 25 | ): 26 | """TabNet 27 | 28 | Will output a vector of size output_dim. 29 | 30 | Args: 31 | num_features (int): Number of features. 32 | feature_dim (int): Embedding feature dimention to use. 33 | output_dim (int): Output dimension. 34 | feature_columns (List, optional): If defined will add a DenseFeatures layer first. Defaults to None. 35 | n_step (int, optional): Total number of steps. Defaults to 1. 36 | n_total (int, optional): Total number of feature transformer blocks. Defaults to 4. 37 | n_shared (int, optional): Number of shared feature transformer blocks. Defaults to 2. 38 | relaxation_factor (float, optional): >1 will allow features to be used more than once. Defaults to 1.5. 39 | bn_epsilon (float, optional): Batch normalization, epsilon. Defaults to 1e-5. 40 | bn_momentum (float, optional): Batch normalization, momentum. Defaults to 0.7. 41 | bn_virtual_divider (int, optional): Batch normalization. Full batch will be divided by this. 42 | """ 43 | super(TabNet, self).__init__() 44 | self.output_dim, self.num_features = output_dim, num_features 45 | self.n_step, self.relaxation_factor = n_step, relaxation_factor 46 | self.feature_columns = feature_columns 47 | 48 | if feature_columns is not None: 49 | self.input_features = tf.keras.layers.DenseFeatures(feature_columns) 50 | 51 | # ? Switch to Ghost Batch Normalization 52 | self.bn = tf.keras.layers.BatchNormalization( 53 | momentum=bn_momentum, epsilon=bn_epsilon 54 | ) 55 | 56 | kargs = { 57 | "feature_dim": feature_dim + output_dim, 58 | "n_total": n_total, 59 | "n_shared": n_shared, 60 | "bn_momentum": bn_momentum, 61 | "bn_virtual_divider": bn_virtual_divider, 62 | } 63 | 64 | # first feature transformer block is built first to get the shared blocks 65 | self.feature_transforms: List[FeatureTransformer] = [ 66 | FeatureTransformer(**kargs) 67 | ] 68 | self.attentive_transforms: List[AttentiveTransformer] = [] 69 | for i in range(n_step): 70 | self.feature_transforms.append( 71 | FeatureTransformer(**kargs, fcs=self.feature_transforms[0].shared_fcs) 72 | ) 73 | self.attentive_transforms.append( 74 | AttentiveTransformer(num_features, bn_momentum, bn_virtual_divider) 75 | ) 76 | 77 | def call( 78 | self, features: tf.Tensor, training: bool = None, alpha: float = 0.0 79 | ) -> Tuple[tf.Tensor, tf.Tensor]: 80 | if self.feature_columns is not None: 81 | features = self.input_features(features) 82 | 83 | bs = tf.shape(features)[0] 84 | out_agg = tf.zeros((bs, self.output_dim)) 85 | prior_scales = tf.ones((bs, self.num_features)) 86 | masks = [] 87 | 88 | features = self.bn(features, training=training) 89 | masked_features = features 90 | 91 | total_entropy = 0.0 92 | 93 | for step_i in range(self.n_step + 1): 94 | x = self.feature_transforms[step_i]( 95 | masked_features, training=training, alpha=alpha 96 | ) 97 | 98 | if step_i > 0: 99 | out = tf.keras.activations.relu(x[:, : self.output_dim]) 100 | out_agg += out 101 | 102 | # no need to build the features mask for the last step 103 | if step_i < self.n_step: 104 | x_for_mask = x[:, self.output_dim:] 105 | 106 | mask_values = self.attentive_transforms[step_i]( 107 | [x_for_mask, prior_scales], training=training, alpha=alpha 108 | ) 109 | 110 | # relaxation factor of 1 forces the feature to be only used once. 111 | prior_scales *= self.relaxation_factor - mask_values 112 | 113 | masked_features = tf.multiply(mask_values, features) 114 | 115 | # entropy is used to penalize the amount of sparsity in feature selection 116 | total_entropy += tf.reduce_mean( 117 | tf.reduce_sum( 118 | tf.multiply(mask_values, tf.math.log(mask_values + 1e-15)), 119 | axis=1, 120 | ) 121 | ) 122 | 123 | masks.append(tf.expand_dims(tf.expand_dims(mask_values, 0), 3)) 124 | 125 | loss = total_entropy / self.n_step 126 | 127 | return out_agg, loss, masks 128 | -------------------------------------------------------------------------------- /tabnet/models/transformers.py: -------------------------------------------------------------------------------- 1 | from typing import List 2 | 3 | import tensorflow as tf 4 | from tensorflow_addons.activations import sparsemax 5 | 6 | from tabnet.models.utils import glu 7 | from tabnet.models.gbn import GhostBatchNormalization 8 | 9 | 10 | class FeatureBlock(tf.keras.Model): 11 | def __init__( 12 | self, 13 | feature_dim: int, 14 | apply_glu: bool = True, 15 | bn_momentum: float = 0.9, 16 | bn_virtual_divider: int = 32, 17 | fc: tf.keras.layers.Layer = None, 18 | epsilon: float = 1e-5, 19 | ): 20 | super(FeatureBlock, self).__init__() 21 | self.apply_gpu = apply_glu 22 | self.feature_dim = feature_dim 23 | units = feature_dim * 2 if apply_glu else feature_dim 24 | 25 | self.fc = tf.keras.layers.Dense(units, use_bias=False) if fc is None else fc 26 | self.bn = GhostBatchNormalization( 27 | virtual_divider=bn_virtual_divider, momentum=bn_momentum 28 | ) 29 | 30 | def call(self, x, training: bool = None, alpha: float = 0.0): 31 | x = self.fc(x) 32 | x = self.bn(x, training=training, alpha=alpha) 33 | if self.apply_gpu: 34 | return glu(x, self.feature_dim) 35 | return x 36 | 37 | 38 | class AttentiveTransformer(tf.keras.Model): 39 | def __init__(self, feature_dim: int, bn_momentum: float, bn_virtual_divider: int): 40 | super(AttentiveTransformer, self).__init__() 41 | self.block = FeatureBlock( 42 | feature_dim, 43 | bn_momentum=bn_momentum, 44 | bn_virtual_divider=bn_virtual_divider, 45 | apply_glu=False, 46 | ) 47 | 48 | def call(self, inputs, training=None, alpha: float = 0.0): 49 | x_for_mask, prior_scales = inputs 50 | x = self.block(x_for_mask, training=training, alpha=alpha) 51 | return sparsemax(x * prior_scales) 52 | 53 | 54 | class FeatureTransformer(tf.keras.Model): 55 | def __init__( 56 | self, 57 | feature_dim: int, 58 | fcs: List[tf.keras.layers.Layer] = [], 59 | n_total: int = 4, 60 | n_shared: int = 2, 61 | bn_momentum: float = 0.9, 62 | bn_virtual_divider: int = 1, 63 | ): 64 | super(FeatureTransformer, self).__init__() 65 | self.n_total, self.n_shared = n_total, n_shared 66 | 67 | kargs = { 68 | "feature_dim": feature_dim, 69 | "bn_momentum": bn_momentum, 70 | "bn_virtual_divider": bn_virtual_divider, 71 | } 72 | 73 | # build blocks 74 | self.blocks: List[FeatureBlock] = [] 75 | for n in range(n_total): 76 | # some shared blocks 77 | if fcs and n < len(fcs): 78 | self.blocks.append(FeatureBlock(**kargs, fc=fcs[n])) 79 | # build new blocks 80 | else: 81 | self.blocks.append(FeatureBlock(**kargs)) 82 | 83 | def call( 84 | self, x: tf.Tensor, training: bool = None, alpha: float = 0.0 85 | ) -> tf.Tensor: 86 | x = self.blocks[0](x, training=training, alpha=alpha) 87 | for n in range(1, self.n_total): 88 | x = x * tf.sqrt(0.5) + self.blocks[n](x, training=training, alpha=alpha) 89 | return x 90 | 91 | @property 92 | def shared_fcs(self): 93 | return [self.blocks[i].fc for i in range(self.n_shared)] 94 | -------------------------------------------------------------------------------- /tabnet/models/utils.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | 3 | 4 | # taken from https://github.com/google-research/google-research/blob/master/tabnet/tabnet_model.py 5 | def glu(x, n_units=None): 6 | """Generalized linear unit nonlinear activation.""" 7 | return x[:, :n_units] * tf.nn.sigmoid(x[:, n_units:]) 8 | -------------------------------------------------------------------------------- /tabnet/schedules/__init__.py: -------------------------------------------------------------------------------- 1 | from tabnet.schedules.decay_with_warmup import DecayWithWarmupSchedule 2 | -------------------------------------------------------------------------------- /tabnet/schedules/decay_with_warmup.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | from tensorflow.keras.optimizers.schedules import LearningRateSchedule 3 | 4 | 5 | class DecayWithWarmupSchedule(LearningRateSchedule): 6 | def __init__( 7 | self, learning_rate, min_learning_rate, warmup, decay_rate, decay_steps 8 | ): 9 | super(DecayWithWarmupSchedule, self).__init__() 10 | self.learning_rate, self.min_learning_rate = learning_rate, min_learning_rate 11 | self.warmup = warmup 12 | self.decay_rate, self.decay_steps = decay_rate, decay_steps 13 | 14 | self.m = (learning_rate - min_learning_rate) / warmup 15 | self.b = learning_rate - self.m * warmup 16 | 17 | def __call__(self, step): 18 | return tf.cond( 19 | tf.greater(step, self.warmup), 20 | lambda: self.learning_rate 21 | * tf.pow(self.decay_rate, (step / self.decay_steps)), 22 | lambda: self.m * step + self.b, 23 | ) 24 | 25 | def get_config(self): 26 | return { 27 | "learning_rate": self.learning_rate, 28 | "min_learning_rate": self.min_learning_rate, 29 | "warmup": self.warmup, 30 | "decay_rate": self.decay_rate, 31 | "decay_steps": self.decay_steps, 32 | } 33 | -------------------------------------------------------------------------------- /tabnet/utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | import random 3 | import tensorflow as tf 4 | import numpy as np 5 | 6 | 7 | def set_seed(seed: int = 42): 8 | # reference: https://github.com/NVIDIA/framework-determinism 9 | os.environ["PYTHONHASHSEED"] = str(seed) 10 | random.seed(seed) 11 | np.random.seed(seed) 12 | tf.random.set_seed(seed) 13 | -------------------------------------------------------------------------------- /tests/test_classify.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | import tensorflow as tf 3 | import numpy as np 4 | import shutil 5 | 6 | from tabnet.models import TabNetClassifier 7 | 8 | 9 | CONFIGS = {"num_features": 20, "feature_dim": 32, "output_dim": 64, "n_classes": 10} 10 | 11 | 12 | OUTPUT_FOLDER = ".tmp" 13 | 14 | 15 | @pytest.fixture() 16 | def model(): 17 | net = TabNetClassifier(**CONFIGS) 18 | net.build((None, CONFIGS["num_features"])) 19 | return net 20 | 21 | 22 | @pytest.fixture() 23 | def features(): 24 | return tf.random.uniform((32, CONFIGS["num_features"])) * 2 25 | 26 | 27 | @pytest.fixture() 28 | def output_folder(): 29 | yield OUTPUT_FOLDER 30 | shutil.rmtree(OUTPUT_FOLDER) 31 | 32 | 33 | class TestClassify: 34 | def test_can_save_model(self, model, output_folder, features): 35 | # save to folder 36 | model.save_to_directory(output_folder) 37 | out = model(features, training=False, alpha=1) 38 | # load from folder 39 | model_loaded = TabNetClassifier.load_from_directory(output_folder) 40 | out_loaded = model_loaded(features, training=False, alpha=1) 41 | 42 | assert model.configs.keys() == model_loaded.configs.keys() 43 | for k, v in model_loaded.configs.items(): 44 | assert model.configs[k] == v 45 | 46 | assert np.allclose( 47 | model_loaded.head.weights[0].numpy(), model.head.weights[0].numpy() 48 | ) 49 | assert np.allclose(out, out_loaded) 50 | -------------------------------------------------------------------------------- /tests/test_custom_bn.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | import tensorflow as tf 3 | 4 | from tabnet.models.gbn import BatchNormInferenceWeighting 5 | 6 | 7 | class TestCustomBatchNorm(tf.test.TestCase): 8 | def setUp(self): 9 | self.x = tf.random.uniform(shape=(32, 54), dtype=tf.float32) 10 | self.zeros = tf.zeros(self.x.shape[1]) 11 | self.ones = tf.ones(self.x.shape[1]) 12 | 13 | def test_can_apply_bn_in_training(self): 14 | bn = BatchNormInferenceWeighting() 15 | x_bn = bn(self.x, training=True) 16 | 17 | mean = tf.reduce_mean(x_bn, axis=0) 18 | std = tf.sqrt( 19 | tf.reduce_mean(tf.pow(x_bn, 2), axis=0) 20 | - tf.pow(tf.reduce_mean(x_bn, axis=0), 2) 21 | ) 22 | 23 | self.assertAllClose(mean, self.zeros, rtol=1e-04, atol=1e-04) 24 | self.assertAllClose(std, self.ones, rtol=1e-04, atol=1e-04) 25 | 26 | def test_update_moving_stats_only_in_training(self): 27 | bn = BatchNormInferenceWeighting() 28 | _ = bn(self.x, training=False) 29 | 30 | self.assertAllClose(bn.moving_mean, self.zeros) 31 | self.assertAllClose(bn.moving_mean_of_squares, self.zeros) 32 | 33 | _ = bn(self.x, training=True) 34 | 35 | self.assertNotAllClose(bn.moving_mean, self.zeros) 36 | self.assertNotAllClose(bn.moving_mean_of_squares, self.zeros) 37 | 38 | def test_similar_to_keras(self): 39 | bn = BatchNormInferenceWeighting(momentum=0.9) 40 | bn_keras = tf.keras.layers.BatchNormalization( 41 | momentum=0.9, epsilon=tf.keras.backend.epsilon() 42 | ) 43 | 44 | x_bn = bn(self.x, training=True) 45 | x_bn_keras = bn_keras(self.x, training=True) 46 | 47 | self.assertAllClose(x_bn, x_bn_keras, rtol=1e-4, atol=1e-4) 48 | # TODO check moving mean & std 49 | 50 | 51 | if __name__ == "__main__": 52 | tf.test.main() 53 | -------------------------------------------------------------------------------- /tests/test_dataset.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import tensorflow as tf 3 | 4 | from tabnet.datasets.covertype import get_data, get_dataset 5 | 6 | 7 | COVTYPE_CSV_PATH = "data/test/covtype_sample.csv" 8 | SEED = 42 9 | 10 | 11 | class TestDataset(tf.test.TestCase): 12 | def test_gets_always_the_same_data(self): 13 | df_tr, df_val, df_test = get_data(COVTYPE_CSV_PATH, seed=SEED) 14 | df2_tr, df2_val, df2_test = get_data(COVTYPE_CSV_PATH, seed=SEED) 15 | 16 | self.assertAllClose( 17 | df_tr.values.astype(np.float32), df2_tr.values.astype(np.float32) 18 | ) 19 | self.assertAllClose( 20 | df_val.values.astype(np.float32), df2_val.values.astype(np.float32) 21 | ) 22 | self.assertAllClose( 23 | df_test.values.astype(np.float32), df2_test.values.astype(np.float32) 24 | ) 25 | 26 | def __get_labels(self, ds: tf.data.Dataset, n_iter: int): 27 | labels = [] 28 | ds_iter = iter(ds) 29 | for i in range(n_iter): 30 | _, label = next(ds_iter) 31 | labels.append(label) 32 | return tf.concat(labels, axis=0) 33 | 34 | def test_dataset_is_deterministic(self): 35 | df_tr, _, _ = get_data(COVTYPE_CSV_PATH, seed=SEED) 36 | 37 | ds_tr = get_dataset(df_tr, shuffle=True, batch_size=32, seed=SEED, take=2) 38 | labels1 = self.__get_labels(ds_tr, 20) 39 | 40 | ds_tr = get_dataset(df_tr, shuffle=True, batch_size=32, seed=SEED, take=2) 41 | labels2 = self.__get_labels(ds_tr, 20) 42 | 43 | self.assertAllClose(labels1, labels2) 44 | 45 | 46 | if __name__ == "__main__": 47 | tf.test.main() 48 | -------------------------------------------------------------------------------- /tests/test_lr_finder.py: -------------------------------------------------------------------------------- 1 | import os 2 | import pytest 3 | 4 | import numpy as np 5 | import tensorflow as tf 6 | 7 | from tabnet.callbacks.lrfinder import LRFinder 8 | 9 | 10 | BS = 16 11 | NUM_STEPS = 10 12 | FIGNAME = "test.png" 13 | 14 | 15 | @pytest.fixture() 16 | def dataset() -> tf.data.Dataset: 17 | # generate fake data 18 | size_of_dataset = BS * 10 19 | 20 | x = tf.random.uniform((size_of_dataset, 1), minval=-1, maxval=1) 21 | y = ( 22 | (2 * tf.random.uniform((size_of_dataset, 1), minval=0.9, maxval=1.1)) * x 23 | + 10 24 | + tf.random.uniform((size_of_dataset, 1), minval=-2.0, maxval=2.0) 25 | ) 26 | 27 | ds = tf.data.Dataset.from_tensor_slices((x, y)) 28 | ds = ds.shuffle(x.shape[0]).batch(BS, drop_remainder=False) 29 | return ds 30 | 31 | 32 | @pytest.fixture() 33 | def model() -> tf.keras.Model: 34 | return tf.keras.Sequential([tf.keras.layers.Dense(1)]) 35 | 36 | 37 | @pytest.fixture 38 | def clean_output(): 39 | yield 40 | os.remove(FIGNAME) 41 | 42 | 43 | @pytest.mark.usefixtures("clean_output") 44 | def test_can_run_the_lr_finder(model: tf.keras.Model, dataset: tf.data.Dataset): 45 | min_lr = 1e-6 46 | max_lr = 1e-1 47 | 48 | model.compile( 49 | optimizer=tf.keras.optimizers.SGD(), loss=tf.keras.losses.MeanSquaredError() 50 | ) 51 | 52 | lrfinder = LRFinder(min_lr, max_lr, num_steps=NUM_STEPS, figname=FIGNAME) 53 | 54 | model.fit(dataset, epochs=1, callbacks=[lrfinder]) 55 | 56 | assert len(lrfinder.losses) == NUM_STEPS 57 | assert len(lrfinder.lrs) == NUM_STEPS 58 | assert lrfinder.lrs[0] == min_lr 59 | assert lrfinder.lrs[-1] == max_lr 60 | 61 | # by default should have saved a figure with the results 62 | assert os.path.exists(lrfinder.figname) 63 | -------------------------------------------------------------------------------- /tests/test_tabnet.py: -------------------------------------------------------------------------------- 1 | import os 2 | import shutil 3 | import pytest 4 | import tensorflow as tf 5 | import numpy as np 6 | from tensorflow.python.framework.test_util import run_all_in_graph_and_eager_modes 7 | 8 | from tabnet.models import TabNet 9 | from tabnet.models.transformers import FeatureBlock 10 | from tabnet.datasets.covertype import get_data, get_dataset 11 | 12 | 13 | COVTYPE_CSV_PATH = "data/test/covtype_sample.csv" 14 | FEATURE_DIM = 50 15 | TMP_DIR = ".tmp" 16 | 17 | 18 | @pytest.fixture() 19 | def features(): 20 | return tf.random.uniform([32, FEATURE_DIM], -1.0, 1.0) 21 | 22 | 23 | @pytest.fixture() 24 | def model(features): 25 | model = TabNet(features.shape[1], feature_dim=16, output_dim=16, n_step=2) 26 | model.build(features.shape) 27 | return model 28 | 29 | 30 | @pytest.fixture() 31 | def saved_model_path(model: tf.keras.Model): 32 | path = os.path.join(TMP_DIR, "saved_model") 33 | model.save_weights(path, overwrite=True) 34 | yield path 35 | shutil.rmtree(TMP_DIR) 36 | 37 | 38 | class TestTabNet: 39 | def test_feature_transformer_block(self, features): 40 | block = FeatureBlock(FEATURE_DIM, apply_glu=True, bn_virtual_divider=1) 41 | x = block(features, training=False) 42 | assert x.shape[1] == features.shape[1] 43 | 44 | def test_tabnet_model(self, model, features): 45 | y, _, _ = model(features, training=True) 46 | assert y.shape[0] == features.shape[0] 47 | assert y.shape[1] == 16 48 | 49 | def test_tabnet_with_alpha(self, model, features): 50 | # in training mode alpha should change nothing 51 | y_with_alpha, _, _ = model(features, training=True, alpha=0.5) 52 | y_no_alpha, _, _ = model(features, training=True) 53 | 54 | np.allclose(y_with_alpha, y_no_alpha) 55 | 56 | # in inference mode when alpha > 1.0 the batch stats will be used 57 | 58 | y_with_alpha, _, _ = model(features, training=False, alpha=0.5) 59 | y_no_alpha, _, _ = model(features, training=False) 60 | 61 | np.allclose(y_with_alpha, y_no_alpha) 62 | 63 | # @pytest.skip(msg="Saving takes too much time.") 64 | def test_can_infer_with_saved_model( 65 | self, model: tf.keras.Model, features, saved_model_path 66 | ): 67 | model.load_weights(saved_model_path) 68 | out1, _, _ = model(features, training=False, alpha=0.5) 69 | # out2 will all be zeros since bn moving stats are still zeros at that point 70 | out2, _, _ = model(features, training=False) 71 | assert not np.allclose(out1.numpy(), out2.numpy()) 72 | --------------------------------------------------------------------------------