├── .gitignore ├── requirements.txt ├── dofen_default_config.py ├── README.md ├── DOFEN_on_tabular_benchmark.ipynb ├── LICENSE └── model.py /.gitignore: -------------------------------------------------------------------------------- 1 | tabular-benchmark 2 | __pycache__ 3 | .ipynb_checkpoints/ 4 | openml_id_keyword_mapping.csv 5 | *.swp 6 | *venv 7 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | captum==0.5.0 2 | catboost==1.0.6 3 | category_encoders==2.4.0 4 | ConfigArgParse==1.5.3 5 | icecream==2.1.1 6 | einops==0.6.0 7 | lightgbm==3.2.1 8 | matplotlib==3.4.2 9 | numba==0.55.1 10 | numpy==1.21 11 | nvidia_ml_py3==7.352.0 12 | openml==0.12.2 13 | optuna==2.10.0 14 | plotly==5.10.0 15 | protobuf==3.20.* 16 | pynvml==11.4.1 17 | pytomlpp==1.0.10 18 | PyYAML==6.0 19 | requests==2.25.1 20 | scikit_learn==1.1.2 21 | scipy==1.6.2 22 | shap==0.39.0 23 | skorch==0.10.0 24 | torch==1.10.1 25 | tqdm==4.62.3 26 | wandb==0.13.5 27 | xgboost==1.5.2 28 | zero==0.9.1 29 | pandas==2.0.3 -------------------------------------------------------------------------------- /dofen_default_config.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from sklearn.metrics import accuracy_score, r2_score 3 | 4 | dofen_config = { 5 | 'm': 16, 6 | 'd': 4, 7 | 'n_head': 1, 8 | 'n_forest': 100, 9 | 'n_hidden': 128, 10 | 'dropout': 0.0, 11 | 12 | 'categorical_optimized': False, 13 | 'fast_mode': 32, 14 | 'use_bagging_loss': False, 15 | 16 | 'device': torch.device('cuda:0'), 17 | 'verbose': True 18 | } 19 | 20 | train_config = { 21 | 'batch_size': 256, 22 | 'n_epochs': 500, 23 | 'early_stopping_patience': -1, 24 | 'save_dir': './', 25 | 'ckpt_name': 'best' 26 | } 27 | 28 | eval_config = { 29 | 'metric': { 30 | 'classification': accuracy_score, 31 | 'regression': r2_score 32 | } 33 | } -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # :dolphin: DOFEN: Deep Oblivious Forest Ensemble :dolphin: 2 | This is the official implementation of DOFEN, a novel tree-inspired deep tabular neural network, accepted by NeurIPS 2024 ([openreview link](https://openreview.net/forum?id=umukvCdGI6)) 3 | 4 | ### Installation and build benchmark dataset 5 | ``` 6 | ### we use python 3.8 ### 7 | pip install requirements.txt 8 | source build_tabular_benchmark_data.sh 9 | ``` 10 | 11 | ### How to use DOFEN 12 | The `DOFEN_on_tabular_benchmark.ipynb` notebook provides more detailed usage and setting of DOFEN on tabular benchmark, here we provide a quick simple view: 13 | ```python 14 | from model import DOFENTrainer 15 | from dofen_default_config import dofen_config, train_config, eval_config 16 | 17 | # prepare your training data 18 | tr_x = ... 19 | tr_y = ... 20 | 21 | # provide dataset specific information 22 | dofen_config['column_category_count'] = ... # list of int, number of categories for each column, set the value to -1 for numerical columns 23 | dofen_config['n_class'] = ... # int, number of class of a dataset, please set to 2 for binary tasks, set to 'number of class' for multiclass tasks , and set to -1 for regression tasks 24 | 25 | # model initialize, for detail usage and descriptions of these three configs, please see docstring of DOFENTrainer 26 | dofen_trainer = DOFENTrainer(dofen_config, train_config, eval_config) 27 | dofen_trainer.init() 28 | 29 | # fit dofen on training data 30 | dofen_trainer.fit(tr_x, tr_y) 31 | 32 | # prepare your testing data and evaluate 33 | te_x = ... 34 | te_y = ... 35 | dofen_trainer.evaluate(te_x, te_y) 36 | ``` -------------------------------------------------------------------------------- /DOFEN_on_tabular_benchmark.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": null, 6 | "metadata": {}, 7 | "outputs": [], 8 | "source": [ 9 | "from model import DOFENTrainer\n", 10 | "from dofen_default_config import dofen_config, train_config, eval_config\n", 11 | "import pickle\n", 12 | "import torch" 13 | ] 14 | }, 15 | { 16 | "cell_type": "code", 17 | "execution_count": null, 18 | "metadata": {}, 19 | "outputs": [], 20 | "source": [ 21 | "data_dir = './tabular-benchmark/tabular_benchmark_data/'\n", 22 | "data_name = 361111\n", 23 | "data_seed = 0\n", 24 | "\n", 25 | "target_transform = False\n", 26 | "\n", 27 | "with open(f'{data_dir}/{data_name}/{data_seed}.pkl', 'rb') as f:\n", 28 | " data_dict = pickle.load(f)\n", 29 | " \n", 30 | "tr_x = data_dict['x_train']\n", 31 | "va_x = data_dict['x_val']\n", 32 | "te_x = data_dict['x_test']\n", 33 | "\n", 34 | "tr_y = data_dict['y_train' if not target_transform else 'y_train_transform']\n", 35 | "va_y = data_dict['y_val' if not target_transform else 'y_val_transform']\n", 36 | "te_y = data_dict['y_test' if not target_transform else 'y_test_transform']\n", 37 | "\n", 38 | "col_cat_count = data_dict['col_cat_count']\n", 39 | "label_cat_count = data_dict['label_cat_count']" 40 | ] 41 | }, 42 | { 43 | "cell_type": "code", 44 | "execution_count": null, 45 | "metadata": {}, 46 | "outputs": [], 47 | "source": [ 48 | "dofen_config['column_category_count'] = col_cat_count\n", 49 | "dofen_config['n_class'] = label_cat_count\n", 50 | "dofen_trainer = DOFENTrainer(dofen_config, train_config, eval_config)" 51 | ] 52 | }, 53 | { 54 | "cell_type": "code", 55 | "execution_count": null, 56 | "metadata": {}, 57 | "outputs": [], 58 | "source": [ 59 | "dofen_trainer.init()" 60 | ] 61 | }, 62 | { 63 | "cell_type": "code", 64 | "execution_count": null, 65 | "metadata": {}, 66 | "outputs": [], 67 | "source": [ 68 | "dofen_trainer.fit(tr_x, tr_y, te_x=te_x, te_y=te_y)" 69 | ] 70 | } 71 | ], 72 | "metadata": { 73 | "kernelspec": { 74 | "display_name": "dofen_venv", 75 | "language": "python", 76 | "name": "dofen_venv" 77 | }, 78 | "language_info": { 79 | "codemirror_mode": { 80 | "name": "ipython", 81 | "version": 3 82 | }, 83 | "file_extension": ".py", 84 | "mimetype": "text/x-python", 85 | "name": "python", 86 | "nbconvert_exporter": "python", 87 | "pygments_lexer": "ipython3", 88 | "version": "3.8.10" 89 | } 90 | }, 91 | "nbformat": 4, 92 | "nbformat_minor": 4 93 | } 94 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright 2024 Sinopac Financial Holdings Co., Ltd. 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | -------------------------------------------------------------------------------- /model.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import numpy as np 4 | import time 5 | 6 | import datetime 7 | import pickle 8 | from torch.utils.data import DataLoader, Dataset, TensorDataset 9 | 10 | 11 | class Reshape(nn.Module): 12 | def __init__(self, *args): 13 | super(Reshape, self).__init__() 14 | self.shape = args 15 | 16 | def forward(self, x): 17 | return x.reshape(self.shape) 18 | 19 | 20 | class FastGroupConv1d(nn.Conv1d): 21 | ''' 22 | This class is built to resolve the issue: the slow operation speed of group convolution when number of groups are large. 23 | We found that directly using self-written matrix multiplication can dramatically accelerate operation speed of group convolution. 24 | The `fast_mode` argument decides when to switch from native operation in `nn.Conv1d` to self-written one by setting the group threshold. 25 | 26 | See following pytorch issues for more detailed description of this bug: 27 | * https://github.com/pytorch/pytorch/issues/18631 28 | * https://github.com/pytorch/pytorch/issues/70954 29 | * https://github.com/pytorch/pytorch/issues/73764 30 | ''' 31 | def __init__(self, *args, **kwargs): 32 | self.fast_mode = kwargs.pop('fast_mode') 33 | nn.Conv1d.__init__(self, *args, **kwargs) 34 | if self.groups > self.fast_mode: 35 | self.weight = nn.Parameter( 36 | self.weight.reshape( 37 | self.groups, self.out_channels // self.groups, self.in_channels // self.groups, 1 38 | ).permute(3, 0, 2, 1) 39 | ) 40 | self.bias = nn.Parameter( 41 | self.bias.unsqueeze(0).unsqueeze(-1) 42 | ) 43 | 44 | def forward(self, x): 45 | if self.groups > self.fast_mode: 46 | x = x.reshape(-1, self.groups, self.in_channels // self.groups, 1) 47 | return (x * self.weight).sum(2, keepdims=True).permute(0, 1, 3, 2).reshape(-1, self.out_channels, 1) + self.bias 48 | else: 49 | return self._conv_forward(x, self.weight, self.bias) 50 | 51 | 52 | class ConditionGeneration(nn.Module): 53 | def __init__(self, column_category_count, n_cond=128, categorical_optimized=False, fast_mode=64, device=torch.device('cpu')): 54 | super(ConditionGeneration, self).__init__() 55 | self.device = device 56 | self.fast_mode = fast_mode 57 | self.categorical_optimized = categorical_optimized 58 | 59 | self.num_index, self.cat_index, self.cat_count, self.cat_offset = self.get_num_cat_index(column_category_count) 60 | self.n_cond = n_cond 61 | self.phi_1 = self.get_phi_1() 62 | 63 | def get_num_cat_index(self, column_category_count): 64 | num_index = [] 65 | cat_index = [] 66 | cat_count = [] 67 | for idx, ele in enumerate(column_category_count): 68 | if ele == -1: 69 | num_index.append(idx) 70 | else: 71 | cat_index.append(idx) 72 | cat_count.append(ele) 73 | cat_offset = torch.tensor([0] + np.cumsum(cat_count).tolist()[:-1]).long().to(self.device) 74 | return num_index, cat_index, cat_count, cat_offset 75 | 76 | def get_phi_1(self, ): 77 | phi_1 = nn.ModuleDict() 78 | if len(self.num_index): 79 | phi_1['num'] = nn.Sequential( 80 | # input = (b, n_num_col) 81 | # output = (b, n_num_col, n_cond) 82 | Reshape(-1, len(self.num_index), 1), 83 | FastGroupConv1d(len(self.num_index), len(self.num_index)*self.n_cond, kernel_size=1, groups=len(self.num_index), fast_mode=self.fast_mode), 84 | nn.Sigmoid(), 85 | Reshape(-1, len(self.num_index), self.n_cond) 86 | ) 87 | if len(self.cat_index): 88 | phi_1['cat'] = nn.ModuleDict() 89 | phi_1['cat']['embedder'] = nn.Embedding(sum(self.cat_count), self.n_cond) 90 | phi_1['cat']['mapper'] = nn.Sequential( 91 | # input = (b, n_cat_col, n_cond) 92 | # output = (b, n_cat_col, n_cond) 93 | Reshape(-1, len(self.cat_index) * self.n_cond, 1), 94 | nn.GroupNorm(len(self.cat_index), len(self.cat_index) * self.n_cond), 95 | FastGroupConv1d(len(self.cat_index)*self.n_cond, len(self.cat_index)*self.n_cond, kernel_size=1, groups=len(self.cat_index)*self.n_cond if self.categorical_optimized else len(self.cat_index), fast_mode=self.fast_mode), 96 | nn.Sigmoid(), 97 | Reshape(-1, len(self.cat_index), self.n_cond) 98 | ) 99 | return phi_1 100 | 101 | def forward(self, x): 102 | M = [] 103 | 104 | if len(self.num_index): 105 | num_x = x[:, self.num_index].float() 106 | num_sample_emb = self.phi_1['num'](num_x) 107 | M.append(num_sample_emb) 108 | 109 | if len(self.cat_index): 110 | cat_x = x[:, self.cat_index].long() + self.cat_offset 111 | cat_sample_emb = self.phi_1['cat']['mapper'](self.phi_1['cat']['embedder'](cat_x)) 112 | M.append(cat_sample_emb) 113 | 114 | M = torch.cat(M, dim=1) # (b, n_col, n_cond) 115 | M = M.permute(0, 2, 1) # (b, n_cond, n_col) 116 | return M 117 | 118 | 119 | class rODTConstruction(nn.Module): 120 | def __init__(self, n_cond, n_col): 121 | super().__init__() 122 | self.permutator = torch.rand(n_cond * n_col).argsort(-1) 123 | 124 | def forward(self, M): 125 | return M.reshape(M.shape[0], -1, 1)[:, self.permutator, :] 126 | 127 | 128 | class rODTForestConstruction(nn.Module): 129 | def __init__(self, n_col, n_rodt, n_cond, n_estimator, n_head=1, n_hidden=128, n_forest=100, dropout=0.0, fast_mode=64, device=torch.device('cpu')): 130 | super().__init__() 131 | 132 | self.device = device 133 | self.n_estimator = n_estimator 134 | self.n_forest = n_forest 135 | self.n_rodt = n_rodt 136 | self.n_head = n_head 137 | self.n_hidden = n_hidden 138 | 139 | self.phi_2 = nn.Sequential( 140 | nn.GroupNorm(n_rodt, n_cond * n_col), 141 | nn.Dropout(dropout), 142 | FastGroupConv1d(n_cond * n_col, n_cond * n_col, groups=n_rodt, kernel_size=1, fast_mode=fast_mode), 143 | nn.ReLU(), 144 | nn.GroupNorm(n_rodt, n_cond * n_col), 145 | nn.Dropout(dropout), 146 | FastGroupConv1d(n_cond * n_col, n_rodt * n_head, groups=n_rodt, kernel_size=1, fast_mode=fast_mode), 147 | Reshape(-1, n_rodt, n_head) 148 | ) 149 | self.E = nn.Embedding(n_rodt, n_hidden) 150 | self.sample_without_replacement_eval = self.get_sample_without_replacement() 151 | 152 | def get_sample_without_replacement(self, ): 153 | return torch.rand(self.n_forest, self.n_rodt, device=self.device).argsort(-1)[:, :self.n_estimator] 154 | 155 | def forward(self, O): 156 | b = O.shape[0] 157 | w = self.phi_2(O) # (b, n_rodt, n_head) 158 | E = self.E.weight.unsqueeze(0) # (1, n_rodt, n_hidden) 159 | 160 | sample_without_replacement = self.get_sample_without_replacement() if self.training else self.sample_without_replacement_eval 161 | 162 | w_prime = w[:, sample_without_replacement].softmax(-2).unsqueeze(-1) # (b, n_forest, n_rodt, n_head, 1) 163 | E_prime = E[:, sample_without_replacement].reshape( 164 | 1, self.n_forest, self.n_estimator, self.n_head, self.n_hidden // self.n_head 165 | ) # (1, n_forest, n_rodt, n_head, n_hidden // n_head) 166 | F = (w_prime * E_prime).sum(-3).reshape( 167 | b, self.n_forest, self.n_hidden 168 | ) # (b, n_forest, n_hidden) 169 | return F 170 | 171 | 172 | class rODTForestBagging(nn.Module): 173 | def __init__(self, n_hidden, dropout, n_class): 174 | super().__init__() 175 | self.phi_3 = nn.Sequential( 176 | nn.LayerNorm(n_hidden), 177 | nn.Dropout(dropout), 178 | nn.Linear(n_hidden, n_hidden), 179 | nn.ReLU(), 180 | nn.LayerNorm(n_hidden), 181 | nn.Dropout(dropout), 182 | nn.Linear(n_hidden, n_class) 183 | ) 184 | 185 | def forward(self, F): 186 | y_hat = self.phi_3(F) # (b, n_forest, n_class) 187 | return y_hat 188 | 189 | 190 | class DOFEN(nn.Module): 191 | def __init__( 192 | self, 193 | column_category_count, 194 | n_class, 195 | 196 | m=16, 197 | d=4, 198 | n_head=1, 199 | n_forest=100, 200 | n_hidden=128, 201 | dropout=0.0, 202 | 203 | ### experimental functionality ### 204 | categorical_optimized=False, 205 | fast_mode=2048, 206 | use_bagging_loss=False, 207 | ### ### 208 | 209 | device=torch.device('cpu'), 210 | verbose=False 211 | ): 212 | super().__init__() 213 | 214 | self.device = device 215 | self.n_class = 1 if n_class == -1 else n_class 216 | self.is_rgr = True if n_class == -1 else False 217 | 218 | self.m = m 219 | self.d = d 220 | self.n_head = n_head 221 | self.n_forest = n_forest 222 | self.n_hidden = n_hidden 223 | self.dropout = dropout 224 | self.use_bagging_loss = use_bagging_loss 225 | 226 | self.n_cond = self.d * self.m 227 | self.n_col = len(column_category_count) 228 | self.n_rodt = self.n_cond * self.n_col // self.d 229 | self.n_estimator = max(2, int(self.n_col ** 0.5)) * self.n_cond // self.d 230 | 231 | self.condition_generation = ConditionGeneration( 232 | column_category_count, 233 | n_cond=self.n_cond, 234 | categorical_optimized=categorical_optimized, 235 | fast_mode=fast_mode, 236 | device=self.device 237 | ) 238 | self.rodt_construction = rODTConstruction( 239 | self.n_cond, 240 | self.n_col 241 | ) 242 | self.rodt_forest_construction = rODTForestConstruction( 243 | self.n_col, 244 | self.n_rodt, 245 | self.n_cond, 246 | self.n_estimator, 247 | n_head=self.n_head, 248 | n_hidden=self.n_hidden, 249 | n_forest=self.n_forest, 250 | dropout=self.dropout, 251 | fast_mode=fast_mode, 252 | device=self.device 253 | ) 254 | self.rodt_forest_bagging = rODTForestBagging( 255 | self.n_hidden, 256 | self.dropout, 257 | self.n_class 258 | ) 259 | 260 | if verbose: 261 | print('='*20) 262 | print('total condition: ', self.n_cond * self.n_col) 263 | print('n_rodt: ', self.n_rodt) 264 | print('n_estimator: ', self.n_estimator) 265 | print('='*20) 266 | 267 | def calc_loss(self, y_hat, y): 268 | if self.is_rgr: 269 | loss = torch.nn.functional.mse_loss(y_hat.squeeze(-1), y.float()) 270 | else: 271 | loss = torch.nn.functional.cross_entropy(y_hat, y.long()) 272 | return loss 273 | 274 | def timer(self, x): 275 | self.eval() 276 | x = x.to(self.device) 277 | 278 | times = [] 279 | 280 | times.append(time.perf_counter()) 281 | M = self.condition_generation(x) # (b, n_cond, n_col) 282 | times.append(time.perf_counter()) 283 | O = self.rodt_construction(M) # (b, n_rodt, d) 284 | times.append(time.perf_counter()) 285 | F = self.rodt_forest_construction(O) # (b, n_forest, n_hidden) 286 | times.append(time.perf_counter()) 287 | y_hat = self.rodt_forest_bagging(F) # (b, n_forest, n_class) 288 | times.append(time.perf_counter()) 289 | y_hat_final = y_hat.detach().mean(1) # (b, n_class) 290 | times.append(time.perf_counter()) 291 | 292 | self.train() 293 | 294 | times = np.array(times) 295 | return times[1:] - times[:-1] 296 | 297 | def forward(self, x, y=None): 298 | x = x.to(self.device) 299 | 300 | M = self.condition_generation(x) # (b, n_cond, n_col) 301 | O = self.rodt_construction(M) # (b, n_rodt, d) 302 | F = self.rodt_forest_construction(O) # (b, n_forest, n_hidden) 303 | y_hat = self.rodt_forest_bagging(F) # (b, n_forest, n_class) 304 | y_hat_final = y_hat.detach().mean(1) # (b, n_class) 305 | 306 | if y is not None: 307 | y = y.to(self.device) 308 | loss = self.calc_loss( 309 | y_hat.permute(0, 2, 1) if not self.is_rgr else y_hat, 310 | y.unsqueeze(-1).expand(-1, self.n_forest) 311 | ) 312 | if self.n_forest > 1 and self.training and self.use_bagging_loss: 313 | loss += self.calc_loss(y_hat.mean(1), y) 314 | else: 315 | loss = torch.tensor(0.0) 316 | 317 | return y_hat_final, loss 318 | 319 | 320 | class DOFENTrainer(): 321 | """This is the training and inference interface of DOFEN 322 | 323 | Args: 324 | dofen_config (dict): 325 | ### standard usage ### 326 | - column_category_count (list of int): number of categories for each column, set the value to -1 for numerical columns 327 | - n_class (int): number of class of a dataset, please set to 2 for binary tasks, set to 'number of class' for multiclass tasks, and set to -1 for regression tasks 328 | - m (int): an intermediate parameter ensures that number of rODT is an integer, larger m result in more rODTs, search space = [16, 32, 64] 329 | - d (int): depth of a rODT, search space = [3, 4, 6, 8] 330 | - n_forest (int): number of rODT forest generated for forest ensemble 331 | - n_hidden (int): hidden dimension of rODT embedding 332 | - dropout (float): dropout rate, search space = [0.0, 0.1, 0.2] 333 | - device (torch.device): torch device, Sets the device for tensors initialized in the forward function, accelerating computations by placing them directly on the desired device (e.g. GPU). Note that this setting only affects tensors created in forward, and should match the device used by the DOFEN model. 334 | - verbose (bool): whether to print model configuration when initialize 335 | 336 | ### advanced usage, These functionalities strengthen DOFEN's performance and efficiency, but are not implemented in the paper of NeurIPS 2024 version ### 337 | - n_head (int): A multi-head extension of "Two-level Relaxed Forest Ensemble", increase number of heads (e.g. 4) greatly improves performance on larger datasets (e.g. n_samples > 10000), default = 1, search space = [1, 4, 8] 338 | - categorical_optimized (bool): A simpler encoding layer for categorical columns, when set to True, model uses less parameters but improves performance, default = False. 339 | - fast_mode (int): A faster version of group convolution when having large number of groups (i.e. number of rODTs in DOFEN), will start to use the faster version if number of groups is larger than the set value, default = 32. 340 | - use_bagging_loss (bool): DOFEN default calculate loss individually for each tree, when set ot True, an additional loss is calculated on the ensemble prediction, default = False. 341 | 342 | train_config (dict): 343 | - batch_size (int): Number of batch size. 344 | - n_epochs (int): Number of training epochs. 345 | - early_stopping_patience (int): Training is early stopped if model is not improving on validation set for this many of epochs, DOFEN originally does not use early stopping, default = -1, set this number larger than 0 if you want to early stop. 346 | - save_dir (str): Model save path if early stopping is used. 347 | - ckpt_name (str): Model checkpoint name, model will be saved when validation performance improve. 348 | 349 | eval_config (dict): 350 | - metric (dict): dictionary containing evaluation metrics 351 | - classification (function): Evaluation metric for classification task, default = sklearn.metric.accuracy_score 352 | - regression (function): Evaluation metric for classification task, default = sklearn.metric.r2_score 353 | 354 | Returns: 355 | None 356 | """ 357 | 358 | def __init__(self, dofen_config, train_config, eval_config): 359 | self.dofen_config = dofen_config 360 | self.batch_size = train_config['batch_size'] 361 | self.n_epochs = train_config['n_epochs'] 362 | self.early_stopping_patience = train_config['early_stopping_patience'] 363 | self.save_dir = train_config['save_dir'] 364 | self.ckpt_name = train_config['ckpt_name'] 365 | self.eval_metric = eval_config['metric'] 366 | 367 | def set_seed(self, torch_seed=0, deterministic=True): 368 | torch.manual_seed(torch_seed) 369 | torch.cuda.manual_seed_all(torch_seed) 370 | torch.cuda.manual_seed(torch_seed) 371 | torch.backends.cudnn.deterministic = deterministic 372 | torch.backends.cudnn.benchmark = not deterministic 373 | 374 | def init(self, ): 375 | self.set_seed(0) 376 | self.model = DOFEN(**self.dofen_config).to(self.dofen_config['device']) 377 | self.optimizer = torch.optim.AdamW(self.model.parameters(), lr=1e-3, weight_decay=0.0) 378 | 379 | def fit(self, tr_x, tr_y, va_x=None, va_y=None, te_x=None, te_y=None): 380 | self.tr_dataloader = DataLoader( 381 | TensorDataset(torch.tensor(tr_x), torch.tensor(tr_y)), 382 | batch_size=self.batch_size, shuffle=True, drop_last=False 383 | ) 384 | 385 | if va_x is not None and va_y is not None: 386 | self.va_dataloader = DataLoader( 387 | TensorDataset(torch.tensor(va_x), torch.tensor(va_y)), 388 | batch_size=self.batch_size, shuffle=False, drop_last=False 389 | ) 390 | best_perf = -np.inf 391 | best_epoch = -1 392 | no_improve = 0 393 | else: 394 | self.va_dataloader = None 395 | 396 | if te_x is not None and te_y is not None: 397 | self.te_dataloader = DataLoader( 398 | TensorDataset(torch.tensor(te_x), torch.tensor(te_y)), 399 | batch_size=self.batch_size, shuffle=False, drop_last=False 400 | ) 401 | else: 402 | self.te_dataloader = None 403 | 404 | for epoch in range(self.n_epochs): 405 | self.model.train() 406 | print(f'Epoch: {epoch+1} | {datetime.datetime.now()}') 407 | 408 | for x, y in self.tr_dataloader: 409 | _, loss = self.model(x, y) 410 | loss.backward() 411 | self.optimizer.step() 412 | self.optimizer.zero_grad() 413 | 414 | if self.va_dataloader is not None and self.early_stopping_patience > 0: 415 | perf_name, curr_perf = self.evaluate_with_dataloader(self.va_dataloader) 416 | if curr_perf > best_perf: 417 | best_perf = curr_perf 418 | best_epoch = epoch + 1 419 | no_improve = 0 420 | torch.save(self.model, f'{self.save_dir}/{self.ckpt_name}.ckpt') 421 | else: 422 | no_improve += 1 423 | print(f'Performance not improve for {no_improve} epochs, best epoch is {best_epoch}') 424 | 425 | if no_improve >= self.early_stopping_patience: 426 | break 427 | 428 | if self.te_dataloader is not None: 429 | perf_name, curr_perf = self.evaluate_with_dataloader(self.te_dataloader) 430 | print(f'testing {perf_name}: {curr_perf}') 431 | 432 | def evaluate_with_dataloader(self, dataloader): 433 | self.model.eval() 434 | preds = [] 435 | ys = [] 436 | with torch.no_grad(): 437 | for x, y in dataloader: 438 | pred, loss = self.model(x, y) 439 | preds.append(pred.cpu()) 440 | ys.append(y.cpu()) 441 | preds = torch.cat(preds).numpy() 442 | ys = torch.cat(ys).numpy() 443 | if self.model.is_rgr: 444 | curr_perf = self.eval_metric['regression'](ys, preds.squeeze()) 445 | perf_name = self.eval_metric['regression'].__name__ 446 | else: 447 | curr_perf = self.eval_metric['classification'](ys, preds.argmax(-1)) 448 | perf_name = self.eval_metric['classification'].__name__ 449 | return perf_name, curr_perf 450 | 451 | def predict(self, x, eval_batch_size=256): 452 | x = torch.tensor(x) 453 | self.model.eval() 454 | preds = [] 455 | with torch.no_grad(): 456 | for chunk_x in x.chunk(max(1, x.shape[0] // eval_batch_size)): 457 | preds.append(self.model(chunk_x)[0].cpu()) 458 | preds = torch.cat(preds).numpy() 459 | return preds 460 | 461 | def evaluate(self, x, y, eval_batch_size=256): 462 | pred = self.predict(x, eval_batch_size=eval_batch_size) 463 | if self.model.is_rgr: 464 | curr_perf = self.eval_metric['regression'](y, pred.squeeze()) 465 | perf_name = self.eval_metric['regression'].__name__ 466 | else: 467 | curr_perf = self.eval_metric['classification'](y, pred.argmax(-1)) 468 | perf_name = self.eval_metric['classification'].__name__ 469 | print(f'{perf_name}: {curr_perf}') 470 | --------------------------------------------------------------------------------