├── .gitattributes ├── LICENSE ├── README.md ├── args.py ├── data └── credit_card │ ├── README.md │ └── transactions.tgz ├── dataset ├── __init__.py ├── card.py ├── datacollator.py ├── prsa.py └── vocab.py ├── main.py ├── misc ├── __init__.py ├── cc_trans_dataset.png └── utils.py ├── models ├── __init__.py ├── custom_criterion.py ├── hierarchical.py ├── modules.py ├── tabformer_bert.py ├── tabformer_gpt2.py └── tabformer_tokenizer.py └── setup.yml /.gitattributes: -------------------------------------------------------------------------------- 1 | *.tgz filter=lfs diff=lfs merge=lfs -text 2 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright [yyyy] [name of copyright owner] 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Tabular Transformers for Modeling Multivariate Time Series 2 | 3 | This repository provides the pytorch source code, and data for tabular transformers (TabFormer). Details are described in the paper [Tabular Transformers for Modeling Multivariate Time Series](http://arxiv.org/abs/2011.01843 ), to be presented at ICASSP 2021. 4 | 5 | #### Summary 6 | * Modules for hierarchical transformers for tabular data 7 | * A synthetic credit card transaction dataset 8 | * Modified Adaptive Softmax for handling masking 9 | * Modified _DataCollatorForLanguageModeling_ for tabular data 10 | * The modules are built within transformers from HuggingFace 🤗. (HuggingFace is ❤️) 11 | --- 12 | ### Requirements 13 | * Python (3.7) 14 | * Pytorch (1.6.0) 15 | * HuggingFace / Transformer (3.2.0) 16 | * scikit-learn (0.23.2) 17 | * Pandas (1.1.2) 18 | 19 | (X) represents the versions which code is tested on. 20 | 21 | These can be installed using yaml by running : 22 | ``` 23 | conda env create -f setup.yml 24 | ``` 25 | --- 26 | 27 | ### Credit Card Transaction Dataset 28 | 29 | The synthetic credit card transaction dataset is provided in [./data/credit_card](/data/credit_card/). There are 24M records with 12 fields. 30 | You would need git-lfs to access the data. If you are facing issue related to LFS bandwidth, you can use this [direct link](https://ibm.box.com/v/tabformer-data) to access the data. You can then ignore git-lfs files by prefixing `GIT_LFS_SKIP_SMUDGE=1` to the `git clone ..` command. 31 | 32 | ![figure](./misc/cc_trans_dataset.png) 33 | 34 | --- 35 | 36 | ### PRSA Dataset 37 | For PRSA dataset, one have to download the PRSA dataset from [Kaggle](https://www.kaggle.com/sid321axn/beijing-multisite-airquality-data-set) and place them in [./data/card](/data/card/) directory. 38 | 39 | --- 40 | 41 | ### Tabular BERT 42 | To train a tabular BERT model on credit card transaction or PRSA dataset run : 43 | ``` 44 | $ python main.py --do_train --mlm --field_ce --lm_type bert \ 45 | --field_hs 64 --data_type [prsa/card] \ 46 | --output_dir [output_dir] 47 | ``` 48 | 49 | 50 | ### Tabular GPT2 51 | To train a tabular GPT2 model on credit card transactions for a particular _user-id_ : 52 | ``` 53 | 54 | $ python main.py --do_train --lm_type gpt2 --field_ce --flatten --data_type card \ 55 | --data_root [path_to_data] --user_ids [user-id] \ 56 | --output_dir [output_dir] 57 | 58 | ``` 59 | 60 | Description of some options (more can be found in _`args.py`_): 61 | * `--data_type` choices are `prsa` and `card` for Beijing PM2.5 dataset and credit-card transaction dataset respecitively. 62 | * `--mlm` for masked language model; option for transformer trainer for BERT 63 | * `--field_hs` hidden size for field level transformer 64 | * `--lm_type` choices from `bert` and `gpt2` 65 | * `--user_ids` option to pick only transacations from particular user ids. 66 | --- 67 | 68 | ### Citation 69 | 70 | ``` 71 | @inproceedings{padhi2021tabular, 72 | title={Tabular transformers for modeling multivariate time series}, 73 | author={Padhi, Inkit and Schiff, Yair and Melnyk, Igor and Rigotti, Mattia and Mroueh, Youssef and Dognin, Pierre and Ross, Jerret and Nair, Ravi and Altman, Erik}, 74 | booktitle={ICASSP 2021-2021 IEEE International Conference on Acoustics, Speech and Signal Processing (ICASSP)}, 75 | pages={3565--3569}, 76 | year={2021}, 77 | organization={IEEE}, 78 | url={https://ieeexplore.ieee.org/document/9414142} 79 | } 80 | ``` 81 | -------------------------------------------------------------------------------- /args.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | 3 | 4 | def define_main_parser(parser=None): 5 | if parser is None: 6 | parser = argparse.ArgumentParser() 7 | 8 | parser.add_argument("--jid", type=int, 9 | default=1, 10 | help="job id: 1[default] used for job queue") 11 | parser.add_argument("--seed", type=int, 12 | default=9, 13 | help="seed to use: 9[default]") 14 | 15 | parser.add_argument("--lm_type", default='bert', choices=['bert', 'gpt2'], 16 | help="gpt or bert choice.") 17 | parser.add_argument("--flatten", action='store_true', 18 | help="enable flattened input, no hierarchical") 19 | parser.add_argument("--field_ce", action='store_true', 20 | help="enable field wise CE") 21 | parser.add_argument("--mlm", action='store_true', 22 | help="masked lm loss; pass it for BERT") 23 | parser.add_argument("--mlm_prob", type=float, 24 | default=0.15, 25 | help="mask mlm_probability") 26 | 27 | parser.add_argument("--data_type", type=str, 28 | default="card", choices=['card', 'prsa'], 29 | help='root directory for files') 30 | parser.add_argument("--data_root", type=str, 31 | default="./data/credit_card/", 32 | help='root directory for files') 33 | parser.add_argument("--data_fname", type=str, 34 | default="card_transaction.v1", 35 | help='file name of transaction') 36 | parser.add_argument("--data_extension", type=str, 37 | default="", 38 | help="file name extension to add to cache") 39 | parser.add_argument("--vocab_file", type=str, 40 | default='vocab.nb', 41 | help="cached vocab file") 42 | parser.add_argument('--user_ids', nargs='+', 43 | default=None, 44 | help='pass list of user ids to filter data by') 45 | parser.add_argument("--cached", action='store_true', 46 | help='use cached data files') 47 | parser.add_argument("--nrows", type=int, 48 | default=None, 49 | help="no of transactions to use") 50 | 51 | parser.add_argument("--output_dir", type=str, 52 | default='checkpoints', 53 | help="path to model dump") 54 | parser.add_argument("--checkpoint", type=int, 55 | default=0, 56 | help='set to continue training from checkpoint') 57 | parser.add_argument("--do_train", action='store_true', 58 | help="enable training flag") 59 | parser.add_argument("--do_eval", action='store_true', 60 | help="enable evaluation flag") 61 | parser.add_argument("--save_steps", type=int, 62 | default=500, 63 | help="set checkpointing") 64 | parser.add_argument("--num_train_epochs", type=int, 65 | default=3, 66 | help="number of training epochs") 67 | parser.add_argument("--stride", type=int, 68 | default=5, 69 | help="stride for transaction sliding window") 70 | 71 | parser.add_argument("--field_hs", type=int, 72 | default=768, 73 | help="hidden size for transaction transformer") 74 | parser.add_argument("--skip_user", action='store_true', 75 | help="if user field to be skipped or added (default add)") 76 | 77 | return parser 78 | -------------------------------------------------------------------------------- /data/credit_card/README.md: -------------------------------------------------------------------------------- 1 | This **transactions.tgz** file is also available from Box: https://ibm.box.com/v/tabformer-data 2 | -------------------------------------------------------------------------------- /data/credit_card/transactions.tgz: -------------------------------------------------------------------------------- 1 | version https://git-lfs.github.com/spec/v1 2 | oid sha256:e9f589a0958f40d60f81b1a2e8428db86e00c05755caf44fb055827976c0efa2 3 | size 278576638 4 | -------------------------------------------------------------------------------- /dataset/__init__.py: -------------------------------------------------------------------------------- 1 | from .datacollator import TransDataCollatorForLanguageModeling 2 | -------------------------------------------------------------------------------- /dataset/card.py: -------------------------------------------------------------------------------- 1 | import os 2 | from os import path 3 | import pandas as pd 4 | import numpy as np 5 | import math 6 | import tqdm 7 | import pickle 8 | import logging 9 | 10 | from sklearn.preprocessing import LabelEncoder 11 | from sklearn.preprocessing import MinMaxScaler 12 | 13 | import torch 14 | from torch.utils.data.dataset import Dataset 15 | 16 | from misc.utils import divide_chunks 17 | from dataset.vocab import Vocabulary 18 | 19 | logger = logging.getLogger(__name__) 20 | log = logger 21 | 22 | 23 | class TransactionDataset(Dataset): 24 | def __init__(self, 25 | mlm, 26 | user_ids=None, 27 | seq_len=10, 28 | num_bins=10, 29 | cached=True, 30 | root="./data/card/", 31 | fname="card_trans", 32 | vocab_dir="checkpoints", 33 | fextension="", 34 | nrows=None, 35 | flatten=False, 36 | stride=5, 37 | adap_thres=10 ** 8, 38 | return_labels=False, 39 | skip_user=False): 40 | 41 | self.root = root 42 | self.fname = fname 43 | self.nrows = nrows 44 | self.fextension = f'_{fextension}' if fextension else '' 45 | self.cached = cached 46 | self.user_ids = user_ids 47 | self.return_labels = return_labels 48 | self.skip_user = skip_user 49 | 50 | self.mlm = mlm 51 | self.trans_stride = stride 52 | 53 | self.flatten = flatten 54 | 55 | self.vocab = Vocabulary(adap_thres) 56 | self.seq_len = seq_len 57 | self.encoder_fit = {} 58 | 59 | self.trans_table = None 60 | self.data = [] 61 | self.labels = [] 62 | self.window_label = [] 63 | 64 | self.ncols = None 65 | self.num_bins = num_bins 66 | self.encode_data() 67 | self.init_vocab() 68 | self.prepare_samples() 69 | self.save_vocab(vocab_dir) 70 | 71 | def __getitem__(self, index): 72 | if self.flatten: 73 | return_data = torch.tensor(self.data[index], dtype=torch.long) 74 | else: 75 | return_data = torch.tensor(self.data[index], dtype=torch.long).reshape(self.seq_len, -1) 76 | 77 | if self.return_labels: 78 | return_data = (return_data, torch.tensor(self.labels[index], dtype=torch.long)) 79 | 80 | return return_data 81 | 82 | def __len__(self): 83 | return len(self.data) 84 | 85 | def save_vocab(self, vocab_dir): 86 | file_name = path.join(vocab_dir, f'vocab{self.fextension}.nb') 87 | log.info(f"saving vocab at {file_name}") 88 | self.vocab.save_vocab(file_name) 89 | 90 | @staticmethod 91 | def label_fit_transform(column, enc_type="label"): 92 | if enc_type == "label": 93 | mfit = LabelEncoder() 94 | else: 95 | mfit = MinMaxScaler() 96 | mfit.fit(column) 97 | 98 | return mfit, mfit.transform(column) 99 | 100 | @staticmethod 101 | def timeEncoder(X): 102 | X_hm = X['Time'].str.split(':', expand=True) 103 | d = pd.to_datetime(dict(year=X['Year'], month=X['Month'], day=X['Day'], hour=X_hm[0], minute=X_hm[1])).astype( 104 | int) 105 | return pd.DataFrame(d) 106 | 107 | @staticmethod 108 | def amountEncoder(X): 109 | amt = X.apply(lambda x: x[1:]).astype(float).apply(lambda amt: max(1, amt)).apply(math.log) 110 | return pd.DataFrame(amt) 111 | 112 | @staticmethod 113 | def fraudEncoder(X): 114 | fraud = (X == 'Yes').astype(int) 115 | return pd.DataFrame(fraud) 116 | 117 | @staticmethod 118 | def nanNone(X): 119 | return X.where(pd.notnull(X), 'None') 120 | 121 | @staticmethod 122 | def nanZero(X): 123 | return X.where(pd.notnull(X), 0) 124 | 125 | def _quantization_binning(self, data): 126 | qtls = np.arange(0.0, 1.0 + 1 / self.num_bins, 1 / self.num_bins) 127 | bin_edges = np.quantile(data, qtls, axis=0) # (num_bins + 1, num_features) 128 | bin_widths = np.diff(bin_edges, axis=0) 129 | bin_centers = bin_edges[:-1] + bin_widths / 2 # () 130 | return bin_edges, bin_centers, bin_widths 131 | 132 | def _quantize(self, inputs, bin_edges): 133 | quant_inputs = np.zeros(inputs.shape[0]) 134 | for i, x in enumerate(inputs): 135 | quant_inputs[i] = np.digitize(x, bin_edges) 136 | quant_inputs = quant_inputs.clip(1, self.num_bins) - 1 # Clip edges 137 | return quant_inputs 138 | 139 | def user_level_data(self): 140 | fname = path.join(self.root, f"preprocessed/{self.fname}.user{self.fextension}.pkl") 141 | trans_data, trans_labels = [], [] 142 | 143 | if self.cached and path.isfile(fname): 144 | log.info(f"loading cached user level data from {fname}") 145 | cached_data = pickle.load(open(fname, "rb")) 146 | trans_data = cached_data["trans"] 147 | trans_labels = cached_data["labels"] 148 | columns_names = cached_data["columns"] 149 | 150 | else: 151 | unique_users = self.trans_table["User"].unique() 152 | columns_names = list(self.trans_table.columns) 153 | 154 | for user in tqdm.tqdm(unique_users): 155 | user_data = self.trans_table.loc[self.trans_table["User"] == user] 156 | user_trans, user_labels = [], [] 157 | for idx, row in user_data.iterrows(): 158 | row = list(row) 159 | 160 | # assumption that user is first field 161 | skip_idx = 1 if self.skip_user else 0 162 | 163 | user_trans.extend(row[skip_idx:-1]) 164 | user_labels.append(row[-1]) 165 | 166 | trans_data.append(user_trans) 167 | trans_labels.append(user_labels) 168 | 169 | if self.skip_user: 170 | columns_names.remove("User") 171 | 172 | with open(fname, 'wb') as cache_file: 173 | pickle.dump({"trans": trans_data, "labels": trans_labels, "columns": columns_names}, cache_file) 174 | 175 | # convert to str 176 | return trans_data, trans_labels, columns_names 177 | 178 | def format_trans(self, trans_lst, column_names): 179 | trans_lst = list(divide_chunks(trans_lst, len(self.vocab.field_keys) - 2)) # 2 to ignore isFraud and SPECIAL 180 | user_vocab_ids = [] 181 | 182 | sep_id = self.vocab.get_id(self.vocab.sep_token, special_token=True) 183 | 184 | for trans in trans_lst: 185 | vocab_ids = [] 186 | for jdx, field in enumerate(trans): 187 | vocab_id = self.vocab.get_id(field, column_names[jdx]) 188 | vocab_ids.append(vocab_id) 189 | 190 | # TODO : need to handle ncols when sep is not added 191 | if self.mlm: # and self.flatten: # only add [SEP] for BERT + flatten scenario 192 | vocab_ids.append(sep_id) 193 | 194 | user_vocab_ids.append(vocab_ids) 195 | 196 | return user_vocab_ids 197 | 198 | def prepare_samples(self): 199 | log.info("preparing user level data...") 200 | trans_data, trans_labels, columns_names = self.user_level_data() 201 | 202 | log.info("creating transaction samples with vocab") 203 | for user_idx in tqdm.tqdm(range(len(trans_data))): 204 | user_row = trans_data[user_idx] 205 | user_row_ids = self.format_trans(user_row, columns_names) 206 | 207 | user_labels = trans_labels[user_idx] 208 | 209 | bos_token = self.vocab.get_id(self.vocab.bos_token, special_token=True) # will be used for GPT2 210 | eos_token = self.vocab.get_id(self.vocab.eos_token, special_token=True) # will be used for GPT2 211 | for jdx in range(0, len(user_row_ids) - self.seq_len + 1, self.trans_stride): 212 | ids = user_row_ids[jdx:(jdx + self.seq_len)] 213 | ids = [idx for ids_lst in ids for idx in ids_lst] # flattening 214 | if not self.mlm and self.flatten: # for GPT2, need to add [BOS] and [EOS] tokens 215 | ids = [bos_token] + ids + [eos_token] 216 | self.data.append(ids) 217 | 218 | for jdx in range(0, len(user_labels) - self.seq_len + 1, self.trans_stride): 219 | ids = user_labels[jdx:(jdx + self.seq_len)] 220 | self.labels.append(ids) 221 | 222 | fraud = 0 223 | if len(np.nonzero(ids)[0]) > 0: 224 | fraud = 1 225 | self.window_label.append(fraud) 226 | 227 | assert len(self.data) == len(self.labels) 228 | 229 | ''' 230 | ncols = total fields - 1 (special tokens) - 1 (label) 231 | if bert: 232 | ncols += 1 (for sep) 233 | ''' 234 | self.ncols = len(self.vocab.field_keys) - 2 + (1 if self.mlm else 0) 235 | log.info(f"ncols: {self.ncols}") 236 | log.info(f"no of samples {len(self.data)}") 237 | 238 | def get_csv(self, fname): 239 | data = pd.read_csv(fname, nrows=self.nrows) 240 | if self.user_ids: 241 | log.info(f'Filtering data by user ids list: {self.user_ids}...') 242 | self.user_ids = map(int, self.user_ids) 243 | data = data[data['User'].isin(self.user_ids)] 244 | 245 | self.nrows = data.shape[0] 246 | log.info(f"read data : {data.shape}") 247 | return data 248 | 249 | def write_csv(self, data, fname): 250 | log.info(f"writing to file {fname}") 251 | data.to_csv(fname, index=False) 252 | 253 | def init_vocab(self): 254 | column_names = list(self.trans_table.columns) 255 | if self.skip_user: 256 | column_names.remove("User") 257 | 258 | self.vocab.set_field_keys(column_names) 259 | 260 | for column in column_names: 261 | unique_values = self.trans_table[column].value_counts(sort=True).to_dict() # returns sorted 262 | for val in unique_values: 263 | self.vocab.set_id(val, column) 264 | 265 | log.info(f"total columns: {list(column_names)}") 266 | log.info(f"total vocabulary size: {len(self.vocab.id2token)}") 267 | 268 | for column in self.vocab.field_keys: 269 | vocab_size = len(self.vocab.token2id[column]) 270 | log.info(f"column : {column}, vocab size : {vocab_size}") 271 | 272 | if vocab_size > self.vocab.adap_thres: 273 | log.info(f"\tsetting {column} for adaptive softmax") 274 | self.vocab.adap_sm_cols.add(column) 275 | 276 | def encode_data(self): 277 | dirname = path.join(self.root, "preprocessed") 278 | fname = f'{self.fname}{self.fextension}.encoded.csv' 279 | data_file = path.join(self.root, f"{self.fname}.csv") 280 | 281 | if self.cached and path.isfile(path.join(dirname, fname)): 282 | log.info(f"cached encoded data is read from {fname}") 283 | self.trans_table = self.get_csv(path.join(dirname, fname)) 284 | encoder_fname = path.join(dirname, f'{self.fname}{self.fextension}.encoder_fit.pkl') 285 | self.encoder_fit = pickle.load(open(encoder_fname, "rb")) 286 | return 287 | 288 | data = self.get_csv(data_file) 289 | log.info(f"{data_file} is read.") 290 | 291 | log.info("nan resolution.") 292 | data['Errors?'] = self.nanNone(data['Errors?']) 293 | data['Is Fraud?'] = self.fraudEncoder(data['Is Fraud?']) 294 | data['Zip'] = self.nanZero(data['Zip']) 295 | data['Merchant State'] = self.nanNone(data['Merchant State']) 296 | data['Use Chip'] = self.nanNone(data['Use Chip']) 297 | data['Amount'] = self.amountEncoder(data['Amount']) 298 | 299 | sub_columns = ['Errors?', 'MCC', 'Zip', 'Merchant State', 'Merchant City', 'Merchant Name', 'Use Chip'] 300 | 301 | log.info("label-fit-transform.") 302 | for col_name in tqdm.tqdm(sub_columns): 303 | col_data = data[col_name] 304 | col_fit, col_data = self.label_fit_transform(col_data) 305 | self.encoder_fit[col_name] = col_fit 306 | data[col_name] = col_data 307 | 308 | log.info("timestamp fit transform") 309 | timestamp = self.timeEncoder(data[['Year', 'Month', 'Day', 'Time']]) 310 | timestamp_fit, timestamp = self.label_fit_transform(timestamp, enc_type="time") 311 | self.encoder_fit['Timestamp'] = timestamp_fit 312 | data['Timestamp'] = timestamp 313 | 314 | log.info("timestamp quant transform") 315 | coldata = np.array(data['Timestamp']) 316 | bin_edges, bin_centers, bin_widths = self._quantization_binning(coldata) 317 | data['Timestamp'] = self._quantize(coldata, bin_edges) 318 | self.encoder_fit["Timestamp-Quant"] = [bin_edges, bin_centers, bin_widths] 319 | 320 | log.info("amount quant transform") 321 | coldata = np.array(data['Amount']) 322 | bin_edges, bin_centers, bin_widths = self._quantization_binning(coldata) 323 | data['Amount'] = self._quantize(coldata, bin_edges) 324 | self.encoder_fit["Amount-Quant"] = [bin_edges, bin_centers, bin_widths] 325 | 326 | columns_to_select = ['User', 327 | 'Card', 328 | 'Timestamp', 329 | 'Amount', 330 | 'Use Chip', 331 | 'Merchant Name', 332 | 'Merchant City', 333 | 'Merchant State', 334 | 'Zip', 335 | 'MCC', 336 | 'Errors?', 337 | 'Is Fraud?'] 338 | 339 | self.trans_table = data[columns_to_select] 340 | 341 | log.info(f"writing cached csv to {path.join(dirname, fname)}") 342 | if not path.exists(dirname): 343 | os.mkdir(dirname) 344 | self.write_csv(self.trans_table, path.join(dirname, fname)) 345 | 346 | encoder_fname = path.join(dirname, f'{self.fname}{self.fextension}.encoder_fit.pkl') 347 | log.info(f"writing cached encoder fit to {encoder_fname}") 348 | pickle.dump(self.encoder_fit, open(encoder_fname, "wb")) -------------------------------------------------------------------------------- /dataset/datacollator.py: -------------------------------------------------------------------------------- 1 | from typing import Dict, List, Tuple, Union 2 | import torch 3 | from transformers import DataCollatorForLanguageModeling 4 | 5 | 6 | class TransDataCollatorForLanguageModeling(DataCollatorForLanguageModeling): 7 | 8 | def __call__( 9 | self, examples: List[Union[List[int], torch.Tensor, Dict[str, torch.Tensor]]] 10 | ) -> Dict[str, torch.Tensor]: 11 | batch = self._tensorize_batch(examples) 12 | sz = batch.shape 13 | if self.mlm: 14 | batch = batch.view(sz[0], -1) 15 | inputs, labels = self.mask_tokens(batch) 16 | return {"input_ids": inputs.view(sz), "masked_lm_labels": labels.view(sz)} 17 | else: 18 | labels = batch.clone().detach() 19 | if self.tokenizer.pad_token_id is not None: 20 | labels[labels == self.tokenizer.pad_token_id] = -100 21 | return {"input_ids": batch, "labels": labels} 22 | 23 | def mask_tokens(self, inputs: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: 24 | """ 25 | Prepare masked tokens inputs/labels for masked language modeling: 80% MASK, 10% random, 10% original. 26 | """ 27 | if self.tokenizer.mask_token is None: 28 | raise ValueError( 29 | "This tokenizer does not have a mask token which is necessary for masked language modeling. Remove " 30 | "the --mlm flag if you want to use this tokenizer. " 31 | ) 32 | labels = inputs.clone() 33 | # We sample a few tokens in each sequence for masked-LM training (with probability args.mlm_probability 34 | # defaults to 0.15 in Bert/RoBERTa) 35 | probability_matrix = torch.full(labels.shape, self.mlm_probability) 36 | special_tokens_mask = [ 37 | self.tokenizer.get_special_tokens_mask(val, already_has_special_tokens=True) for val in labels.tolist() 38 | ] 39 | probability_matrix.masked_fill_(torch.tensor(special_tokens_mask, dtype=torch.bool), value=0.0) 40 | if self.tokenizer._pad_token is not None: 41 | padding_mask = labels.eq(self.tokenizer.pad_token_id) 42 | probability_matrix.masked_fill_(padding_mask, value=0.0) 43 | masked_indices = torch.bernoulli(probability_matrix).bool() 44 | labels[~masked_indices] = -100 # We only compute loss on masked tokens 45 | 46 | # 80% of the time, we replace masked input tokens with tokenizer.mask_token ([MASK]) 47 | indices_replaced = torch.bernoulli(torch.full(labels.shape, 0.8)).bool() & masked_indices 48 | inputs[indices_replaced] = self.tokenizer.convert_tokens_to_ids(self.tokenizer.mask_token) 49 | 50 | # 10% of the time, we replace masked input tokens with random word 51 | indices_random = torch.bernoulli(torch.full(labels.shape, 0.5)).bool() & masked_indices & ~indices_replaced 52 | random_words = torch.randint(len(self.tokenizer), labels.shape, dtype=torch.long) 53 | inputs[indices_random] = random_words[indices_random] 54 | 55 | # The rest of the time (10% of the time) we keep the masked input tokens unchanged 56 | return inputs, labels 57 | -------------------------------------------------------------------------------- /dataset/prsa.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.utils.data import DataLoader, Dataset 3 | import glob 4 | import pandas as pd 5 | import numpy as np 6 | from dataset.vocab import Vocabulary 7 | from sklearn.preprocessing import MinMaxScaler 8 | import tqdm 9 | from os import path 10 | import logging 11 | 12 | logger = logging.getLogger(__name__) 13 | log = logger 14 | 15 | 16 | class PRSADataset(Dataset): 17 | def __init__(self, 18 | data_root: str = "./data/prsa/", 19 | seq_len=10, 20 | stride=5, 21 | nbins=50, 22 | vocab_dir="output", 23 | mlm=True, 24 | return_labels=False, 25 | use_station=False, 26 | transform_date=True, 27 | flatten=False): 28 | 29 | self.stride = stride 30 | self.seq_len = seq_len 31 | self.data_root = data_root 32 | self.nbins = nbins 33 | self.vocab_dir = vocab_dir 34 | 35 | self.mlm = mlm 36 | self.return_labels = return_labels 37 | self.use_station = use_station 38 | self.transform_date = transform_date 39 | self.flatten = flatten 40 | 41 | self.vocab = Vocabulary() 42 | self.encoding_fn = {} 43 | self.target_cols = ['PM2.5', 'PM10'] 44 | 45 | self.setup() 46 | 47 | def __getitem__(self, index): 48 | if self.flatten: 49 | return_data = torch.tensor(self.samples[index], dtype=torch.long) 50 | else: 51 | return_data = torch.tensor(self.samples[index], dtype=torch.long).reshape(self.seq_len, -1) 52 | 53 | if self.return_labels: 54 | target = self.targets[index] 55 | return_data = return_data, torch.tensor(target, dtype=torch.float32) 56 | 57 | return return_data 58 | 59 | def __len__(self): 60 | return len(self.samples) 61 | 62 | def _quantization_binning(self, data): 63 | qtls = np.arange(0.0, 1.0 + 1 / self.nbins, 1 / self.nbins) 64 | bin_edges = np.quantile(data, qtls, axis=0) 65 | bin_widths = np.diff(bin_edges, axis=0) 66 | bin_centers = bin_edges[:-1] + bin_widths / 2 67 | return bin_edges, bin_centers, bin_widths 68 | 69 | def _quantize(self, inputs, bin_edges): 70 | quant_inputs = np.zeros(inputs.shape[0]) 71 | for i, x in enumerate(inputs): 72 | quant_inputs[i] = np.digitize(x, bin_edges) 73 | quant_inputs = quant_inputs.clip(1, self.nbins) - 1 74 | return quant_inputs 75 | 76 | @staticmethod 77 | def time_fit_transform(column): 78 | mfit = MinMaxScaler() 79 | mfit.fit(column) 80 | return mfit, mfit.transform(column) 81 | 82 | @staticmethod 83 | def timeEncoder(X): 84 | d = pd.to_datetime(dict(year=X['year'], month=X['month'], day=X['day'], hour=X['hour'])).astype(int) 85 | return pd.DataFrame(d) 86 | 87 | def setup(self): 88 | data = self.read_data(self.data_root) 89 | 90 | ''' 91 | year month day hour PM2.5 PM10 SO2 NO2 92 | CO O3 TEMP PRES DEWP RAIN wd WSPM station 93 | ''' 94 | 95 | cols_for_bins = [] 96 | if self.transform_date: 97 | cols_for_bins += ['timestamp'] 98 | 99 | data_cols = ['year', 'month', 'day', 'hour'] 100 | timestamp = self.timeEncoder(data[data_cols]) 101 | timestamp_fit, timestamp = self.time_fit_transform(timestamp) 102 | self.encoding_fn['timestamp'] = timestamp_fit 103 | data['timestamp'] = timestamp 104 | 105 | cols_for_bins += ['SO2', 'NO2', 'CO', 'O3', 'TEMP', 'PRES', 'DEWP', 'RAIN', 'WSPM'] 106 | for col in cols_for_bins: 107 | col_data = np.array(data[col]) 108 | bin_edges, bin_centers, bin_widths = self._quantization_binning(col_data) 109 | data[col] = self._quantize(col_data, bin_edges) 110 | self.encoding_fn[col] = [bin_edges, bin_centers, bin_widths] 111 | 112 | final_cols = cols_for_bins + ['wd', 'station', 'PM2.5', 'PM10'] 113 | 114 | self.data = data[final_cols] 115 | self.init_vocab() 116 | self.prepare_samples() 117 | self.save_vocab(self.vocab_dir) 118 | 119 | def prepare_samples(self): 120 | self.samples, self.targets = [], [] 121 | sep_id = self.vocab.get_id(self.vocab.sep_token, special_token=True) 122 | 123 | groups = self.data.groupby('station') 124 | for group in tqdm.tqdm(groups): 125 | station_name, station_data = group 126 | 127 | nrows = station_data.shape[0] 128 | nrows = nrows - self.seq_len 129 | 130 | log.info(f"{station_name} : {nrows}") 131 | for sample_id in range(0, nrows, self.stride): 132 | sample, target = [], [] 133 | for tid in range(0, self.seq_len): 134 | row = station_data.iloc[sample_id + tid] 135 | for col_name, col_value in row.iteritems(): 136 | if not self.use_station: 137 | if col_name == "station": 138 | continue 139 | if col_name not in self.target_cols: 140 | vocab_id = self.vocab.get_id(col_value, col_name) 141 | sample.append(vocab_id) 142 | 143 | if self.mlm: 144 | sample.append(sep_id) 145 | target.append(row[self.target_cols].tolist()) 146 | 147 | self.samples.append(sample) 148 | self.targets.append(target) 149 | 150 | assert len(self.samples) == len(self.targets) 151 | log.info(f"total samples {len(self.samples)}") 152 | 153 | self.ncols = len(self.vocab.field_keys) 154 | 155 | def init_vocab(self): 156 | cols = list(self.data.columns) 157 | 158 | if not self.use_station: 159 | cols.remove('station') 160 | 161 | for col in self.target_cols: 162 | cols.remove(col) 163 | 164 | self.vocab.set_field_keys(cols) 165 | 166 | for column in cols: 167 | unique_values = self.data[column].value_counts(sort=True).to_dict() # returns sorted 168 | for val in unique_values: 169 | self.vocab.set_id(val, column) 170 | 171 | print(f"columns used for vocab: {list(cols)}") 172 | print(f"total vocabulary size: {len(self.vocab.id2token)}") 173 | 174 | for column in cols: 175 | vocab_size = len(self.vocab.token2id[column]) 176 | print(f"column : {column}, vocab size : {vocab_size}") 177 | 178 | def read_data(self, root): 179 | all_stations = None 180 | fnames = glob.glob(f"{root}/*.csv") 181 | for fname in fnames: 182 | station_data = pd.read_csv(fname) 183 | 184 | if all_stations is None: 185 | all_stations = station_data 186 | else: 187 | all_stations = all_stations.append(station_data, ignore_index=True) 188 | 189 | all_stations.drop(columns=['No'], inplace=True, axis=1) 190 | log.info(f"shape (original) : {all_stations.shape}") 191 | all_stations = all_stations.dropna() 192 | log.info(f"shape (after nan removed): {all_stations.shape}") 193 | return all_stations 194 | 195 | def save_vocab(self, vocab_dir): 196 | file_name = path.join(vocab_dir, f'vocab.nb') 197 | log.info(f"saving vocab at {file_name}") 198 | self.vocab.save_vocab(file_name) -------------------------------------------------------------------------------- /dataset/vocab.py: -------------------------------------------------------------------------------- 1 | from collections import OrderedDict 2 | import numpy as np 3 | 4 | 5 | class AttrDict(dict): 6 | def __init__(self, *args, **kwargs): 7 | super(AttrDict, self).__init__(*args, **kwargs) 8 | self.__dict__ = self 9 | 10 | 11 | class Vocabulary: 12 | def __init__(self, adap_thres=10000, target_column_name="Is Fraud?"): 13 | self.unk_token = "[UNK]" 14 | self.sep_token = "[SEP]" 15 | self.pad_token = "[PAD]" 16 | self.cls_token = "[CLS]" 17 | self.mask_token = "[MASK]" 18 | self.bos_token = "[BOS]" 19 | self.eos_token = "[EOS]" 20 | 21 | self.adap_thres = adap_thres 22 | self.adap_sm_cols = set() 23 | 24 | self.target_column_name = target_column_name 25 | self.special_field_tag = "SPECIAL" 26 | 27 | self.special_tokens = [self.unk_token, self.sep_token, self.pad_token, 28 | self.cls_token, self.mask_token, self.bos_token, self.eos_token] 29 | 30 | self.token2id = OrderedDict() # {field: {token: id}, ...} 31 | self.id2token = OrderedDict() # {id : [token,field]} 32 | self.field_keys = OrderedDict() 33 | self.token2id[self.special_field_tag] = OrderedDict() 34 | 35 | self.filename = '' # this field is set in the `save_vocab` method 36 | 37 | for token in self.special_tokens: 38 | global_id = len(self.id2token) 39 | local_id = len(self.token2id[self.special_field_tag]) 40 | 41 | self.token2id[self.special_field_tag][token] = [global_id, local_id] 42 | self.id2token[global_id] = [token, self.special_field_tag, local_id] 43 | 44 | def set_id(self, token, field_name, return_local=False): 45 | global_id, local_id = None, None 46 | 47 | if token not in self.token2id[field_name]: 48 | global_id = len(self.id2token) 49 | local_id = len(self.token2id[field_name]) 50 | 51 | self.token2id[field_name][token] = [global_id, local_id] 52 | self.id2token[global_id] = [token, field_name, local_id] 53 | else: 54 | global_id, local_id = self.token2id[field_name][token] 55 | 56 | if return_local: 57 | return local_id 58 | 59 | return global_id 60 | 61 | def get_id(self, token, field_name="", special_token=False, return_local=False): 62 | global_id, local_id = None, None 63 | if special_token: 64 | field_name = self.special_field_tag 65 | 66 | if token in self.token2id[field_name]: 67 | global_id, local_id = self.token2id[field_name][token] 68 | 69 | else: 70 | raise Exception(f"token {token} not found in field: {field_name}") 71 | 72 | if return_local: 73 | return local_id 74 | 75 | return global_id 76 | 77 | def set_field_keys(self, keys): 78 | 79 | for key in keys: 80 | self.token2id[key] = OrderedDict() 81 | self.field_keys[key] = None 82 | 83 | self.field_keys[self.special_field_tag] = None # retain the order of columns 84 | 85 | def get_field_ids(self, field_name, return_local=False): 86 | if field_name in self.token2id: 87 | ids = self.token2id[field_name] 88 | else: 89 | raise Exception(f"field name {field_name} is invalid.") 90 | 91 | selected_idx = 0 92 | if return_local: 93 | selected_idx = 1 94 | return [ids[idx][selected_idx] for idx in ids] 95 | 96 | def get_from_global_ids(self, global_ids, what_to_get='local_ids'): 97 | device = global_ids.device 98 | 99 | def map_global_ids_to_local_ids(gid): 100 | return self.id2token[gid][2] if gid != -100 else -100 101 | 102 | def map_global_ids_to_tokens(gid): 103 | return f'{self.id2token[gid][1]}_{self.id2token[gid][0]}' if gid != -100 else '-' 104 | 105 | if what_to_get == 'local_ids': 106 | return global_ids.cpu().apply_(map_global_ids_to_local_ids).to(device) 107 | elif what_to_get == 'tokens': 108 | vectorized_token_map = np.vectorize(map_global_ids_to_tokens) 109 | new_array_for_tokens = global_ids.detach().clone().cpu().numpy() 110 | return vectorized_token_map(new_array_for_tokens) 111 | else: 112 | raise ValueError("Only 'local_ids' or 'tokens' can be passed as value of the 'what_to_get' parameter.") 113 | 114 | def save_vocab(self, fname): 115 | self.filename = fname 116 | with open(fname, "w") as fout: 117 | for idx in self.id2token: 118 | token, field, _ = self.id2token[idx] 119 | token = "%s_%s" % (field, token) 120 | fout.write("%s\n" % token) 121 | 122 | def get_field_keys(self, remove_target=True, ignore_special=False): 123 | keys = list(self.field_keys.keys()) 124 | 125 | if remove_target and self.target_column_name in keys: 126 | keys.remove(self.target_column_name) 127 | if ignore_special: 128 | keys.remove(self.special_field_tag) 129 | return keys 130 | 131 | def get_special_tokens(self): 132 | special_tokens_map = {} 133 | # TODO : remove the dependency of re-initializing here. retrieve from field_key = SPECIAL 134 | keys = ["unk_token", "sep_token", "pad_token", "cls_token", "mask_token", "bos_token", "eos_token"] 135 | for key, token in zip(keys, self.special_tokens): 136 | token = "%s_%s" % (self.special_field_tag, token) 137 | special_tokens_map[key] = token 138 | 139 | return AttrDict(special_tokens_map) 140 | 141 | def __len__(self): 142 | return len(self.id2token) 143 | 144 | def __str__(self): 145 | str_ = 'vocab: [{} tokens] [field_keys={}]'.format(len(self), self.field_keys) 146 | return str_ 147 | -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | from os import makedirs 2 | from os.path import join 3 | import logging 4 | import numpy as np 5 | import torch 6 | import random 7 | from args import define_main_parser 8 | 9 | from transformers import DataCollatorForLanguageModeling, Trainer, TrainingArguments 10 | 11 | from dataset.prsa import PRSADataset 12 | from dataset.card import TransactionDataset 13 | from models.modules import TabFormerBertLM, TabFormerGPT2 14 | from misc.utils import random_split_dataset 15 | from dataset.datacollator import TransDataCollatorForLanguageModeling 16 | 17 | 18 | logger = logging.getLogger(__name__) 19 | log = logger 20 | logging.basicConfig( 21 | format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", 22 | datefmt="%m/%d/%Y %H:%M:%S", 23 | level=logging.INFO 24 | ) 25 | 26 | 27 | def main(args): 28 | # random seeds 29 | seed = args.seed 30 | random.seed(seed) # python 31 | np.random.seed(seed) # numpy 32 | torch.manual_seed(seed) # torch 33 | if torch.cuda.is_available(): 34 | torch.cuda.manual_seed_all(seed) # torch.cuda 35 | 36 | if args.data_type == 'card': 37 | dataset = TransactionDataset(root=args.data_root, 38 | fname=args.data_fname, 39 | fextension=args.data_extension, 40 | vocab_dir=args.output_dir, 41 | nrows=args.nrows, 42 | user_ids=args.user_ids, 43 | mlm=args.mlm, 44 | cached=args.cached, 45 | stride=args.stride, 46 | flatten=args.flatten, 47 | return_labels=False, 48 | skip_user=args.skip_user) 49 | elif args.data_type == 'prsa': 50 | dataset = PRSADataset(stride=args.stride, 51 | mlm=args.mlm, 52 | return_labels=False, 53 | use_station=False, 54 | flatten=args.flatten, 55 | vocab_dir=args.output_dir) 56 | 57 | else: 58 | raise Exception(f"data type '{args.data_type}' not defined") 59 | 60 | vocab = dataset.vocab 61 | custom_special_tokens = vocab.get_special_tokens() 62 | 63 | # split dataset into train, val, test [0.6. 0.2, 0.2] 64 | totalN = len(dataset) 65 | trainN = int(0.6 * totalN) 66 | 67 | valtestN = totalN - trainN 68 | valN = int(valtestN * 0.5) 69 | testN = valtestN - valN 70 | 71 | assert totalN == trainN + valN + testN 72 | 73 | lengths = [trainN, valN, testN] 74 | 75 | log.info(f"# lengths: train [{trainN}] valid [{valN}] test [{testN}]") 76 | log.info("# lengths: train [{:.2f}] valid [{:.2f}] test [{:.2f}]".format(trainN / totalN, valN / totalN, 77 | testN / totalN)) 78 | 79 | train_dataset, eval_dataset, test_dataset = random_split_dataset(dataset, lengths) 80 | 81 | if args.lm_type == "bert": 82 | tab_net = TabFormerBertLM(custom_special_tokens, 83 | vocab=vocab, 84 | field_ce=args.field_ce, 85 | flatten=args.flatten, 86 | ncols=dataset.ncols, 87 | field_hidden_size=args.field_hs 88 | ) 89 | else: 90 | tab_net = TabFormerGPT2(custom_special_tokens, 91 | vocab=vocab, 92 | field_ce=args.field_ce, 93 | flatten=args.flatten, 94 | ) 95 | 96 | log.info(f"model initiated: {tab_net.model.__class__}") 97 | 98 | if args.flatten: 99 | collactor_cls = "DataCollatorForLanguageModeling" 100 | else: 101 | collactor_cls = "TransDataCollatorForLanguageModeling" 102 | 103 | log.info(f"collactor class: {collactor_cls}") 104 | data_collator = eval(collactor_cls)( 105 | tokenizer=tab_net.tokenizer, mlm=args.mlm, mlm_probability=args.mlm_prob 106 | ) 107 | 108 | training_args = TrainingArguments( 109 | output_dir=args.output_dir, # output directory 110 | num_train_epochs=args.num_train_epochs, # total number of training epochs 111 | logging_dir=args.log_dir, # directory for storing logs 112 | save_steps=args.save_steps, 113 | do_train=args.do_train, 114 | # do_eval=args.do_eval, 115 | # evaluation_strategy="epoch", 116 | prediction_loss_only=True, 117 | overwrite_output_dir=True, 118 | # eval_steps=10000 119 | ) 120 | 121 | trainer = Trainer( 122 | model=tab_net.model, 123 | args=training_args, 124 | data_collator=data_collator, 125 | train_dataset=train_dataset, 126 | eval_dataset=eval_dataset, 127 | ) 128 | 129 | if args.checkpoint: 130 | model_path = join(args.output_dir, f'checkpoint-{args.checkpoint}') 131 | else: 132 | model_path = args.output_dir 133 | 134 | trainer.train(model_path=model_path) 135 | 136 | 137 | if __name__ == "__main__": 138 | 139 | parser = define_main_parser() 140 | opts = parser.parse_args() 141 | 142 | opts.log_dir = join(opts.output_dir, "logs") 143 | makedirs(opts.output_dir, exist_ok=True) 144 | makedirs(opts.log_dir, exist_ok=True) 145 | 146 | if opts.mlm and opts.lm_type == "gpt2": 147 | raise Exception("Error: GPT2 doesn't need '--mlm' option. Please re-run with this flag removed.") 148 | 149 | if not opts.mlm and opts.lm_type == "bert": 150 | raise Exception("Error: Bert needs '--mlm' option. Please re-run with this flag included.") 151 | 152 | main(opts) 153 | -------------------------------------------------------------------------------- /misc/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/IBM/TabFormer/ebb7cd68ee1897599568107740bc452104bbbaf8/misc/__init__.py -------------------------------------------------------------------------------- /misc/cc_trans_dataset.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/IBM/TabFormer/ebb7cd68ee1897599568107740bc452104bbbaf8/misc/cc_trans_dataset.png -------------------------------------------------------------------------------- /misc/utils.py: -------------------------------------------------------------------------------- 1 | import random 2 | import numpy as np 3 | import torch 4 | 5 | 6 | class ddict(object): 7 | def __init__(self, **kwargs): 8 | self.__dict__.update(kwargs) 9 | 10 | 11 | def random_split_dataset(dataset, lengths, random_seed=20200706): 12 | # state snapshot 13 | state = {} 14 | state['seeds'] = { 15 | 'python_state': random.getstate(), 16 | 'numpy_state': np.random.get_state(), 17 | 'torch_state': torch.get_rng_state(), 18 | 'cuda_state': torch.cuda.get_rng_state() if torch.cuda.is_available() else None 19 | } 20 | 21 | # seed 22 | random.seed(random_seed) # python 23 | np.random.seed(random_seed) # numpy 24 | torch.manual_seed(random_seed) # torch 25 | if torch.cuda.is_available(): 26 | torch.cuda.manual_seed_all(random_seed) # torch.cuda 27 | 28 | train_dataset, eval_dataset, test_dataset = torch.utils.data.dataset.random_split(dataset, lengths) 29 | 30 | # reinstate state 31 | random.setstate(state['seeds']['python_state']) 32 | np.random.set_state(state['seeds']['numpy_state']) 33 | torch.set_rng_state(state['seeds']['torch_state']) 34 | if torch.cuda.is_available(): 35 | torch.cuda.set_rng_state(state['seeds']['cuda_state']) 36 | 37 | return train_dataset, eval_dataset, test_dataset 38 | 39 | 40 | def divide_chunks(l, n): 41 | for i in range(0, len(l), n): 42 | yield l[i:i + n] -------------------------------------------------------------------------------- /models/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/IBM/TabFormer/ebb7cd68ee1897599568107740bc452104bbbaf8/models/__init__.py -------------------------------------------------------------------------------- /models/custom_criterion.py: -------------------------------------------------------------------------------- 1 | from torch import Tensor 2 | from torch.nn import CrossEntropyLoss, AdaptiveLogSoftmaxWithLoss 3 | from torch.nn.functional import log_softmax 4 | 5 | 6 | class CustomAdaptiveLogSoftmax(AdaptiveLogSoftmaxWithLoss): 7 | def __init__( 8 | self, ignore_index=-100, 9 | **kwargs 10 | ) -> None: 11 | super().__init__(**kwargs) 12 | self.ignore_index = ignore_index 13 | 14 | def forward(self, input: Tensor, target: Tensor): 15 | if input.size(0) != target.size(0): 16 | raise RuntimeError('Input and target should have the same size ' 17 | 'in the batch dimension.') 18 | 19 | ''' 20 | handles ignore index = -100; 21 | removes all targets which are masked from input and target 22 | ''' 23 | consider_indices = (target != self.ignore_index) 24 | input = input[consider_indices, :] 25 | target = target[consider_indices] 26 | 27 | used_rows = 0 28 | batch_size = target.size(0) 29 | 30 | output = input.new_zeros(batch_size) 31 | gather_inds = target.new_empty(batch_size) 32 | 33 | cutoff_values = [0] + self.cutoffs 34 | for i in range(len(cutoff_values) - 1): 35 | 36 | low_idx = cutoff_values[i] 37 | high_idx = cutoff_values[i + 1] 38 | 39 | target_mask = (target >= low_idx) & (target < high_idx) 40 | row_indices = target_mask.nonzero().squeeze() 41 | 42 | if row_indices.numel() == 0: 43 | continue 44 | 45 | if i == 0: 46 | gather_inds.index_copy_(0, row_indices, target[target_mask]) 47 | 48 | else: 49 | relative_target = target[target_mask] - low_idx 50 | input_subset = input.index_select(0, row_indices) 51 | 52 | cluster_output = self.tail[i - 1](input_subset) 53 | cluster_index = self.shortlist_size + i - 1 54 | 55 | gather_inds.index_fill_(0, row_indices, cluster_index) 56 | 57 | cluster_logprob = log_softmax(cluster_output, dim=1) 58 | local_logprob = cluster_logprob.gather(1, relative_target.unsqueeze(1)) 59 | output.index_copy_(0, row_indices, local_logprob.squeeze(1)) 60 | 61 | used_rows += row_indices.numel() 62 | 63 | if used_rows != batch_size: 64 | raise RuntimeError("Target values should be in [0, {}], " 65 | "but values in range [{}, {}] " 66 | "were found. ".format(self.n_classes - 1, 67 | target.min().item(), 68 | target.max().item())) 69 | 70 | head_output = self.head(input) 71 | head_logprob = log_softmax(head_output, dim=1) 72 | output += head_logprob.gather(1, gather_inds.unsqueeze(1)).squeeze() 73 | loss = (-output).mean() 74 | 75 | return loss 76 | -------------------------------------------------------------------------------- /models/hierarchical.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | 3 | 4 | class TabFormerConcatEmbeddings(nn.Module): 5 | """TabFormerConcatEmbeddings: Embeds tabular data of categorical variables 6 | 7 | Notes: - All column entries must be integer indices in a vocabolary that is common across columns 8 | - `sparse=True` in `nn.Embedding` speeds up gradient computation for large vocabs 9 | 10 | Args: 11 | config.ncols 12 | config.vocab_size 13 | config.hidden_size 14 | 15 | Inputs: 16 | - **input_ids** (batch, seq_len, ncols): tensor of batch of sequences of rows 17 | 18 | Outputs: 19 | - **output'**: (batch, seq_len, hidden_size): tensor of embedded rows 20 | """ 21 | 22 | def __init__(self, config): 23 | super().__init__() 24 | self.word_embeddings = nn.Embedding(config.vocab_size, config.field_hidden_size, 25 | padding_idx=getattr(config, 'pad_token_id', 0), sparse=False) 26 | self.lin_proj = nn.Linear(config.field_hidden_size * config.ncols, config.hidden_size) 27 | 28 | self.hidden_size = config.hidden_size 29 | self.field_hidden_size = config.field_hidden_size 30 | 31 | def forward(self, input_ids): 32 | input_shape = input_ids.size() 33 | 34 | embeds_sz = list(input_shape[:-1]) + [input_shape[-1] * self.field_hidden_size] 35 | inputs_embeds = self.lin_proj(self.word_embeddings(input_ids).view(embeds_sz)) 36 | 37 | return inputs_embeds 38 | 39 | 40 | class TabFormerEmbeddings(nn.Module): 41 | """TabFormerEmbeddings: Embeds tabular data of categorical variables 42 | 43 | Notes: - All column entries must be integer indices in a vocabolary that is common across columns 44 | 45 | Args: 46 | config.ncols 47 | config.num_layers (int): Number of transformer layers 48 | config.vocab_size 49 | config.hidden_size 50 | config.field_hidden_size 51 | 52 | Inputs: 53 | - **input** (batch, seq_len, ncols): tensor of batch of sequences of rows 54 | 55 | Outputs: 56 | - **output**: (batch, seq_len, hidden_size): tensor of embedded rows 57 | """ 58 | 59 | def __init__(self, config): 60 | super().__init__() 61 | 62 | if not hasattr(config, 'num_layers'): 63 | config.num_layers = 1 64 | if not hasattr(config, 'nhead'): 65 | config.nhead = 8 66 | 67 | self.word_embeddings = nn.Embedding(config.vocab_size, config.field_hidden_size, 68 | padding_idx=getattr(config, 'pad_token_id', 0), sparse=False) 69 | 70 | encoder_layer = nn.TransformerEncoderLayer(d_model=config.field_hidden_size, nhead=config.nhead, 71 | dim_feedforward=config.field_hidden_size) 72 | self.transformer_encoder = nn.TransformerEncoder(encoder_layer, num_layers=config.num_layers) 73 | 74 | self.lin_proj = nn.Linear(config.field_hidden_size * config.ncols, config.hidden_size) 75 | 76 | def forward(self, input_ids): 77 | inputs_embeds = self.word_embeddings(input_ids) 78 | embeds_shape = list(inputs_embeds.size()) 79 | 80 | inputs_embeds = inputs_embeds.view([-1] + embeds_shape[-2:]) 81 | inputs_embeds = inputs_embeds.permute(1, 0, 2) 82 | inputs_embeds = self.transformer_encoder(inputs_embeds) 83 | inputs_embeds = inputs_embeds.permute(1, 0, 2) 84 | inputs_embeds = inputs_embeds.contiguous().view(embeds_shape[0:2]+[-1]) 85 | 86 | inputs_embeds = self.lin_proj(inputs_embeds) 87 | 88 | return inputs_embeds 89 | -------------------------------------------------------------------------------- /models/modules.py: -------------------------------------------------------------------------------- 1 | from misc.utils import ddict 2 | 3 | from transformers.modeling_utils import PreTrainedModel 4 | from transformers import ( 5 | BertTokenizer, 6 | BertForMaskedLM, 7 | GPT2Config, 8 | GPT2LMHeadModel 9 | ) 10 | 11 | from models.tabformer_tokenizer import TabFormerTokenizer 12 | from models.hierarchical import TabFormerEmbeddings 13 | from models.tabformer_bert import TabFormerBertForMaskedLM, TabFormerBertConfig 14 | from models.tabformer_gpt2 import TabFormerGPT2LMHeadModel 15 | 16 | 17 | class TabFormerBaseModel(PreTrainedModel): 18 | def __init__(self, hf_model, tab_embeddings, config): 19 | super().__init__(config) 20 | 21 | self.model = hf_model 22 | self.tab_embeddings = tab_embeddings 23 | 24 | def forward(self, input_ids, **input_args): 25 | inputs_embeds = self.tab_embeddings(input_ids) 26 | return self.model(inputs_embeds=inputs_embeds, **input_args) 27 | 28 | 29 | class TabFormerHierarchicalLM(PreTrainedModel): 30 | base_model_prefix = "bert" 31 | 32 | def __init__(self, config, vocab): 33 | super().__init__(config) 34 | 35 | self.config = config 36 | 37 | self.tab_embeddings = TabFormerEmbeddings(self.config) 38 | self.tb_model = TabFormerBertForMaskedLM(self.config, vocab) 39 | 40 | def forward(self, input_ids, **input_args): 41 | inputs_embeds = self.tab_embeddings(input_ids) 42 | return self.tb_model(inputs_embeds=inputs_embeds, **input_args) 43 | 44 | 45 | class TabFormerBertLM: 46 | def __init__(self, special_tokens, vocab, field_ce=False, flatten=False, ncols=None, field_hidden_size=768): 47 | 48 | self.ncols = ncols 49 | self.vocab = vocab 50 | vocab_file = self.vocab.filename 51 | hidden_size = field_hidden_size if flatten else (field_hidden_size * self.ncols) 52 | 53 | self.config = TabFormerBertConfig(vocab_size=len(self.vocab), 54 | ncols=self.ncols, 55 | hidden_size=hidden_size, 56 | field_hidden_size=field_hidden_size, 57 | flatten=flatten, 58 | num_attention_heads=self.ncols) 59 | 60 | self.tokenizer = BertTokenizer(vocab_file, 61 | do_lower_case=False, 62 | **special_tokens) 63 | self.model = self.get_model(field_ce, flatten) 64 | 65 | def get_model(self, field_ce, flatten): 66 | 67 | if flatten and not field_ce: 68 | # flattened vanilla BERT 69 | model = BertForMaskedLM(self.config) 70 | elif flatten and field_ce: 71 | # flattened field CE BERT 72 | model = TabFormerBertForMaskedLM(self.config, self.vocab) 73 | else: 74 | # hierarchical field CE BERT 75 | model = TabFormerHierarchicalLM(self.config, self.vocab) 76 | 77 | return model 78 | 79 | 80 | class TabFormerGPT2: 81 | def __init__(self, special_tokens, vocab, field_ce=False, flatten=False): 82 | 83 | self.vocab = vocab 84 | self.config = GPT2Config(vocab_size=len(self.vocab)) 85 | 86 | self.tokenizer = TabFormerTokenizer( 87 | unk_token=special_tokens.unk_token, 88 | bos_token=special_tokens.bos_token, 89 | eos_token=special_tokens.eos_token 90 | ) 91 | 92 | self.model = self.get_model(field_ce, flatten) 93 | 94 | def get_model(self, field_ce, flatten): 95 | if field_ce: 96 | model = TabFormerGPT2LMHeadModel(self.config, self.vocab) 97 | else: 98 | model = GPT2LMHeadModel(self.config) 99 | if not flatten: 100 | tab_emb_config = ddict(vocab_size=len(self.vocab), hidden_size=self.config.hidden_size) 101 | model = TabFormerBaseModel(model, TabFormerEmbeddings(tab_emb_config)) 102 | 103 | return model 104 | -------------------------------------------------------------------------------- /models/tabformer_bert.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | from torch.nn import CrossEntropyLoss 4 | 5 | from transformers.modeling_bert import ACT2FN, BertLayerNorm 6 | from transformers.modeling_bert import BertForMaskedLM 7 | from transformers.configuration_bert import BertConfig 8 | from models.custom_criterion import CustomAdaptiveLogSoftmax 9 | 10 | 11 | class TabFormerBertConfig(BertConfig): 12 | def __init__( 13 | self, 14 | flatten=True, 15 | ncols=12, 16 | vocab_size=30522, 17 | field_hidden_size=64, 18 | hidden_size=768, 19 | num_attention_heads=12, 20 | pad_token_id=0, 21 | **kwargs 22 | ): 23 | super().__init__(pad_token_id=pad_token_id, **kwargs) 24 | 25 | self.ncols = ncols 26 | self.field_hidden_size = field_hidden_size 27 | self.hidden_size = hidden_size 28 | self.flatten = flatten 29 | self.vocab_size = vocab_size 30 | self.num_attention_heads=num_attention_heads 31 | 32 | class TabFormerBertPredictionHeadTransform(nn.Module): 33 | def __init__(self, config): 34 | super().__init__() 35 | self.dense = nn.Linear(config.field_hidden_size, config.hidden_size) 36 | if isinstance(config.hidden_act, str): 37 | self.transform_act_fn = ACT2FN[config.hidden_act] 38 | else: 39 | self.transform_act_fn = config.hidden_act 40 | self.LayerNorm = BertLayerNorm(config.hidden_size, eps=config.layer_norm_eps) 41 | 42 | def forward(self, hidden_states): 43 | hidden_states = self.dense(hidden_states) 44 | hidden_states = self.transform_act_fn(hidden_states) 45 | hidden_states = self.LayerNorm(hidden_states) 46 | return hidden_states 47 | 48 | class TabFormerBertLMPredictionHead(nn.Module): 49 | def __init__(self, config): 50 | super().__init__() 51 | self.transform = TabFormerBertPredictionHeadTransform(config) 52 | 53 | # The output weights are the same as the input embeddings, but there is 54 | # an output-only bias for each token. 55 | self.decoder = nn.Linear(config.hidden_size, config.vocab_size, bias=False) 56 | 57 | self.bias = nn.Parameter(torch.zeros(config.vocab_size)) 58 | 59 | # Need a link between the two variables so that the bias is correctly resized with `resize_token_embeddings` 60 | self.decoder.bias = self.bias 61 | 62 | def forward(self, hidden_states): 63 | hidden_states = self.transform(hidden_states) 64 | hidden_states = self.decoder(hidden_states) 65 | return hidden_states 66 | 67 | class TabFormerBertOnlyMLMHead(nn.Module): 68 | def __init__(self, config): 69 | super().__init__() 70 | self.predictions = TabFormerBertLMPredictionHead(config) 71 | 72 | def forward(self, sequence_output): 73 | prediction_scores = self.predictions(sequence_output) 74 | return prediction_scores 75 | 76 | class TabFormerBertForMaskedLM(BertForMaskedLM): 77 | def __init__(self, config, vocab): 78 | super().__init__(config) 79 | 80 | self.vocab = vocab 81 | self.cls = TabFormerBertOnlyMLMHead(config) 82 | self.init_weights() 83 | 84 | def forward( 85 | self, 86 | input_ids=None, 87 | attention_mask=None, 88 | token_type_ids=None, 89 | position_ids=None, 90 | head_mask=None, 91 | inputs_embeds=None, 92 | masked_lm_labels=None, 93 | encoder_hidden_states=None, 94 | encoder_attention_mask=None, 95 | lm_labels=None, 96 | ): 97 | outputs = self.bert( 98 | input_ids, 99 | attention_mask=attention_mask, 100 | token_type_ids=token_type_ids, 101 | position_ids=position_ids, 102 | head_mask=head_mask, 103 | inputs_embeds=inputs_embeds, 104 | encoder_hidden_states=encoder_hidden_states, 105 | encoder_attention_mask=encoder_attention_mask, 106 | ) 107 | 108 | sequence_output = outputs[0] # [bsz * seqlen * hidden] 109 | 110 | if not self.config.flatten: 111 | output_sz = list(sequence_output.size()) 112 | expected_sz = [output_sz[0], output_sz[1]*self.config.ncols, -1] 113 | sequence_output = sequence_output.view(expected_sz) 114 | masked_lm_labels = masked_lm_labels.view(expected_sz[0], -1) 115 | 116 | prediction_scores = self.cls(sequence_output) # [bsz * seqlen * vocab_sz] 117 | 118 | outputs = (prediction_scores,) + outputs[2:] 119 | 120 | # prediction_scores : [bsz x seqlen x vsz] 121 | # masked_lm_labels : [bsz x seqlen] 122 | 123 | total_masked_lm_loss = 0 124 | 125 | seq_len = prediction_scores.size(1) 126 | # TODO : remove_target is True for card 127 | field_names = self.vocab.get_field_keys(remove_target=True, ignore_special=False) 128 | for field_idx, field_name in enumerate(field_names): 129 | col_ids = list(range(field_idx, seq_len, len(field_names))) 130 | 131 | global_ids_field = self.vocab.get_field_ids(field_name) 132 | 133 | prediction_scores_field = prediction_scores[:, col_ids, :][:, :, global_ids_field] # bsz * 10 * K 134 | masked_lm_labels_field = masked_lm_labels[:, col_ids] 135 | masked_lm_labels_field_local = self.vocab.get_from_global_ids(global_ids=masked_lm_labels_field, 136 | what_to_get='local_ids') 137 | 138 | nfeas = len(global_ids_field) 139 | loss_fct = self.get_criterion(field_name, nfeas, prediction_scores.device) 140 | 141 | masked_lm_loss_field = loss_fct(prediction_scores_field.view(-1, len(global_ids_field)), 142 | masked_lm_labels_field_local.view(-1)) 143 | 144 | total_masked_lm_loss += masked_lm_loss_field 145 | 146 | return (total_masked_lm_loss,) + outputs 147 | 148 | def get_criterion(self, fname, vs, device, cutoffs=False, div_value=4.0): 149 | 150 | if fname in self.vocab.adap_sm_cols: 151 | if not cutoffs: 152 | cutoffs = [int(vs/15), 3*int(vs/15), 6*int(vs/15)] 153 | 154 | criteria = CustomAdaptiveLogSoftmax(in_features=vs, n_classes=vs, cutoffs=cutoffs, div_value=div_value) 155 | 156 | return criteria.to(device) 157 | else: 158 | return CrossEntropyLoss() 159 | 160 | class TabFormerBertModel(BertForMaskedLM): 161 | def __init__(self, config): 162 | super().__init__(config) 163 | 164 | self.cls = TabFormerBertOnlyMLMHead(config) 165 | self.init_weights() 166 | 167 | def forward( 168 | self, 169 | input_ids=None, 170 | attention_mask=None, 171 | token_type_ids=None, 172 | position_ids=None, 173 | head_mask=None, 174 | inputs_embeds=None, 175 | masked_lm_labels=None, 176 | encoder_hidden_states=None, 177 | encoder_attention_mask=None, 178 | lm_labels=None, 179 | ): 180 | outputs = self.bert( 181 | input_ids, 182 | attention_mask=attention_mask, 183 | token_type_ids=token_type_ids, 184 | position_ids=position_ids, 185 | head_mask=head_mask, 186 | inputs_embeds=inputs_embeds, 187 | encoder_hidden_states=encoder_hidden_states, 188 | encoder_attention_mask=encoder_attention_mask, 189 | ) 190 | 191 | sequence_output = outputs[0] # [bsz * seqlen * hidden] 192 | 193 | return sequence_output -------------------------------------------------------------------------------- /models/tabformer_gpt2.py: -------------------------------------------------------------------------------- 1 | from torch.nn import CrossEntropyLoss 2 | 3 | from transformers.modeling_gpt2 import GPT2LMHeadModel 4 | 5 | 6 | class TabFormerGPT2LMHeadModel(GPT2LMHeadModel): 7 | def __init__(self, config, vocab): 8 | super().__init__(config) 9 | self.vocab = vocab 10 | 11 | def forward( 12 | self, 13 | input_ids=None, 14 | past=None, 15 | attention_mask=None, 16 | token_type_ids=None, 17 | position_ids=None, 18 | head_mask=None, 19 | inputs_embeds=None, 20 | labels=None, 21 | use_cache=True, 22 | ): 23 | transformer_outputs = self.transformer( 24 | input_ids, 25 | past=past, 26 | attention_mask=attention_mask, 27 | token_type_ids=token_type_ids, 28 | position_ids=position_ids, 29 | head_mask=head_mask, 30 | inputs_embeds=inputs_embeds, 31 | use_cache=use_cache, 32 | ) 33 | hidden_states = transformer_outputs[0] 34 | lm_logits = self.lm_head(hidden_states) 35 | 36 | # lm_logits : [bsz x seq_len x vsz] 37 | # labels : [bsz x seq_len] 38 | # When flatten is set to True: 39 | # seq_len = num_transactions * (num_columns + 2) --> plus 2 because each transaction has BOS and EOS padding 40 | 41 | outputs = (lm_logits,) + transformer_outputs[1:] 42 | if labels is not None: 43 | # Shift so that tokens < n predict n 44 | shift_labels = labels[..., 1:-1].contiguous() # Remove first and last label: [BOS] and [EOS] tokens 45 | shift_logits = lm_logits[..., :-2, :].contiguous() # Line up logits accordingly 46 | 47 | seq_len = shift_logits.size(1) 48 | total_lm_loss = 0 49 | field_names = self.vocab.get_field_keys(remove_target=True, ignore_special=True) 50 | 51 | for field_idx, field_name in enumerate(field_names): 52 | col_ids = list(range(field_idx, seq_len, len(field_names))) 53 | global_ids_field = self.vocab.get_field_ids(field_name) 54 | lm_logits_field = shift_logits[:, col_ids, :][:, :, global_ids_field] # bsz * 10 * K 55 | lm_labels_field = shift_labels[:, col_ids] 56 | lm_labels_local_field = self.vocab.get_from_global_ids(global_ids=lm_labels_field, 57 | what_to_get='local_ids') 58 | 59 | loss_fct = CrossEntropyLoss() 60 | lm_loss_field = loss_fct(lm_logits_field.view(-1, len(global_ids_field)), 61 | lm_labels_local_field.view(-1)) 62 | total_lm_loss += lm_loss_field 63 | 64 | outputs = (total_lm_loss,) + outputs 65 | 66 | return outputs # (loss), lm_logits, presents, (all hidden_states), (attentions) 67 | -------------------------------------------------------------------------------- /models/tabformer_tokenizer.py: -------------------------------------------------------------------------------- 1 | from transformers.tokenization_utils import PreTrainedTokenizer 2 | 3 | class TabFormerTokenizer(PreTrainedTokenizer): 4 | def __init__( 5 | self, 6 | unk_token="<|endoftext|>", 7 | bos_token="<|endoftext|>", 8 | eos_token="<|endoftext|>", 9 | ): 10 | 11 | super().__init__(bos_token=bos_token, eos_token=eos_token, unk_token=unk_token) -------------------------------------------------------------------------------- /setup.yml: -------------------------------------------------------------------------------- 1 | name: tabformer 2 | channels: 3 | - anaconda 4 | - pytorch 5 | - huggingface 6 | - conda-forge 7 | dependencies: 8 | - python>=3.8 9 | - pip>=21.0 10 | - pytorch=1.7.1=py3.8_cuda11.0.221_cudnn8.0.5_0 11 | - torchvision 12 | - pandas 13 | - scikit-learn 14 | - transformers 15 | - numpy 16 | - libgcc 17 | - pip: 18 | - transformers==3.2.0 19 | --------------------------------------------------------------------------------