├── DAN_Task.py ├── Figures └── DAN.jpg ├── LICENSE ├── README.md ├── abstract_model.py ├── config ├── MSLR.yaml ├── cardio.yaml ├── click.yaml ├── default.py ├── epsilon.yaml ├── forest_cover_type.yaml ├── yahoo.yaml └── year.yaml ├── data ├── data_util.py └── dataset.py ├── lib ├── callbacks.py ├── logger.py ├── metrics.py ├── multiclass_utils.py └── utils.py ├── main.py ├── model ├── AcceleratedModule.py ├── DANet.py └── sparsemax.py ├── predict.py └── requirements.txt /DAN_Task.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | from scipy.special import softmax 4 | from lib.utils import PredictDataset 5 | from abstract_model import DANsModel 6 | from lib.multiclass_utils import infer_output_dim, check_output_dim 7 | from torch.utils.data import DataLoader 8 | from torch.nn.functional import cross_entropy, mse_loss 9 | 10 | class DANetClassifier(DANsModel): 11 | def __post_init__(self): 12 | super(DANetClassifier, self).__post_init__() 13 | self._task = 'classification' 14 | self._default_loss = cross_entropy 15 | self._default_metric = 'accuracy' 16 | 17 | def weight_updater(self, weights): 18 | """ 19 | Updates weights dictionary according to target_mapper. 20 | 21 | Parameters 22 | ---------- 23 | weights : bool or dict 24 | Given weights for balancing training. 25 | 26 | Returns 27 | ------- 28 | bool or dict 29 | Same bool if weights are bool, updated dict otherwise. 30 | 31 | """ 32 | if isinstance(weights, int): 33 | return weights 34 | elif isinstance(weights, dict): 35 | return {self.target_mapper[key]: value for key, value in weights.items()} 36 | else: 37 | return weights 38 | 39 | def prepare_target(self, y): 40 | return np.vectorize(self.target_mapper.get)(y) 41 | 42 | def compute_loss(self, y_pred, y_true): 43 | return self.loss_fn(y_pred, y_true.long()) 44 | 45 | def update_fit_params( 46 | self, 47 | X_train, 48 | y_train, 49 | eval_set 50 | ): 51 | output_dim, train_labels = infer_output_dim(y_train) 52 | for X, y in eval_set: 53 | check_output_dim(train_labels, y) 54 | self.output_dim = output_dim 55 | self._default_metric = 'accuracy' 56 | self.classes_ = train_labels 57 | self.target_mapper = {class_label: index for index, class_label in enumerate(self.classes_)} 58 | self.preds_mapper = {str(index): class_label for index, class_label in enumerate(self.classes_)} 59 | 60 | def stack_batches(self, list_y_true, list_y_score): 61 | y_true = np.hstack(list_y_true) 62 | y_score = np.vstack(list_y_score) 63 | y_score = softmax(y_score, axis=1) 64 | return y_true, y_score 65 | 66 | def predict_func(self, outputs): 67 | outputs = np.argmax(outputs, axis=1) 68 | return outputs 69 | 70 | def predict_proba(self, X): 71 | """ 72 | Make predictions for classification on a batch (valid) 73 | 74 | Parameters 75 | ---------- 76 | X : a :tensor: `torch.Tensor` 77 | Input data 78 | 79 | Returns 80 | ------- 81 | res : np.ndarray 82 | 83 | """ 84 | self.network.eval() 85 | 86 | dataloader = DataLoader( 87 | PredictDataset(X), 88 | batch_size=1024, 89 | shuffle=False, 90 | ) 91 | 92 | results = [] 93 | for batch_nb, data in enumerate(dataloader): 94 | data = data.to(self.device).float() 95 | output = self.network(data) 96 | predictions = torch.nn.Softmax(dim=1)(output).cpu().detach().numpy() 97 | results.append(predictions) 98 | res = np.vstack(results) 99 | return res 100 | 101 | 102 | class DANetRegressor(DANsModel): 103 | def __post_init__(self): 104 | super(DANetRegressor, self).__post_init__() 105 | self._task = 'regression' 106 | self._default_loss = mse_loss 107 | self._default_metric = 'mse' 108 | 109 | def prepare_target(self, y): 110 | return y 111 | 112 | def compute_loss(self, y_pred, y_true): 113 | return self.loss_fn(y_pred, y_true) 114 | 115 | def update_fit_params( 116 | self, 117 | X_train, 118 | y_train, 119 | eval_set 120 | ): 121 | if len(y_train.shape) != 2: 122 | msg = "Targets should be 2D : (n_samples, n_regression) " + \ 123 | f"but y_train.shape={y_train.shape} given.\n" + \ 124 | "Use reshape(-1, 1) for single regression." 125 | raise ValueError(msg) 126 | self.output_dim = y_train.shape[1] 127 | self.preds_mapper = None 128 | 129 | 130 | def predict_func(self, outputs): 131 | return outputs 132 | 133 | def stack_batches(self, list_y_true, list_y_score): 134 | y_true = np.vstack(list_y_true) 135 | y_score = np.vstack(list_y_score) 136 | return y_true, y_score 137 | -------------------------------------------------------------------------------- /Figures/DAN.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/WhatAShot/DANet/b007c57121ec9082f6ef19ec7465d9df70767c26/Figures/DAN.jpg -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2021 Ronnie Rocket 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Deep Abstract Networks 2 | A PyTorch implementation of AAAI-2022 paper **[DANets: Deep Abstract Networks for Tabular Data Classification and Regression](https://arxiv.org/abs/2112.02962)** for reference. 3 | 4 | ## Brief Introduction 5 | Tabular data are ubiquitous in real world applications. Although many commonly-used neural components (e.g., convolution) and extensible neural networks (e.g., ResNet) have been developed by the machine learning community, few of them were effective for tabular data and few designs were adequately tailored for tabular data structures. In this paper, we propose a novel and flexible neural component for tabular data, called Abstract Layer (AbstLay), which learns to explicitly group correlative input features and generate higher-level features for semantics abstraction. Also, we design a structure re-parameterization method to compress AbstLay, thus reducing the computational complexity by a clear margin in the reference phase. A special basic block is built using AbstLays, and we construct a family of Deep Abstract Networks (DANets) for tabular data classification and regression by stacking such blocks. In DANets, a special shortcut path is introduced to fetch information from raw tabular features, assisting feature interactions across different levels. Comprehensive experiments on real-world tabular datasets show that our AbstLay and DANets are effective for tabular data classification and regression, and the computational complexity is superior to competitive methods. 6 | 7 | ## DANets illustration 8 | ![DANets](./Figures/DAN.jpg) 9 | 10 | ## Downloads 11 | ### Dataset 12 | Download the datasets from the following links: 13 | - [Cardiovascular Disease](https://www.kaggle.com/sulianova/cardiovascular-disease-dataset) 14 | - [Click](https://www.kaggle.com/c/kddcup2012-track2/) 15 | - [Epsilon](https://www.csie.ntu.edu.tw/~cjlin/libsvmtools/datasets/binary.html) 16 | - [Forest Cover Type](https://archive.ics.uci.edu/ml/datasets/covertype) 17 | - [Microsoft WEB-10K](https://www.microsoft.com/en-us/research/project/mslr/) 18 | - [Yahoo! Learn to Rank Challenge version 2.0](https://webscope.sandbox.yahoo.com/catalog.php?datatype=c) 19 | - [YearPrediction](https://archive.ics.uci.edu/ml/datasets/yearpredictionmsd) 20 | 21 | (Optional) Before starting the program, you may change the file format to `.pkl` by using `svm2pkl()` or `csv2pkl()` functions in `./data/data_util.py`. 22 | 23 | ## How to use 24 | 25 | ### Setting 26 | 1. Clone or download this repository, and `cd` the path. 27 | 2. Build a working python environment. Python 3.7 is fine for this repository. 28 | 3. Install packages following the `requirements.txt`, e.g., by using `pip install -r requirements.txt`. 29 | 30 | ### Training 31 | 1. Set the hyperparameters in config files (`./config/default.py ` or `./config/*.yaml`). 32 | Notably, the hyperparameters in `.yaml` file will cover those in `default.py`. 33 | 34 | 2. Run by `python main.py --c [config_path] --g [gpu_id]`. 35 | - `-c`: The config file path 36 | - `-g`: GPU device ID 37 | 3. The checkpoint models and best models will be saved at the `./logs` file. 38 | 39 | ### Inference 40 | 1. Replace the `resume_dir` path with the file path containing your trained model/weight. 41 | 2. Run codes by using `python predict.py -d [dataset_name] -m [model_file_path] -g [gpu_id]`. 42 | - `-d`: Dataset name 43 | - `-m`: Model path for loading 44 | - `-g`: GPU device ID 45 | 46 | ### Config Hyperparameters 47 | #### Normal parameters 48 | - `dataset`: str 49 | The dataset name given must match those in `./data/dataset.py`. 50 | 51 | - `task`: str 52 | Choose one of the pre-given tasks 'classification' and 'regression'. 53 | 54 | - `resume_dir`: str 55 | The log path containing the checkpoint models. 56 | 57 | - `logname`: str 58 | The directory names of the models save at `./logs`. 59 | 60 | - `seed`: int 61 | The random seed. 62 | 63 | #### Model parameters 64 | - `layer`: int (default=20) 65 | Number of abstract layers to stack 66 | 67 | - `k`: int (default=5) 68 | Number of masks 69 | 70 | - `base_outdim`: int (default=64) 71 | The output feature dimension in abstract layer. 72 | 73 | - `drop_rate`: float (default=0.1) 74 | Dropout rate in shortcut module 75 | 76 | #### Fit parameters 77 | - `lr`: float (default=0.008) 78 | Learning rate 79 | 80 | - `max_epochs`: int (default=5000) 81 | Maximum number of epochs in training. 82 | 83 | - `patience`: int (default=1500) 84 | Number of consecutive epochs without improvement before performing early stopping. If patience is set to 0, then no early stopping will be performed. 85 | 86 | - `batch_size`: int (default=8192) 87 | Number of examples per batch. 88 | 89 | - `virtual_batch_size`: int (default=256) 90 | Size of the mini batches used for "Ghost Batch Normalization". `virtual_batch_size` must divide `batch_size`. 91 | 92 | ### Citations 93 | ``` 94 | @inproceedings{danets, 95 | title={DANets: Deep Abstract Networks for Tabular Data Classification and Regression}, 96 | author={Chen, Jintai and Liao, Kuanlun and Wan, Yao and Chen, Danny Z and Wu, Jian}, 97 | booktitle={AAAI}, 98 | year={2022} 99 | } 100 | ``` 101 | 102 | 103 | -------------------------------------------------------------------------------- /abstract_model.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass, field 2 | from typing import List, Any, Dict 3 | import torch 4 | import torch.cuda 5 | from torch.nn.utils import clip_grad_norm_ 6 | from torch.nn.parallel import DataParallel 7 | from torch.utils.data import DataLoader 8 | from qhoptim.pyt import QHAdam 9 | import numpy as np 10 | from abc import abstractmethod 11 | from lib.utils import ( 12 | PredictDataset, 13 | validate_eval_set, 14 | create_dataloaders, 15 | define_device, 16 | ) 17 | from lib.callbacks import ( 18 | CallbackContainer, 19 | History, 20 | EarlyStopping, 21 | LRSchedulerCallback, 22 | ) 23 | from lib.logger import Train_Log 24 | from lib.metrics import MetricContainer, check_metrics 25 | from model.DANet import DANet 26 | from model.AcceleratedModule import AcceleratedCreator 27 | from sklearn.base import BaseEstimator 28 | from sklearn.utils import check_array 29 | 30 | @dataclass 31 | class DANsModel(BaseEstimator): 32 | """ Class for DANsModel model. 33 | """ 34 | std: int = None 35 | drop_rate: float = 0.1 36 | layer: int = 32 37 | base_outdim: int = 64 38 | k: int = 5 39 | clip_value: int = 2 40 | seed: int = 1 41 | verbose: int = 1 42 | optimizer_fn: Any = QHAdam 43 | optimizer_params: Dict = field(default_factory=lambda: dict(lr=8e-3, weight_decay=1e-5, nus=(0.8, 1.0))) 44 | scheduler_fn: Any = torch.optim.lr_scheduler.StepLR 45 | scheduler_params: Dict = field(default_factory=lambda: dict(gamma=0.95, step_size=20)) 46 | input_dim: int = None 47 | output_dim: int = None 48 | device_name: str = "auto" 49 | 50 | def __post_init__(self): 51 | torch.cuda.manual_seed_all(self.seed) 52 | torch.manual_seed(self.seed) 53 | np.random.seed(self.seed) 54 | # Defining device 55 | self.device = torch.device(define_device(self.device_name)) 56 | if self.verbose != 0: 57 | print(f"Device used : {self.device}") 58 | 59 | def fit( 60 | self, 61 | X_train, 62 | y_train, 63 | eval_set=None, 64 | eval_name=None, 65 | eval_metric=None, 66 | loss_fn=None, 67 | max_epochs=1000, 68 | patience=500, 69 | batch_size=8192, 70 | virtual_batch_size=256, 71 | callbacks=None, 72 | logname=None, 73 | resume_dir=None, 74 | n_gpu=1 75 | ): 76 | """Train a neural network stored in self.network 77 | Using train_dataloader for training data and 78 | valid_dataloader for validation. 79 | Parameters 80 | ---------- 81 | X_train : np.ndarray 82 | Train set 83 | y_train : np.array 84 | Train targets 85 | eval_set : list of tuple 86 | List of eval tuple set (X, y). 87 | The last one is used for early stopping 88 | eval_name : list of str 89 | List of eval set names. 90 | eval_metric : list of str 91 | List of evaluation metrics. 92 | The last metric is used for early stopping. 93 | loss_fn : callable or None 94 | a PyTorch loss function 95 | max_epochs : int 96 | Maximum number of epochs during training 97 | patience : int 98 | Number of consecutive non improving epoch before early stopping 99 | batch_size : int 100 | Training batch size 101 | virtual_batch_size : int 102 | Batch size for Ghost Batch Normalization (virtual_batch_size < batch_size) 103 | callbacks : list of callback function 104 | List of custom callbacks 105 | logname: str 106 | Setting log name 107 | resume_dir: str 108 | The resume file directory 109 | gpu_id: str 110 | Single GPU or Multi GPU ID 111 | """ 112 | self.max_epochs = max_epochs 113 | self.patience = patience 114 | self.batch_size = batch_size 115 | self.virtual_batch_size = virtual_batch_size 116 | self.input_dim = X_train.shape[1] 117 | self._stop_training = False 118 | self.log = Train_Log(logname, resume_dir) if (logname or resume_dir) else None 119 | self.n_gpu = n_gpu 120 | eval_set = eval_set if eval_set else [] 121 | 122 | self.loss_fn = self._default_loss if loss_fn is None else loss_fn 123 | check_array(X_train) 124 | 125 | self.update_fit_params(X_train, y_train, eval_set) 126 | # Validate and reformat eval set depending on training data 127 | eval_names, eval_set = validate_eval_set(eval_set, eval_name, X_train, y_train) 128 | train_dataloader, valid_dataloaders = self._construct_loaders(X_train, y_train, eval_set) 129 | 130 | self._set_network() 131 | self._set_metrics(eval_metric, eval_names) 132 | self._set_optimizer() 133 | self._set_callbacks(callbacks) 134 | 135 | if resume_dir: 136 | start_epoch, self.network, self._optimizer, best_value, best_epoch = self.log.load_checkpoint(self._optimizer) 137 | 138 | 139 | # Call method on_train_begin for all callbacks 140 | self._callback_container.on_train_begin() 141 | best_epoch = 1 142 | start_epoch = 1 143 | best_value = -float('inf') if self._task == 'classification' else float('inf') 144 | 145 | print("===> Start training ...") 146 | for epoch_idx in range(start_epoch, self.max_epochs + 1): 147 | self.epoch = epoch_idx 148 | # Call method on_epoch_begin for all callbacks 149 | self._callback_container.on_epoch_begin(epoch_idx) 150 | self._train_epoch(train_dataloader) 151 | 152 | # Apply predict epoch to all eval sets 153 | for eval_name, valid_dataloader in zip(eval_names, valid_dataloaders): 154 | self._predict_epoch(eval_name, valid_dataloader) 155 | 156 | # Call method on_epoch_end for all callbacks 157 | self._callback_container.on_epoch_end(epoch_idx, logs=self.history.epoch_metrics) 158 | 159 | #save checkpoint 160 | self.save_check() 161 | print('LR: ' + str(self._optimizer.param_groups[0]['lr'])) 162 | if self._stop_training: 163 | break 164 | 165 | # Call method on_train_end for all callbacks 166 | self._callback_container.on_train_end() 167 | self.network.eval() 168 | 169 | return best_value 170 | 171 | def predict(self, X): 172 | """ 173 | Make predictions on a batch (valid) 174 | Parameters 175 | ---------- 176 | X : a :tensor: `torch.Tensor` 177 | Input data 178 | Returns 179 | ------- 180 | predictions : np.array 181 | Predictions of the regression problem 182 | """ 183 | self.network.eval() 184 | dataloader = DataLoader(PredictDataset(X), batch_size=1024, shuffle=False, pin_memory=True) 185 | results = [] 186 | print('===> Starting test ... ') 187 | for batch_nb, data in enumerate(dataloader): 188 | data = data.to(self.device).float() 189 | with torch.no_grad(): 190 | output = self.network(data) 191 | predictions = output.cpu().detach().numpy() 192 | results.append(predictions) 193 | res = np.vstack(results) 194 | return self.predict_func(res) 195 | 196 | def save_check(self): 197 | save_dict = { 198 | 'epoch': self.epoch, 199 | 'model': self.network, 200 | # 'state_dict': self.network.state_dict(), 201 | 'optimizer': self._optimizer.state_dict(), 202 | 'best_value': self._callback_container.callbacks[1].best_loss, 203 | "best_epoch": self._callback_container.callbacks[1].best_epoch 204 | } 205 | torch.save(save_dict, self.log.log_dir + '/checkpoint.pth') 206 | 207 | 208 | def load_model(self, filepath, input_dim, output_dim, n_gpu=1): 209 | """Load DANet model. 210 | Parameters 211 | ---------- 212 | filepath : str 213 | Path of the model. 214 | """ 215 | self.input_dim = input_dim 216 | self.output_dim = output_dim 217 | self.n_gpu = n_gpu 218 | load_model = torch.load(filepath, map_location=self.device) 219 | self.layer, self.virtual_batch_size = load_model['layer_num'], load_model['virtual_batch_size'] 220 | self.k, self.base_outdim = load_model['k'], load_model['base_outdim'] 221 | self._set_network() 222 | self.network.load_state_dict(load_model['state_dict']) 223 | self.network.eval() 224 | accelerated_module = AcceleratedCreator(self.input_dim, base_out_dim=self.base_outdim, k=self.k) 225 | self.network = accelerated_module(self.network) 226 | return 227 | 228 | def _train_epoch(self, train_loader): 229 | """ 230 | Trains one epoch of the network in self.network 231 | Parameters 232 | ---------- 233 | train_loader : a :class: `torch.utils.data.Dataloader` 234 | DataLoader with train set 235 | """ 236 | self.network.train() 237 | loss = [] 238 | for batch_idx, (X, y) in enumerate(train_loader): 239 | self._callback_container.on_batch_begin(batch_idx) 240 | batch_logs = self._train_batch(X, y) 241 | 242 | self._callback_container.on_batch_end(batch_idx, batch_logs) 243 | loss.append(batch_logs['loss']) 244 | 245 | epoch_logs = {"lr": self._optimizer.param_groups[-1]["lr"], "loss": np.mean(loss)} 246 | 247 | self.history.epoch_metrics.update(epoch_logs) 248 | return 249 | 250 | def _train_batch(self, X, y): 251 | """ 252 | Trains one batch of data 253 | Parameters 254 | ---------- 255 | X : torch.Tensor 256 | Train matrix 257 | y : torch.Tensor 258 | Target matrix 259 | Returns 260 | ------- 261 | batch_outs : dict 262 | Dictionnary with "y": target and "score": prediction scores. 263 | batch_logs : dict 264 | Dictionnary with "batch_size" and "loss". 265 | """ 266 | batch_logs = {"batch_size": X.shape[0]} 267 | 268 | X = X.to(self.device).float() 269 | y = y.to(self.device).float() 270 | 271 | self._optimizer.zero_grad() 272 | output = self.network(X) 273 | loss = self.compute_loss(output, y) 274 | # Perform backward pass and optimization 275 | 276 | loss.backward() 277 | if self.clip_value: 278 | clip_grad_norm_(self.network.parameters(), self.clip_value) 279 | self._optimizer.step() 280 | 281 | batch_logs["loss"] = loss.cpu().detach().numpy().item() 282 | 283 | return batch_logs 284 | 285 | def _predict_epoch(self, name, loader): 286 | """ 287 | Predict an epoch and update metrics. 288 | Parameters 289 | ---------- 290 | name : str 291 | Name of the validation set 292 | loader : torch.utils.data.Dataloader 293 | DataLoader with validation set 294 | """ 295 | # Setting network on evaluation mode 296 | self.network.eval() 297 | list_y_true = [] 298 | list_y_score = [] 299 | 300 | # Main loop 301 | for batch_idx, (X, y) in enumerate(loader): 302 | scores = self._predict_batch(X) 303 | list_y_true.append(y) 304 | list_y_score.append(scores) 305 | 306 | y_true, scores = self.stack_batches(list_y_true, list_y_score) 307 | 308 | metrics_logs = self._metric_container_dict[name](y_true, scores) 309 | if self._task == 'regression': 310 | for k, v in metrics_logs.items(): 311 | metrics_logs[k] = v * self.std ** 2 312 | self.network.train() 313 | self.history.epoch_metrics.update(metrics_logs) 314 | return 315 | 316 | def _predict_batch(self, X): 317 | """ 318 | Predict one batch of data. 319 | Parameters 320 | ---------- 321 | X : torch.Tensor 322 | Owned products 323 | Returns 324 | ------- 325 | np.array 326 | model scores 327 | """ 328 | X = X.to(self.device).float() 329 | 330 | # compute model output 331 | with torch.no_grad(): 332 | scores = self.network(X) 333 | if isinstance(scores, list): 334 | scores = [x.cpu().detach().numpy() for x in scores] 335 | else: 336 | scores = scores.cpu().detach().numpy() 337 | 338 | return scores 339 | 340 | @abstractmethod 341 | def update_fit_params(self, X_train, y_train, eval_set): 342 | """ 343 | Set attributes relative to fit function. 344 | Parameters 345 | ---------- 346 | X_train : np.ndarray 347 | Train set 348 | y_train : np.array 349 | Train targets 350 | eval_set : list of tuple 351 | List of eval tuple set (X, y). 352 | """ 353 | raise NotImplementedError( 354 | "users must define update_fit_params to use this base class" 355 | ) 356 | 357 | def _set_network(self): 358 | """Setup the network and explain matrix.""" 359 | print("===> Building model ...") 360 | params = {'layer_num': self.layer, 361 | 'base_outdim': self.base_outdim, 362 | 'k': self.k, 363 | 'virtual_batch_size': self.virtual_batch_size, 364 | 'drop_rate': self.drop_rate, 365 | } 366 | 367 | self.network = DANet(self.input_dim, self.output_dim, **params) 368 | if self.n_gpu > 1 and self.device == 'cuda': 369 | self.network = DataParallel(self.network) 370 | self.network = self.network.to(self.device) 371 | 372 | def _set_metrics(self, metrics, eval_names): 373 | """Set attributes relative to the metrics. 374 | Parameters 375 | ---------- 376 | metrics : list of str 377 | List of eval metric names. 378 | eval_names : list of str 379 | List of eval set names. 380 | """ 381 | metrics = metrics or [self._default_metric] 382 | 383 | metrics = check_metrics(metrics) 384 | # Set metric container for each sets 385 | self._metric_container_dict = {} 386 | for name in eval_names: 387 | self._metric_container_dict.update( 388 | {name: MetricContainer(metrics, prefix=f"{name}_")} 389 | ) 390 | 391 | self._metrics = [] 392 | self._metrics_names = [] 393 | for _, metric_container in self._metric_container_dict.items(): 394 | self._metrics.extend(metric_container.metrics) 395 | self._metrics_names.extend(metric_container.names) 396 | 397 | # Early stopping metric is the last eval metric 398 | 399 | self.early_stopping_metric = self._metrics_names[-1] if len(self._metrics_names) > 0 else None 400 | 401 | def _set_callbacks(self, custom_callbacks): 402 | """Setup the callbacks functions. 403 | Parameters 404 | ---------- 405 | custom_callbacks : list of func 406 | List of callback functions. 407 | """ 408 | # Setup default callbacks history, early stopping and scheduler 409 | callbacks = [] 410 | self.history = History(self, verbose=self.verbose) 411 | callbacks.append(self.history) 412 | if (self.early_stopping_metric is not None) and (self.patience > 0): 413 | early_stopping = EarlyStopping( 414 | early_stopping_metric=self.early_stopping_metric, 415 | is_maximize=self._metrics[-1]._maximize if len(self._metrics) > 0 else None, 416 | patience=self.patience, 417 | ) 418 | callbacks.append(early_stopping) 419 | else: 420 | print("No early stopping will be performed, last training weights will be used.") 421 | 422 | if self.scheduler_fn is not None: 423 | # Add LR Scheduler call_back 424 | is_batch_level = self.scheduler_params.pop("is_batch_level", False) 425 | scheduler = LRSchedulerCallback( 426 | scheduler_fn=self.scheduler_fn, 427 | scheduler_params=self.scheduler_params, 428 | optimizer=self._optimizer, 429 | early_stopping_metric=self.early_stopping_metric, 430 | is_batch_level=is_batch_level, 431 | ) 432 | callbacks.append(scheduler) 433 | 434 | if custom_callbacks: 435 | callbacks.extend(custom_callbacks) 436 | self._callback_container = CallbackContainer(callbacks) 437 | self._callback_container.set_trainer(self) 438 | 439 | def _set_optimizer(self): 440 | """Setup optimizer.""" 441 | self._optimizer = self.optimizer_fn(self.network.parameters(), **self.optimizer_params) 442 | 443 | def _construct_loaders(self, X_train, y_train, eval_set): 444 | """Generate dataloaders for train and eval set. 445 | Parameters 446 | ---------- 447 | X_train : np.array 448 | Train set. 449 | y_train : np.array 450 | Train targets. 451 | eval_set : list of tuple 452 | List of eval tuple set (X, y). 453 | Returns 454 | ------- 455 | train_dataloader : `torch.utils.data.Dataloader` 456 | Training dataloader. 457 | valid_dataloaders : list of `torch.utils.data.Dataloader` 458 | List of validation dataloaders. 459 | """ 460 | # all weights are not allowed for this type of model 461 | y_train_mapped = self.prepare_target(y_train) 462 | for i, (X, y) in enumerate(eval_set): 463 | y_mapped = self.prepare_target(y) 464 | eval_set[i] = (X, y_mapped) 465 | 466 | train_dataloader, valid_dataloaders = create_dataloaders( 467 | X_train, 468 | y_train_mapped, 469 | eval_set, 470 | self.batch_size 471 | ) 472 | return train_dataloader, valid_dataloaders 473 | 474 | 475 | def _update_network_params(self): 476 | self.network.virtual_batch_size = self.virtual_batch_size 477 | 478 | @abstractmethod 479 | def compute_loss(self, y_score, y_true): 480 | """ 481 | Compute the loss. 482 | Parameters 483 | ---------- 484 | y_score : a :tensor: `torch.Tensor` 485 | Score matrix 486 | y_true : a :tensor: `torch.Tensor` 487 | Target matrix 488 | Returns 489 | ------- 490 | float 491 | Loss value 492 | """ 493 | raise NotImplementedError( 494 | "users must define compute_loss to use this base class" 495 | ) 496 | 497 | @abstractmethod 498 | def prepare_target(self, y): 499 | """ 500 | Prepare target before training. 501 | Parameters 502 | ---------- 503 | y : a :tensor: `torch.Tensor` 504 | Target matrix. 505 | Returns 506 | ------- 507 | `torch.Tensor` 508 | Converted target matrix. 509 | """ 510 | raise NotImplementedError( 511 | "users must define prepare_target to use this base class" 512 | ) 513 | -------------------------------------------------------------------------------- /config/MSLR.yaml: -------------------------------------------------------------------------------- 1 | dataset: 'MSLR' 2 | task: 'regression' 3 | resume_dir: '' 4 | logname: 'layer20' 5 | 6 | fit: 7 | max_epochs: 2000 8 | patience: 500 9 | lr: 0.008 10 | 11 | model: 12 | layer: 20 13 | base_outdim: 64 14 | k: 5 15 | drop_rate: 0.1 -------------------------------------------------------------------------------- /config/cardio.yaml: -------------------------------------------------------------------------------- 1 | dataset: 'cardio' 2 | task: 'classification' 3 | resume_dir: '' 4 | logname: 'layer8' 5 | 6 | fit: 7 | max_epochs: 500 8 | patience: 200 9 | lr: 0.008 10 | 11 | model: 12 | layer: 8 13 | base_outdim: 64 14 | k: 5 15 | drop_rate: 0.1 -------------------------------------------------------------------------------- /config/click.yaml: -------------------------------------------------------------------------------- 1 | dataset: 'click' 2 | task: 'classification' 3 | resume_dir: '' 4 | logname: '' 5 | 6 | fit: 7 | max_epochs: 1700 8 | patience: 1000 9 | lr: 0.008 10 | 11 | model: 12 | layer: 8 13 | base_outdim: 64 14 | k: 5 15 | drop_rate: 0.1 -------------------------------------------------------------------------------- /config/default.py: -------------------------------------------------------------------------------- 1 | from yacs.config import CfgNode as Node 2 | cfg = Node() 3 | cfg.seed = 324 4 | cfg.dataset = 'forest_cover_type' 5 | cfg.task = 'classification' 6 | cfg.resume_dir = '' 7 | cfg.logname = '' 8 | 9 | cfg.model = Node() 10 | cfg.model.base_outdim = 64 11 | cfg.model.k = 5 12 | cfg.model.drop_rate = 0.1 13 | cfg.model.layer = 20 14 | 15 | cfg.fit = Node() 16 | cfg.fit.lr = 0.008 17 | cfg.fit.max_epochs = 4000 18 | cfg.fit.patience = 1500 19 | cfg.fit.batch_size = 8192 20 | cfg.fit.virtual_batch_size = 256 21 | -------------------------------------------------------------------------------- /config/epsilon.yaml: -------------------------------------------------------------------------------- 1 | dataset: 'epsilon' 2 | task: 'classification' 3 | train_ratio: 1.0 4 | resume_dir: '' 5 | logname: 'layer32' 6 | 7 | fit: 8 | max_epochs: 1500 9 | patience: 500 10 | lr: 0.02 11 | 12 | model: 13 | layer: 32 14 | base_outdim: 96 15 | k: 8 16 | drop_rate: 0.1 -------------------------------------------------------------------------------- /config/forest_cover_type.yaml: -------------------------------------------------------------------------------- 1 | dataset: 'forest' 2 | task: 'classification' 3 | resume_dir: '' 4 | logname: 'layer20' 5 | 6 | fit: 7 | max_epochs: 5000 8 | patience: 1500 9 | lr: 0.008 10 | 11 | model: 12 | layer: 20 13 | base_outdim: 64 14 | k: 5 15 | -------------------------------------------------------------------------------- /config/yahoo.yaml: -------------------------------------------------------------------------------- 1 | dataset: 'yahoo' 2 | task: 'regression' 3 | resume_dir: '' 4 | logname: 'layer32' 5 | 6 | fit: 7 | max_epochs: 2000 8 | patience: 500 9 | lr: 0.02 10 | 11 | model: 12 | layer: 32 13 | base_outdim: 96 14 | k: 8 15 | drop_rate: 0.1 -------------------------------------------------------------------------------- /config/year.yaml: -------------------------------------------------------------------------------- 1 | dataset: 'year' 2 | task: 'regression' 3 | resume_dir: '' 4 | logname: '' 5 | 6 | fit: 7 | max_epochs: 150 8 | patience: 80 9 | lr: 0.008 10 | 11 | model: 12 | layer: 20 13 | base_outdim: 64 14 | k: 5 15 | drop_rate: 0.1 -------------------------------------------------------------------------------- /data/data_util.py: -------------------------------------------------------------------------------- 1 | import pandas as pd 2 | import os 3 | from sklearn.datasets import load_svmlight_file 4 | 5 | def svm2pkl(source, save_path): 6 | X_train, y_train = load_svmlight_file(os.path.join(source, 'train')) 7 | X_valid, y_valid = load_svmlight_file(os.path.join(source, 'vali')) 8 | X_test, y_test = load_svmlight_file(os.path.join(source, 'test')) 9 | 10 | X_train = pd.DataFrame(X_train.todense()) 11 | y_train = pd.Series(y_train) 12 | pd.concat([y_train, X_train], axis=1).T.reset_index(drop=True).T.to_pickle(os.path.join(save_path, 'train.pkl')) 13 | 14 | X_valid = pd.DataFrame(X_valid.todense()) 15 | y_valid = pd.Series(y_valid) 16 | pd.concat([y_valid, X_valid], axis=1).T.reset_index(drop=True).T.to_pickle(os.path.join(save_path, 'valid.pkl')) 17 | 18 | X_test = pd.DataFrame(X_test.todense()) 19 | y_test = pd.Series(y_test) 20 | pd.concat([y_test, X_test], axis=1).T.reset_index(drop=True).T.to_pickle(os.path.join(save_path, 'test.pkl')) 21 | 22 | def csv2pkl(source, save_path): 23 | data = pd.read_csv(source) 24 | data.to_pickle(save_path) 25 | 26 | if __name__ == '__main__': 27 | source = '/data/dataset/MSLR-WEB10K/Fold1' 28 | save_path = './data/MSLR-WEB10K' 29 | svm2pkl(source, save_path) 30 | -------------------------------------------------------------------------------- /data/dataset.py: -------------------------------------------------------------------------------- 1 | import pandas as pd 2 | import numpy as np 3 | from sklearn.preprocessing import QuantileTransformer 4 | from sklearn.model_selection import train_test_split 5 | from category_encoders import LeaveOneOutEncoder 6 | 7 | def remove_unused_column(data): 8 | unused_list = [] 9 | for col in data.columns: 10 | uni = len(data[col].unique()) 11 | if uni <= 1: 12 | unused_list.append(col) 13 | data.drop(columns=unused_list, inplace=True) 14 | return data 15 | 16 | def split_data(data, target, test_size): 17 | label = data[target] 18 | data = data.drop([target], axis=1) 19 | X_train, X_test, y_train, y_test = train_test_split(data, label, test_size=test_size, random_state=123, shuffle=True) 20 | return X_train, y_train.values, X_test, y_test.values 21 | 22 | 23 | def quantile_transform(X_train, X_valid, X_test): 24 | quantile_train = np.copy(X_train) 25 | qt = QuantileTransformer(random_state=55688, output_distribution='normal').fit(quantile_train) 26 | X_train = qt.transform(X_train) 27 | X_valid = qt.transform(X_valid) 28 | X_test = qt.transform(X_test) 29 | 30 | return X_train, X_valid, X_test 31 | 32 | def forest_cover(): 33 | target = "Covertype" 34 | 35 | bool_columns = [ 36 | "Wilderness_Area1", "Wilderness_Area2", "Wilderness_Area3", 37 | "Wilderness_Area4", "Soil_Type1", "Soil_Type2", "Soil_Type3", "Soil_Type4", 38 | "Soil_Type5", "Soil_Type6", "Soil_Type7", "Soil_Type8", "Soil_Type9", 39 | "Soil_Type10", "Soil_Type11", "Soil_Type12", "Soil_Type13", "Soil_Type14", 40 | "Soil_Type15", "Soil_Type16", "Soil_Type17", "Soil_Type18", "Soil_Type19", 41 | "Soil_Type20", "Soil_Type21", "Soil_Type22", "Soil_Type23", "Soil_Type24", 42 | "Soil_Type25", "Soil_Type26", "Soil_Type27", "Soil_Type28", "Soil_Type29", 43 | "Soil_Type30", "Soil_Type31", "Soil_Type32", "Soil_Type33", "Soil_Type34", 44 | "Soil_Type35", "Soil_Type36", "Soil_Type37", "Soil_Type38", "Soil_Type39", 45 | "Soil_Type40" 46 | ] 47 | 48 | int_columns = [ 49 | "Elevation", "Aspect", "Slope", "Horizontal_Distance_To_Hydrology", 50 | "Vertical_Distance_To_Hydrology", "Horizontal_Distance_To_Roadways", 51 | "Hillshade_9am", "Hillshade_Noon", "Hillshade_3pm", 52 | "Horizontal_Distance_To_Fire_Points" 53 | ] 54 | feature = int_columns + bool_columns + [target] 55 | data = pd.read_csv('./data/forest_cover_type/forest-cover-type.csv', header=None, names=feature) 56 | train_idx = pd.read_csv('./data/forest_cover_type/train_idx.csv', header=None)[0].values 57 | train = data.iloc[train_idx, :] 58 | valid_idx = pd.read_csv('./data/forest_cover_type/valid_idx.csv', header=None)[0].values 59 | valid = data.iloc[valid_idx, :] 60 | test_idx = pd.read_csv('./data/forest_cover_type/test_idx.csv', header=None)[0].values 61 | test = data.iloc[test_idx, :] 62 | 63 | 64 | y_train = train[target].values 65 | X_train = train.drop([target], axis=1).values 66 | y_valid = valid[target].values 67 | X_valid = valid.drop([target], axis=1).values 68 | y_test = test[target].values 69 | X_test = test.drop([target], axis=1).values 70 | 71 | mean = np.mean(X_train, axis=0) 72 | std = np.std(X_train, axis=0) 73 | X_train = (X_train - mean) / std 74 | X_valid = (X_valid - mean) / std 75 | X_test = (X_test - mean) / std 76 | 77 | return X_train, y_train, X_valid, y_valid, X_test, y_test 78 | 79 | def MSLR(): 80 | target = 0 81 | train = pd.read_pickle('./data/MSLR-WEB10K/train.pkl') 82 | valid = pd.read_pickle('./data/MSLR-WEB10K/valid.pkl') 83 | test = pd.read_pickle('./data/MSLR-WEB10K/test.pkl') 84 | 85 | y_train = train[target].values 86 | y_valid = valid[target].values 87 | y_test = test[target].values 88 | 89 | train.drop([target], axis=1, inplace=True) 90 | valid.drop([target], axis=1, inplace=True) 91 | test.drop([target], axis=1, inplace=True) 92 | X_train, X_valid, X_test = quantile_transform(train, valid, test) 93 | 94 | return X_train, y_train, X_valid, y_valid, X_test, y_test 95 | 96 | def yahoo(): 97 | target = 0 98 | train = pd.read_pickle('./data/yahoo/train.pkl') 99 | valid = pd.read_pickle('./data/yahoo/valid.pkl') 100 | test = pd.read_pickle('./data/yahoo/test.pkl') 101 | train = remove_unused_column(train) 102 | valid = remove_unused_column(valid) 103 | test = remove_unused_column(test) 104 | 105 | y_train = train[target].values 106 | y_valid = valid[target].values 107 | y_test = test[target].values 108 | 109 | train.drop([target], axis=1, inplace=True) 110 | valid.drop([target], axis=1, inplace=True) 111 | test.drop([target], axis=1, inplace=True) 112 | X_train, X_valid, X_test = quantile_transform(train, valid, test) 113 | return X_train, y_train, X_valid, y_valid, X_test, y_test 114 | 115 | def yearpred(): 116 | target = 0 117 | data = pd.read_pickle('./data/yearpred/YearPrediction.pkl') 118 | 119 | train_idx = pd.read_csv('./data/yearpred/train_idx.csv')['0'].values 120 | train = data.iloc[train_idx, :] 121 | valid_idx = pd.read_csv('./data/yearpred/valid_idx.csv')['0'].values 122 | valid = data.iloc[valid_idx, :] 123 | test_idx = pd.read_csv('./data/yearpred/test_idx.csv')['0'].values 124 | test = data.iloc[test_idx, :] 125 | 126 | y_train = train[target].values 127 | X_train = train.drop([target], axis=1).values 128 | y_valid = valid[target].values 129 | X_valid = valid.drop([target], axis=1).values 130 | y_test = test[target].values 131 | X_test = test.drop([target], axis=1).values 132 | 133 | X_train, X_valid, X_test = quantile_transform(X_train, X_valid, X_test) 134 | return X_train, y_train, X_valid, y_valid, X_test, y_test 135 | 136 | def epsilon(): 137 | target = 0 138 | train = pd.read_pickle('./data/epsilon/train.pkl') 139 | test = pd.read_pickle('./data/epsilon/test.pkl') 140 | 141 | train_idx = pd.read_csv('./data/epsilon/train_idx.csv')['0'].values 142 | train = train.iloc[train_idx, :] 143 | valid_idx = pd.read_csv('./data/epsilon/valid_idx.csv')['0'].values 144 | valid = train.iloc[valid_idx, :] 145 | 146 | y_train = train[target].values 147 | y_valid = valid[target].values 148 | y_test = test[target].values 149 | X_train = train.drop([target], axis=1).values 150 | X_valid = valid.drop([target], axis=1).values 151 | X_test = test.drop([target], axis=1).values 152 | 153 | 154 | mean = np.mean(X_train, axis=0) 155 | std = np.std(X_train, axis=0) 156 | X_train = (X_train - mean) / std 157 | X_valid = (X_valid - mean) / std 158 | X_test = (X_test - mean) / std 159 | 160 | return X_train, y_train, X_valid, y_valid, X_test, y_test 161 | 162 | def click(): 163 | target = 'target' 164 | data = pd.read_pickle('./data/click/click.pkl') 165 | 166 | train_idx = pd.read_csv('./data/click/train_idx.csv')['0'].values 167 | train = data.iloc[train_idx, :] 168 | valid_idx = pd.read_csv('./data/click/valid_idx.csv')['0'].values 169 | valid = data.iloc[valid_idx, :] 170 | test_idx = pd.read_csv('./data/click/test_idx.csv')['0'].values 171 | test = data.iloc[test_idx, :] 172 | 173 | y_train = train[target].values 174 | X_train = train.drop([target], axis=1) 175 | y_valid = valid[target].values 176 | X_valid = valid.drop([target], axis=1) 177 | y_test = test[target].values 178 | X_test = test.drop([target], axis=1) 179 | cat_features = ['url_hash', 'ad_id', 'advertiser_id', 'query_id', 180 | 'keyword_id', 'title_id', 'description_id', 'user_id'] 181 | 182 | cat_encoder = LeaveOneOutEncoder() 183 | cat_encoder.fit(X_train[cat_features], y_train) 184 | X_train[cat_features] = cat_encoder.transform(X_train[cat_features]) 185 | X_valid[cat_features] = cat_encoder.transform(X_valid[cat_features]) 186 | X_test[cat_features] = cat_encoder.transform(X_test[cat_features]) 187 | 188 | X_train, X_valid, X_test = quantile_transform(X_train.astype(np.float32), X_valid.astype(np.float32), X_test.astype(np.float32)) 189 | return X_train, y_train, X_valid, y_valid, X_test, y_test 190 | 191 | def cardio(): 192 | target = 'cardio' 193 | data = pd.read_csv('./data/cardio/cardiovascular-disease.csv', delimiter=';').drop(['id'], axis=1) 194 | train_idx = pd.read_csv('./data/cardio/train_idx.csv')['0'].values 195 | train = data.iloc[train_idx, :] 196 | valid_idx = pd.read_csv('./data/cardio/valid_idx.csv')['0'].values 197 | valid = data.iloc[valid_idx, :] 198 | test_idx = pd.read_csv('./data/cardio/test_idx.csv')['0'].values 199 | test = data.iloc[test_idx, :] 200 | 201 | y_train = train[target].values 202 | train.drop([target], axis=1, inplace=True) 203 | y_valid = valid[target].values 204 | valid.drop([target], axis=1, inplace=True) 205 | y_test = test[target].values 206 | test.drop([target], axis=1, inplace=True) 207 | 208 | X_train, X_valid, X_test = quantile_transform(train, valid, test) 209 | return X_train, y_train, X_valid, y_valid, X_test, y_test 210 | 211 | 212 | def get_data(datasetname): 213 | if datasetname == 'forest': 214 | return forest_cover() 215 | elif datasetname == 'MSLR': 216 | return MSLR() 217 | elif datasetname == 'year': 218 | return yearpred() 219 | elif datasetname == 'cardio': 220 | return cardio() 221 | elif datasetname == 'yahoo': 222 | return yahoo() 223 | elif datasetname == 'epsilon': 224 | return epsilon() 225 | elif datasetname == 'click': 226 | return click() 227 | 228 | if __name__ == '__main__': 229 | forest_cover() 230 | -------------------------------------------------------------------------------- /lib/callbacks.py: -------------------------------------------------------------------------------- 1 | import time 2 | import datetime 3 | import copy 4 | import numpy as np 5 | from dataclasses import dataclass, field 6 | from typing import List, Any 7 | from config.default import cfg 8 | class Callback: 9 | """ 10 | Abstract base class used to build new callbacks. 11 | """ 12 | 13 | def __init__(self): 14 | pass 15 | 16 | def set_params(self, params): 17 | self.params = params 18 | 19 | def set_trainer(self, model): 20 | self.trainer = model 21 | 22 | def on_epoch_begin(self, epoch, logs=None): 23 | pass 24 | 25 | def on_epoch_end(self, epoch, logs=None): 26 | pass 27 | 28 | def on_batch_begin(self, batch, logs=None): 29 | pass 30 | 31 | def on_batch_end(self, batch, logs=None): 32 | pass 33 | 34 | def on_train_begin(self, logs=None): 35 | pass 36 | 37 | def on_train_end(self, logs=None): 38 | pass 39 | 40 | 41 | @dataclass 42 | class CallbackContainer: 43 | """ 44 | Container holding a list of callbacks. 45 | """ 46 | 47 | callbacks: List[Callback] = field(default_factory=list) 48 | 49 | def append(self, callback): 50 | self.callbacks.append(callback) 51 | 52 | def set_params(self, params): 53 | for callback in self.callbacks: 54 | callback.set_params(params) 55 | 56 | def set_trainer(self, trainer): 57 | self.trainer = trainer 58 | for callback in self.callbacks: 59 | callback.set_trainer(trainer) 60 | 61 | def on_epoch_begin(self, epoch, logs=None): 62 | logs = logs or {} 63 | for callback in self.callbacks: 64 | callback.on_epoch_begin(epoch, logs) 65 | 66 | def on_epoch_end(self, epoch, logs=None): 67 | logs = logs or {} 68 | for callback in self.callbacks: 69 | callback.on_epoch_end(epoch, logs) 70 | 71 | def on_batch_begin(self, batch, logs=None): 72 | logs = logs or {} 73 | for callback in self.callbacks: 74 | callback.on_batch_begin(batch, logs) 75 | 76 | def on_batch_end(self, batch, logs=None): 77 | logs = logs or {} 78 | for callback in self.callbacks: 79 | callback.on_batch_end(batch, logs) 80 | 81 | def on_train_begin(self, logs=None): 82 | logs = logs or {} 83 | logs["start_time"] = time.time() 84 | for callback in self.callbacks: 85 | callback.on_train_begin(logs) 86 | 87 | def on_train_end(self, logs=None): 88 | logs = logs or {} 89 | for callback in self.callbacks: 90 | callback.on_train_end(logs) 91 | 92 | 93 | @dataclass 94 | class EarlyStopping(Callback): 95 | """EarlyStopping callback to exit the training loop if early_stopping_metric 96 | does not improve by a certain amount for a certain 97 | number of epochs. 98 | 99 | Parameters 100 | --------- 101 | early_stopping_metric : str 102 | Early stopping metric name 103 | is_maximize : bool 104 | Whether to maximize or not early_stopping_metric 105 | tol : float 106 | minimum change in monitored value to qualify as improvement. 107 | This number should be positive. 108 | patience : integer 109 | number of epochs to wait for improvement before terminating. 110 | the counter be reset after each improvement 111 | 112 | """ 113 | 114 | early_stopping_metric: str 115 | is_maximize: bool 116 | tol: float = 0.0 117 | patience: int = 10 118 | 119 | def __post_init__(self): 120 | self.best_epoch = 0 121 | self.stopped_epoch = 0 122 | self.wait = 0 123 | self.best_weights = None 124 | self.best_loss = np.inf 125 | if self.is_maximize: 126 | self.best_loss = -self.best_loss 127 | super().__init__() 128 | 129 | def on_epoch_end(self, epoch, logs=None): 130 | current_loss = logs.get(self.early_stopping_metric) 131 | if current_loss is None: 132 | return 133 | 134 | loss_change = current_loss - self.best_loss 135 | max_improved = self.is_maximize and loss_change > self.tol 136 | min_improved = (not self.is_maximize) and (-loss_change > self.tol) 137 | if max_improved or min_improved: 138 | self.best_loss = current_loss 139 | self.best_epoch = epoch 140 | self.wait = 1 141 | self.best_weights = copy.deepcopy(self.trainer.network.state_dict()) 142 | self.best_msg = 'Best ' + self.early_stopping_metric + ':{:.5f}'.format(self.best_loss) + ' on epoch ' + str(self.best_epoch) 143 | if self.trainer.log: 144 | best_model = {'layer_num': self.trainer.layer, 145 | 'base_outdim': self.trainer.base_outdim, 146 | 'k': self.trainer.k, 147 | 'virtual_batch_size': self.trainer.virtual_batch_size, 148 | 'state_dict': self.trainer.network.state_dict() 149 | } 150 | self.trainer.log.save_best_model(best_model) 151 | else: 152 | if self.wait >= self.patience: 153 | self.stopped_epoch = epoch 154 | self.trainer._stop_training = True 155 | self.wait += 1 156 | print(self.best_msg) 157 | if self.trainer.log: 158 | self.trainer.log.save_log(self.trainer.history['msg'] + '\n' + self.best_msg) 159 | 160 | def on_train_end(self, logs=None): 161 | self.trainer.best_epoch = self.best_epoch 162 | self.trainer.best_cost = self.best_loss 163 | 164 | if self.best_weights is not None: 165 | self.trainer.network.load_state_dict(self.best_weights) 166 | 167 | if self.stopped_epoch > 0: 168 | msg = f"\nEarly stopping occurred at epoch {self.stopped_epoch}" 169 | msg += ( 170 | f" with best_epoch = {self.best_epoch} and " 171 | + f"best_{self.early_stopping_metric} = {round(self.best_loss, 5)}" 172 | ) 173 | print(msg) 174 | else: 175 | msg = ( 176 | f"Stop training because you reached max_epochs = {self.trainer.max_epochs}" 177 | + f" with best_epoch = {self.best_epoch} and " 178 | + f"best_{self.early_stopping_metric} = {round(self.best_loss, 5)}" 179 | ) 180 | print(msg) 181 | print("Best weights from best epoch are automatically used!") 182 | 183 | 184 | @dataclass 185 | class History(Callback): 186 | """Callback that records events into a `History` object. 187 | This callback is automatically applied to 188 | every SuperModule. 189 | 190 | Parameters 191 | --------- 192 | trainer : DeepRecoModel 193 | Model class to train 194 | verbose : int 195 | Print results every verbose iteration 196 | 197 | """ 198 | 199 | trainer: Any 200 | verbose: int = 1 201 | 202 | def __post_init__(self): 203 | super().__init__() 204 | self.samples_seen = 0.0 205 | self.total_time = 0.0 206 | 207 | def on_train_begin(self, logs=None): 208 | self.history = {"loss": []} 209 | self.history.update({"lr": []}) 210 | self.history.update({name: [] for name in self.trainer._metrics_names}) 211 | self.start_time = logs["start_time"] 212 | self.epoch_loss = 0.0 213 | 214 | def on_epoch_begin(self, epoch, logs=None): 215 | self.epoch_metrics = {"loss": 0.0} 216 | self.samples_seen = 0.0 217 | 218 | def on_epoch_end(self, epoch, logs=None): 219 | self.epoch_metrics["loss"] = self.epoch_loss 220 | for metric_name, metric_value in self.epoch_metrics.items(): 221 | self.history[metric_name].append(metric_value) 222 | if self.verbose == 0: 223 | return 224 | if epoch % self.verbose != 0: 225 | return 226 | msg = f"epoch {epoch:<3}" 227 | for metric_name, metric_value in self.epoch_metrics.items(): 228 | if metric_name != "lr": 229 | msg += f"| {metric_name:<3}: {np.round(metric_value, 5):<8}" 230 | self.total_time = int(time.time() - self.start_time) 231 | msg += f"| {str(datetime.timedelta(seconds=self.total_time)) + 's':<6}" 232 | self.history['msg'] = msg 233 | print(msg) 234 | if self.trainer.log: 235 | self.trainer.log.save_tensorboard(self.epoch_metrics, epoch) 236 | 237 | def on_batch_end(self, batch, logs=None): 238 | batch_size = logs["batch_size"] 239 | self.epoch_loss = ( 240 | self.samples_seen * self.epoch_loss + batch_size * logs["loss"] 241 | ) / (self.samples_seen + batch_size) 242 | self.samples_seen += batch_size 243 | 244 | def __getitem__(self, name): 245 | return self.history[name] 246 | 247 | def __repr__(self): 248 | return str(self.history) 249 | 250 | def __str__(self): 251 | return str(self.history) 252 | 253 | 254 | @dataclass 255 | class LRSchedulerCallback(Callback): 256 | """Wrapper for most torch scheduler functions. 257 | 258 | Parameters 259 | --------- 260 | scheduler_fn : torch.optim.lr_scheduler 261 | Torch scheduling class 262 | scheduler_params : dict 263 | Dictionnary containing all parameters for the scheduler_fn 264 | is_batch_level : bool (default = False) 265 | If set to False : lr updates will happen at every epoch 266 | If set to True : lr updates happen at every batch 267 | Set this to True for OneCycleLR for example 268 | """ 269 | 270 | scheduler_fn: Any 271 | optimizer: Any 272 | scheduler_params: dict 273 | early_stopping_metric: str 274 | is_batch_level: bool = False 275 | 276 | def __post_init__( 277 | self, 278 | ): 279 | self.is_metric_related = hasattr(self.scheduler_fn, "is_better") 280 | self.scheduler = self.scheduler_fn(self.optimizer, **self.scheduler_params) 281 | super().__init__() 282 | 283 | def on_batch_end(self, batch, logs=None): 284 | if self.is_batch_level: 285 | self.scheduler.step() 286 | else: 287 | pass 288 | 289 | def on_epoch_end(self, epoch, logs=None): 290 | self.scheduler.step() 291 | -------------------------------------------------------------------------------- /lib/logger.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | from datetime import datetime 4 | from torch.utils.tensorboard import SummaryWriter 5 | 6 | class Train_Log(): 7 | def __init__(self, logname, resume_dir=None): 8 | time_str = datetime.now().strftime("%m-%d_%H%M") 9 | if resume_dir: 10 | self.resume_dir = os.path.join('./logs', resume_dir) 11 | self.log_dir = self.resume_dir 12 | 13 | else: 14 | self.log_dir = os.path.join('./logs/', logname + '_' +time_str) 15 | 16 | self.writer = SummaryWriter(self.log_dir) 17 | 18 | if not os.path.exists(self.log_dir): 19 | os.makedirs(self.log_dir) 20 | 21 | def load_checkpoint(self, optimizer): 22 | lastest_out_path = "{}/checkpoint.pth".format(self.resume_dir) 23 | ckpt = torch.load(lastest_out_path) 24 | model = ckpt['model'] 25 | start_epoch = ckpt['epoch'] + 1 26 | # model.load_state_dict(ckpt['state_dict']) 27 | optimizer.load_state_dict(ckpt['optimizer']) 28 | best_value = ckpt['best_value'] 29 | best_epoch = ckpt['best_epoch'] 30 | 31 | print("=> loaded checkpoint '{}' (epoch {})".format(lastest_out_path, ckpt['epoch'])) 32 | 33 | return start_epoch, model, optimizer, best_value, best_epoch 34 | 35 | def save_best_model(self, model): 36 | lastest_out_path = self.log_dir + '/' + 'best' + '.pth' 37 | torch.save(model, lastest_out_path) 38 | print('Save Best model!!') 39 | 40 | def save_log(self, log): 41 | mode = 'a' if os.path.exists(self.log_dir + '/log.txt') else 'w' 42 | logFile = open(self.log_dir + '/log.txt', mode) 43 | logFile.write(log + '\n') 44 | logFile.close() 45 | 46 | 47 | def save_tensorboard(self, info, epoch): 48 | for tag, value in info.items(): 49 | self.writer.add_scalar(tag, value, global_step=epoch) 50 | -------------------------------------------------------------------------------- /lib/metrics.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass 2 | from typing import List 3 | import numpy as np 4 | from sklearn.metrics import ( 5 | roc_auc_score, 6 | mean_squared_error, 7 | mean_absolute_error, 8 | accuracy_score, 9 | log_loss, 10 | balanced_accuracy_score, 11 | mean_squared_log_error, 12 | ) 13 | 14 | @dataclass 15 | class MetricContainer: 16 | """Container holding a list of metrics. 17 | 18 | Parameters 19 | ---------- 20 | metric_names : list of str 21 | List of metric names. 22 | prefix : str 23 | Prefix of metric names. 24 | 25 | """ 26 | 27 | metric_names: List[str] 28 | prefix: str = "" 29 | 30 | def __post_init__(self): 31 | self.metrics = Metric.get_metrics_by_names(self.metric_names) 32 | self.names = [self.prefix + name for name in self.metric_names] 33 | 34 | def __call__(self, y_true, y_pred): 35 | """Compute all metrics and store into a dict. 36 | 37 | Parameters 38 | ---------- 39 | y_true : np.ndarray 40 | Target matrix or vector 41 | y_pred : np.ndarray 42 | Score matrix or vector 43 | 44 | Returns 45 | ------- 46 | dict 47 | Dict of metrics ({metric_name: metric_value}). 48 | 49 | """ 50 | logs = {} 51 | for metric in self.metrics: 52 | if isinstance(y_pred, list): 53 | res = np.mean( 54 | [metric(y_true[:, i], y_pred[i]) for i in range(len(y_pred))] 55 | ) 56 | else: 57 | res = metric(y_true, y_pred) 58 | logs[self.prefix + metric._name] = res 59 | return logs 60 | 61 | 62 | class Metric: 63 | def __call__(self, y_true, y_pred): 64 | raise NotImplementedError("Custom Metrics must implement this function") 65 | 66 | @classmethod 67 | def get_metrics_by_names(cls, names): 68 | """Get list of metric classes. 69 | 70 | Parameters 71 | ---------- 72 | cls : Metric 73 | Metric class. 74 | names : list 75 | List of metric names. 76 | 77 | Returns 78 | ------- 79 | metrics : list 80 | List of metric classes. 81 | 82 | """ 83 | available_metrics = cls.__subclasses__() 84 | available_names = [metric()._name for metric in available_metrics] 85 | metrics = [] 86 | for name in names: 87 | assert ( 88 | name in available_names 89 | ), f"{name} is not available, choose in {available_names}" 90 | idx = available_names.index(name) 91 | metric = available_metrics[idx]() 92 | metrics.append(metric) 93 | return metrics 94 | 95 | 96 | class AUC(Metric): 97 | """ 98 | AUC. 99 | """ 100 | 101 | def __init__(self): 102 | self._name = "auc" 103 | self._maximize = True 104 | 105 | def __call__(self, y_true, y_score): 106 | """ 107 | Compute AUC of predictions. 108 | 109 | Parameters 110 | ---------- 111 | y_true : np.ndarray 112 | Target matrix or vector 113 | y_score : np.ndarray 114 | Score matrix or vector 115 | 116 | Returns 117 | ------- 118 | float 119 | AUC of predictions vs targets. 120 | """ 121 | return roc_auc_score(y_true, y_score[:, 1]) 122 | 123 | 124 | class Accuracy(Metric): 125 | """ 126 | Accuracy. 127 | """ 128 | 129 | def __init__(self): 130 | self._name = "accuracy" 131 | self._maximize = True 132 | 133 | def __call__(self, y_true, y_score): 134 | """ 135 | Compute Accuracy of predictions. 136 | 137 | Parameters 138 | ---------- 139 | y_true: np.ndarray 140 | Target matrix or vector 141 | y_score: np.ndarray 142 | Score matrix or vector 143 | 144 | Returns 145 | ------- 146 | float 147 | Accuracy of predictions vs targets. 148 | """ 149 | y_pred = np.argmax(y_score, axis=1) 150 | return accuracy_score(y_true, y_pred) 151 | 152 | 153 | class BalancedAccuracy(Metric): 154 | """ 155 | Balanced Accuracy. 156 | """ 157 | 158 | def __init__(self): 159 | self._name = "balanced_accuracy" 160 | self._maximize = True 161 | 162 | def __call__(self, y_true, y_score): 163 | """ 164 | Compute Accuracy of predictions. 165 | 166 | Parameters 167 | ---------- 168 | y_true : np.ndarray 169 | Target matrix or vector 170 | y_score : np.ndarray 171 | Score matrix or vector 172 | 173 | Returns 174 | ------- 175 | float 176 | Accuracy of predictions vs targets. 177 | """ 178 | y_pred = np.argmax(y_score, axis=1) 179 | return balanced_accuracy_score(y_true, y_pred) 180 | 181 | 182 | class LogLoss(Metric): 183 | """ 184 | LogLoss. 185 | """ 186 | 187 | def __init__(self): 188 | self._name = "logloss" 189 | self._maximize = False 190 | 191 | def __call__(self, y_true, y_score): 192 | """ 193 | Compute LogLoss of predictions. 194 | 195 | Parameters 196 | ---------- 197 | y_true : np.ndarray 198 | Target matrix or vector 199 | y_score : np.ndarray 200 | Score matrix or vector 201 | 202 | Returns 203 | ------- 204 | float 205 | LogLoss of predictions vs targets. 206 | """ 207 | return log_loss(y_true, y_score) 208 | 209 | 210 | class MAE(Metric): 211 | """ 212 | Mean Absolute Error. 213 | """ 214 | 215 | def __init__(self): 216 | self._name = "mae" 217 | self._maximize = False 218 | 219 | def __call__(self, y_true, y_score): 220 | """ 221 | Compute MAE (Mean Absolute Error) of predictions. 222 | 223 | Parameters 224 | ---------- 225 | y_true : np.ndarray 226 | Target matrix or vector 227 | y_score : np.ndarray 228 | Score matrix or vector 229 | 230 | Returns 231 | ------- 232 | float 233 | MAE of predictions vs targets. 234 | """ 235 | return mean_absolute_error(y_true, y_score) 236 | 237 | 238 | class MSE(Metric): 239 | """ 240 | Mean Squared Error. 241 | """ 242 | 243 | def __init__(self): 244 | self._name = "mse" 245 | self._maximize = False 246 | 247 | def __call__(self, y_true, y_score): 248 | """ 249 | Compute MSE (Mean Squared Error) of predictions. 250 | 251 | Parameters 252 | ---------- 253 | y_true : np.ndarray 254 | Target matrix or vector 255 | y_score : np.ndarray 256 | Score matrix or vector 257 | 258 | Returns 259 | ------- 260 | float 261 | MSE of predictions vs targets. 262 | """ 263 | return mean_squared_error(y_true, y_score) 264 | 265 | 266 | class RMSLE(Metric): 267 | """ 268 | Mean squared logarithmic error regression loss. 269 | Scikit-implementation: 270 | https://scikit-learn.org/stable/modules/generated/sklearn.metrics.mean_squared_log_error.html 271 | Note: In order to avoid error, negative predictions are clipped to 0. 272 | This means that you should clip negative predictions manually after calling predict. 273 | """ 274 | 275 | def __init__(self): 276 | self._name = "rmsle" 277 | self._maximize = False 278 | 279 | def __call__(self, y_true, y_score): 280 | """ 281 | Compute RMSLE of predictions. 282 | 283 | Parameters 284 | ---------- 285 | y_true : np.ndarray 286 | Target matrix or vector 287 | y_score : np.ndarray 288 | Score matrix or vector 289 | 290 | Returns 291 | ------- 292 | float 293 | RMSLE of predictions vs targets. 294 | """ 295 | y_score = np.clip(y_score, a_min=0, a_max=None) 296 | return mean_squared_log_error(y_true, y_score) 297 | 298 | class RMSE(Metric): 299 | """ 300 | Root Mean Squared Error. 301 | """ 302 | 303 | def __init__(self): 304 | self._name = "rmse" 305 | self._maximize = False 306 | 307 | def __call__(self, y_true, y_score): 308 | """ 309 | Compute RMSE (Root Mean Squared Error) of predictions. 310 | 311 | Parameters 312 | ---------- 313 | y_true : np.ndarray 314 | Target matrix or vector 315 | y_score : np.ndarray 316 | Score matrix or vector 317 | 318 | Returns 319 | ------- 320 | float 321 | RMSE of predictions vs targets. 322 | """ 323 | return np.sqrt(mean_squared_error(y_true, y_score)) 324 | 325 | 326 | def check_metrics(metrics): 327 | """Check if custom metrics are provided. 328 | 329 | Parameters 330 | ---------- 331 | metrics : list of str or classes 332 | List with built-in metrics (str) or custom metrics (classes). 333 | 334 | Returns 335 | ------- 336 | val_metrics : list of str 337 | List of metric names. 338 | 339 | """ 340 | val_metrics = [] 341 | for metric in metrics: 342 | if isinstance(metric, str): 343 | val_metrics.append(metric) 344 | elif issubclass(metric, Metric): 345 | val_metrics.append(metric()._name) 346 | else: 347 | raise TypeError("You need to provide a valid metric format") 348 | return val_metrics 349 | -------------------------------------------------------------------------------- /lib/multiclass_utils.py: -------------------------------------------------------------------------------- 1 | # Author: Arnaud Joly, Joel Nothman, Hamzeh Alsalhi 2 | # 3 | # License: BSD 3 clause 4 | """ 5 | Multi-class / multi-label utility function 6 | ========================================== 7 | 8 | """ 9 | from collections.abc import Sequence 10 | from itertools import chain 11 | 12 | from scipy.sparse import issparse 13 | from scipy.sparse.base import spmatrix 14 | from scipy.sparse import dok_matrix 15 | from scipy.sparse import lil_matrix 16 | 17 | import numpy as np 18 | import pandas as pd 19 | 20 | 21 | def _assert_all_finite(X, allow_nan=False): 22 | """Like assert_all_finite, but only for ndarray.""" 23 | 24 | X = np.asanyarray(X) 25 | # First try an O(n) time, O(1) space solution for the common case that 26 | # everything is finite; fall back to O(n) space np.isfinite to prevent 27 | # false positives from overflow in sum method. The sum is also calculated 28 | # safely to reduce dtype induced overflows. 29 | is_float = X.dtype.kind in "fc" 30 | if is_float and (np.isfinite(np.sum(X))): 31 | pass 32 | elif is_float: 33 | msg_err = "Input contains {} or a value too large for {!r}." 34 | if ( 35 | allow_nan 36 | and np.isinf(X).any() 37 | or not allow_nan 38 | and not np.isfinite(X).all() 39 | ): 40 | type_err = "infinity" if allow_nan else "NaN, infinity" 41 | raise ValueError(msg_err.format(type_err, X.dtype)) 42 | # for object dtype data, we only check for NaNs (GH-13254) 43 | elif X.dtype == np.dtype("object") and not allow_nan: 44 | if np.isnan(X).any(): 45 | raise ValueError("Input contains NaN") 46 | 47 | 48 | def _unique_multiclass(y): 49 | if hasattr(y, "__array__"): 50 | return np.unique(np.asarray(y)) 51 | else: 52 | return set(y) 53 | 54 | 55 | 56 | _FN_UNIQUE_LABELS = { 57 | "binary": _unique_multiclass, 58 | "multiclass": _unique_multiclass, 59 | } 60 | 61 | 62 | def unique_labels(*ys): 63 | """Extract an ordered array of unique labels 64 | 65 | We don't allow: 66 | - mix of multilabel and multiclass (single label) targets 67 | - mix of label indicator matrix and anything else, 68 | because there are no explicit labels) 69 | - mix of label indicator matrices of different sizes 70 | - mix of string and integer labels 71 | 72 | At the moment, we also don't allow "multiclass-multioutput" input type. 73 | 74 | Parameters 75 | ---------- 76 | *ys : array-likes 77 | 78 | Returns 79 | ------- 80 | out : numpy array of shape [n_unique_labels] 81 | An ordered array of unique labels. 82 | 83 | Examples 84 | -------- 85 | >>> from sklearn.utils.multiclass import unique_labels 86 | >>> unique_labels([3, 5, 5, 5, 7, 7]) 87 | array([3, 5, 7]) 88 | >>> unique_labels([1, 2, 3, 4], [2, 2, 3, 4]) 89 | array([1, 2, 3, 4]) 90 | >>> unique_labels([1, 2, 10], [5, 11]) 91 | array([ 1, 2, 5, 10, 11]) 92 | """ 93 | if not ys: 94 | raise ValueError("No argument has been passed.") 95 | # Check that we don't mix label format 96 | 97 | ys_types = set(type_of_target(x) for x in ys) 98 | if ys_types == {"binary", "multiclass"}: 99 | ys_types = {"multiclass"} 100 | 101 | if len(ys_types) > 1: 102 | raise ValueError("Mix type of y not allowed, got types %s" % ys_types) 103 | 104 | label_type = ys_types.pop() 105 | 106 | # Get the unique set of labels 107 | _unique_labels = _FN_UNIQUE_LABELS.get(label_type, None) 108 | if not _unique_labels: 109 | raise ValueError("Unknown label type: %s" % repr(ys)) 110 | 111 | ys_labels = set(chain.from_iterable(_unique_labels(y) for y in ys)) 112 | 113 | # Check that we don't mix string type with number type 114 | if len(set(isinstance(label, str) for label in ys_labels)) > 1: 115 | raise ValueError("Mix of label input types (string and number)") 116 | 117 | return np.array(sorted(ys_labels)) 118 | 119 | 120 | def _is_integral_float(y): 121 | return y.dtype.kind == "f" and np.all(y.astype(int) == y) 122 | 123 | 124 | def is_multilabel(y): 125 | """Check if ``y`` is in a multilabel format. 126 | 127 | Parameters 128 | ---------- 129 | y : numpy array of shape [n_samples] 130 | Target values. 131 | 132 | Returns 133 | ------- 134 | out : bool 135 | Return ``True``, if ``y`` is in a multilabel format, else ```False``. 136 | 137 | Examples 138 | -------- 139 | >>> import numpy as np 140 | >>> from sklearn.utils.multiclass import is_multilabel 141 | >>> is_multilabel([0, 1, 0, 1]) 142 | False 143 | >>> is_multilabel([[1], [0, 2], []]) 144 | False 145 | >>> is_multilabel(np.array([[1, 0], [0, 0]])) 146 | True 147 | >>> is_multilabel(np.array([[1], [0], [0]])) 148 | False 149 | >>> is_multilabel(np.array([[1, 0, 0]])) 150 | True 151 | """ 152 | if hasattr(y, "__array__"): 153 | y = np.asarray(y) 154 | if not (hasattr(y, "shape") and y.ndim == 2 and y.shape[1] > 1): 155 | return False 156 | 157 | if issparse(y): 158 | if isinstance(y, (dok_matrix, lil_matrix)): 159 | y = y.tocsr() 160 | return ( 161 | len(y.data) == 0 162 | or np.unique(y.data).size == 1 163 | and ( 164 | y.dtype.kind in "biu" 165 | or _is_integral_float(np.unique(y.data)) # bool, int, uint 166 | ) 167 | ) 168 | else: 169 | labels = np.unique(y) 170 | 171 | return len(labels) < 3 and ( 172 | y.dtype.kind in "biu" or _is_integral_float(labels) # bool, int, uint 173 | ) 174 | 175 | 176 | def check_classification_targets(y): 177 | """Ensure that target y is of a non-regression type. 178 | 179 | Only the following target types (as defined in type_of_target) are allowed: 180 | 'binary', 'multiclass', 'multiclass-multioutput' 181 | 182 | Parameters 183 | ---------- 184 | y : array-like 185 | """ 186 | y_type = type_of_target(y) 187 | if y_type not in [ 188 | "binary", 189 | "multiclass", 190 | "multiclass-multioutput", 191 | ]: 192 | raise ValueError("Unknown label type: %r" % y_type) 193 | 194 | 195 | def type_of_target(y): 196 | """Determine the type of data indicated by the target. 197 | 198 | Note that this type is the most specific type that can be inferred. 199 | For example: 200 | 201 | * ``binary`` is more specific but compatible with ``multiclass``. 202 | * ``multiclass`` of integers is more specific but compatible with 203 | ``continuous``. 204 | 205 | Parameters 206 | ---------- 207 | y : array-like 208 | 209 | Returns 210 | ------- 211 | target_type : string 212 | One of: 213 | 214 | * 'continuous': `y` is an array-like of floats that are not all 215 | integers, and is 1d or a column vector. 216 | * 'continuous-multioutput': `y` is a 2d array of floats that are 217 | not all integers, and both dimensions are of size > 1. 218 | * 'binary': `y` contains <= 2 discrete values and is 1d or a column 219 | vector. 220 | * 'multiclass': `y` contains more than two discrete values, is not a 221 | sequence of sequences, and is 1d or a column vector. 222 | * 'multiclass-multioutput': `y` is a 2d array that contains more 223 | than two discrete values, is not a sequence of sequences, and both 224 | dimensions are of size > 1. 225 | * 'unknown': `y` is array-like but none of the above, such as a 3d 226 | array, sequence of sequences, or an array of non-sequence objects. 227 | 228 | Examples 229 | -------- 230 | >>> import numpy as np 231 | >>> type_of_target([0.1, 0.6]) 232 | 'continuous' 233 | >>> type_of_target([1, -1, -1, 1]) 234 | 'binary' 235 | >>> type_of_target(['a', 'b', 'a']) 236 | 'binary' 237 | >>> type_of_target([1.0, 2.0]) 238 | 'binary' 239 | >>> type_of_target([1, 0, 2]) 240 | 'multiclass' 241 | >>> type_of_target([1.0, 0.0, 3.0]) 242 | 'multiclass' 243 | >>> type_of_target(['a', 'b', 'c']) 244 | 'multiclass' 245 | >>> type_of_target(np.array([[1, 2], [3, 1]])) 246 | 'multiclass-multioutput' 247 | >>> type_of_target([[1, 2]]) 248 | 'multiclass-multioutput' 249 | >>> type_of_target(np.array([[1.5, 2.0], [3.0, 1.6]])) 250 | 'continuous-multioutput' 251 | """ 252 | valid = ( 253 | isinstance(y, (Sequence, spmatrix)) or hasattr(y, "__array__") 254 | ) and not isinstance(y, str) 255 | 256 | if not valid: 257 | raise ValueError( 258 | "Expected array-like (array or non-string sequence), " "got %r" % y 259 | ) 260 | 261 | sparseseries = y.__class__.__name__ == "SparseSeries" 262 | if sparseseries: 263 | raise ValueError("y cannot be class 'SparseSeries'.") 264 | 265 | 266 | try: 267 | y = np.asarray(y) 268 | except ValueError: 269 | # Known to fail in numpy 1.3 for array of arrays 270 | return "unknown" 271 | 272 | # The old sequence of sequences format 273 | try: 274 | if ( 275 | not hasattr(y[0], "__array__") 276 | and isinstance(y[0], Sequence) 277 | and not isinstance(y[0], str) 278 | ): 279 | raise ValueError( 280 | "You appear to be using a legacy multi-label data" 281 | " representation. Sequence of sequences are no" 282 | " longer supported; use a binary array or sparse" 283 | " matrix instead - the MultiLabelBinarizer" 284 | " transformer can convert to this format." 285 | ) 286 | except IndexError: 287 | pass 288 | 289 | # Invalid inputs 290 | if y.ndim > 2 or (y.dtype == object and len(y) and not isinstance(y.flat[0], str)): 291 | return "unknown" # [[[1, 2]]] or [obj_1] and not ["label_1"] 292 | 293 | if y.ndim == 2 and y.shape[1] == 0: 294 | return "unknown" # [[]] 295 | 296 | if y.ndim == 2 and y.shape[1] > 1: 297 | suffix = "-multioutput" # [[1, 2], [1, 2]] 298 | else: 299 | suffix = "" # [1, 2, 3] or [[1], [2], [3]] 300 | 301 | # check float and contains non-integer float values 302 | if y.dtype.kind == "f" and np.any(y != y.astype(int)): 303 | # [.1, .2, 3] or [[.1, .2, 3]] or [[1., .2]] and not [1., 2., 3.] 304 | _assert_all_finite(y) 305 | return "continuous" + suffix 306 | 307 | if (len(np.unique(y)) > 2) or (y.ndim >= 2 and len(y[0]) > 1): 308 | return "multiclass" + suffix # [1, 2, 3] or [[1., 2., 3]] or [[1, 2]] 309 | else: 310 | return "binary" # [1, 2] or [["a"], ["b"]] 311 | 312 | 313 | def check_unique_type(y): 314 | target_types = pd.Series(y).map(type).unique() 315 | if len(target_types) != 1: 316 | raise TypeError( 317 | f"Values on the target must have the same type. Target has types {target_types}" 318 | ) 319 | 320 | 321 | def infer_output_dim(y_train): 322 | """ 323 | Infer output_dim from targets 324 | 325 | Parameters 326 | ---------- 327 | y_train : np.array 328 | Training targets 329 | 330 | Returns 331 | ------- 332 | output_dim : int 333 | Number of classes for output 334 | train_labels : list 335 | Sorted list of initial classes 336 | """ 337 | check_unique_type(y_train) 338 | train_labels = unique_labels(y_train) 339 | output_dim = len(train_labels) 340 | 341 | return output_dim, train_labels 342 | 343 | 344 | def check_output_dim(labels, y): 345 | if y is not None: 346 | check_unique_type(y) 347 | valid_labels = unique_labels(y) 348 | if not set(valid_labels).issubset(set(labels)): 349 | raise ValueError( 350 | f"""Valid set -- {set(valid_labels)} -- 351 | contains unkown targets from training -- 352 | {set(labels)}""" 353 | ) 354 | return -------------------------------------------------------------------------------- /lib/utils.py: -------------------------------------------------------------------------------- 1 | from torch.utils.data import Dataset 2 | import torch 3 | import numpy as np 4 | from sklearn.utils import check_array 5 | 6 | class FastTensorDataLoader: 7 | """ 8 | A DataLoader-like object for a set of tensors that can be much faster than 9 | TensorDataset + DataLoader because dataloader grabs individual indices of 10 | the dataset and calls cat (slow). 11 | Source: https://discuss.pytorch.org/t/dataloader-much-slower-than-manual-batching/27014/6 12 | """ 13 | def __init__(self, *tensors, batch_size=32, shuffle=False): 14 | """ 15 | Initialize a FastTensorDataLoader. 16 | :param *tensors: tensors to store. Must have the same length @ dim 0. 17 | :param batch_size: batch size to load. 18 | :param shuffle: if True, shuffle the data *in-place* whenever an 19 | iterator is created out of this object. 20 | :returns: A FastTensorDataLoader. 21 | """ 22 | assert all(t.shape[0] == tensors[0].shape[0] for t in tensors) 23 | self.tensors = tensors 24 | 25 | self.dataset_len = self.tensors[0].shape[0] 26 | self.batch_size = batch_size 27 | self.shuffle = shuffle 28 | 29 | # Calculate # batches 30 | n_batches, remainder = divmod(self.dataset_len, self.batch_size) 31 | if remainder > 0: 32 | n_batches += 1 33 | self.n_batches = n_batches 34 | def __iter__(self): 35 | if self.shuffle: 36 | r = torch.randperm(self.dataset_len) 37 | self.tensors = [t[r] for t in self.tensors] 38 | self.i = 0 39 | return self 40 | 41 | def __next__(self): 42 | if self.i >= self.dataset_len: 43 | raise StopIteration 44 | batch = tuple(t[self.i:self.i+self.batch_size] for t in self.tensors) 45 | self.i += self.batch_size 46 | return batch 47 | 48 | def __len__(self): 49 | return self.n_batches 50 | 51 | class PredictDataset(Dataset): 52 | """ 53 | Format for numpy array 54 | 55 | Parameters 56 | ---------- 57 | X : 2D array 58 | The input matrix 59 | """ 60 | 61 | def __init__(self, x): 62 | self.x = x 63 | 64 | def __len__(self): 65 | return len(self.x) 66 | 67 | def __getitem__(self, index): 68 | x = self.x[index] 69 | return x 70 | 71 | def create_dataloaders(X_train, y_train, eval_set, batch_size): 72 | """ 73 | Create dataloaders with or without subsampling depending on weights and balanced. 74 | 75 | Parameters 76 | ---------- 77 | X_train : np.ndarray 78 | Training data 79 | y_train : np.array 80 | Mapped Training targets 81 | eval_set : list of tuple 82 | List of eval tuple set (X, y) 83 | batch_size : int 84 | how many samples per batch to load 85 | Returns 86 | ------- 87 | train_dataloader, valid_dataloader : torch.DataLoader, torch.DataLoader 88 | Training and validation dataloaders 89 | """ 90 | X_train = torch.from_numpy(X_train).float() 91 | y_train = torch.from_numpy(y_train) 92 | train_dataloader = FastTensorDataLoader(X_train, y_train, batch_size=batch_size, shuffle=True) 93 | 94 | valid_dataloaders = [] 95 | for X, y in eval_set: 96 | X = torch.from_numpy(X).float() 97 | y = torch.from_numpy(y) 98 | valid_dataloaders.append(FastTensorDataLoader(X, y, batch_size=batch_size, shuffle=False)) 99 | 100 | return train_dataloader, valid_dataloaders 101 | 102 | def validate_eval_set(eval_set, eval_name, X_train, y_train): 103 | """Check if the shapes of eval_set are compatible with (X_train, y_train). 104 | 105 | Parameters 106 | ---------- 107 | eval_set : list of tuple 108 | List of eval tuple set (X, y). 109 | The last one is used for early stopping 110 | eval_name : list of str 111 | List of eval set names. 112 | X_train : np.ndarray 113 | Train owned products 114 | y_train : np.array 115 | Train targeted products 116 | 117 | Returns 118 | ------- 119 | eval_names : list of str 120 | Validated list of eval_names. 121 | eval_set : list of tuple 122 | Validated list of eval_set. 123 | 124 | """ 125 | eval_name = eval_name or [f"val_{i}" for i in range(len(eval_set))] 126 | 127 | assert len(eval_set) == len( 128 | eval_name 129 | ), "eval_set and eval_name have not the same length" 130 | if len(eval_set) > 0: 131 | assert all( 132 | len(elem) == 2 for elem in eval_set 133 | ), "Each tuple of eval_set need to have two elements" 134 | for name, (X, y) in zip(eval_name, eval_set): 135 | check_array(X) 136 | msg = ( 137 | f"Dimension mismatch between X_{name} " 138 | + f"{X.shape} and X_train {X_train.shape}" 139 | ) 140 | assert len(X.shape) == len(X_train.shape), msg 141 | 142 | msg = ( 143 | f"Dimension mismatch between y_{name} " 144 | + f"{y.shape} and y_train {y_train.shape}" 145 | ) 146 | assert len(y.shape) == len(y_train.shape), msg 147 | 148 | msg = ( 149 | f"Number of columns is different between X_{name} " 150 | + f"({X.shape[1]}) and X_train ({X_train.shape[1]})" 151 | ) 152 | assert X.shape[1] == X_train.shape[1], msg 153 | 154 | if len(y_train.shape) == 2: 155 | msg = ( 156 | f"Number of columns is different between y_{name} " 157 | + f"({y.shape[1]}) and y_train ({y_train.shape[1]})" 158 | ) 159 | assert y.shape[1] == y_train.shape[1], msg 160 | msg = ( 161 | f"You need the same number of rows between X_{name} " 162 | + f"({X.shape[0]}) and y_{name} ({y.shape[0]})" 163 | ) 164 | assert X.shape[0] == y.shape[0], msg 165 | 166 | return eval_name, eval_set 167 | 168 | def define_device(device_name): 169 | """ 170 | Define the device to use during training and inference. 171 | If auto it will detect automatically whether to use cuda or cpu 172 | 173 | Parameters 174 | ---------- 175 | device_name : str 176 | Either "auto", "cpu" or "cuda" 177 | 178 | Returns 179 | ------- 180 | str 181 | Either "cpu" or "cuda" 182 | """ 183 | if device_name == "auto": 184 | if torch.cuda.is_available(): 185 | return "cuda" 186 | else: 187 | return "cpu" 188 | elif device_name == "cuda" and not torch.cuda.is_available(): 189 | return "cpu" 190 | else: 191 | return device_name 192 | 193 | def normalize_reg_label(label, mu, std): 194 | norm_label = ((label - mu) / std).astype(np.float32) 195 | norm_label = norm_label.reshape(-1, 1) 196 | return norm_label -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | from DAN_Task import DANetClassifier, DANetRegressor 2 | import argparse 3 | import os 4 | import torch.distributed 5 | import torch.backends.cudnn 6 | from sklearn.metrics import accuracy_score, mean_squared_error 7 | from data.dataset import get_data 8 | from lib.utils import normalize_reg_label 9 | from qhoptim.pyt import QHAdam 10 | from config.default import cfg 11 | os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID" 12 | 13 | def get_args(): 14 | parser = argparse.ArgumentParser(description='PyTorch v1.4, DANet Task Training') 15 | parser.add_argument('-c', '--config', type=str, required=False, default='config/forest_cover_type.yaml', metavar="FILE", help='Path to config file') 16 | parser.add_argument('-g', '--gpu_id', type=str, default='1', help='GPU ID') 17 | 18 | args = parser.parse_args() 19 | os.environ['CUDA_VISIBLE_DEVICES'] = args.gpu_id 20 | torch.backends.cudnn.benchmark = True if len(args.gpu_id) < 2 else False 21 | if args.config: 22 | cfg.merge_from_file(args.config) 23 | cfg.freeze() 24 | task = cfg.task 25 | seed = cfg.seed 26 | train_config = {'dataset': cfg.dataset, 'resume_dir': cfg.resume_dir, 'logname': cfg.logname} 27 | fit_config = dict(cfg.fit) 28 | model_config = dict(cfg.model) 29 | print('Using config: ', cfg) 30 | 31 | return train_config, fit_config, model_config, task, seed, len(args.gpu_id) 32 | 33 | def set_task_model(task, std=None, seed=1): 34 | if task == 'classification': 35 | clf = DANetClassifier( 36 | optimizer_fn=QHAdam, 37 | optimizer_params=dict(lr=fit_config['lr'], weight_decay=1e-5, nus=(0.8, 1.0)), 38 | scheduler_params=dict(gamma=0.95, step_size=20), 39 | scheduler_fn=torch.optim.lr_scheduler.StepLR, 40 | layer=model_config['layer'], 41 | base_outdim=model_config['base_outdim'], 42 | k=model_config['k'], 43 | drop_rate=model_config['drop_rate'], 44 | seed=seed 45 | ) 46 | eval_metric = ['accuracy'] 47 | 48 | elif task == 'regression': 49 | clf = DANetRegressor( 50 | std=std, 51 | optimizer_fn=QHAdam, 52 | optimizer_params=dict(lr=fit_config['lr'], weight_decay=fit_config['weight_decay'], nus=(0.8, 1.0)), 53 | scheduler_params=dict(gamma=0.95, step_size=fit_config['schedule_step']), 54 | scheduler_fn=torch.optim.lr_scheduler.StepLR, 55 | layer=model_config['layer'], 56 | base_outdim=model_config['base_outdim'], 57 | k=model_config['k'], 58 | seed=seed 59 | ) 60 | eval_metric = ['mse'] 61 | return clf, eval_metric 62 | 63 | if __name__ == '__main__': 64 | 65 | print('===> Setting configuration ...') 66 | train_config, fit_config, model_config, task, seed, n_gpu = get_args() 67 | logname = None if train_config['logname'] == '' else train_config['dataset'] + '/' + train_config['logname'] 68 | print('===> Getting data ...') 69 | X_train, y_train, X_valid, y_valid, X_test, y_test = get_data(train_config['dataset']) 70 | mu, std = None, None 71 | if task == 'regression': 72 | mu, std = y_train.mean(), y_train.std() 73 | print("mean = %.5f, std = %.5f" % (mu, std)) 74 | y_train = normalize_reg_label(y_train, std, mu) 75 | y_valid = normalize_reg_label(y_valid, std, mu) 76 | y_test = normalize_reg_label(y_test, std, mu) 77 | 78 | clf, eval_metric = set_task_model(task, std, seed) 79 | 80 | clf.fit( 81 | X_train=X_train, y_train=y_train, 82 | eval_set=[(X_valid, y_valid)], 83 | eval_name=['valid'], 84 | eval_metric=eval_metric, 85 | max_epochs=fit_config['max_epochs'], patience=fit_config['patience'], 86 | batch_size=fit_config['batch_size'], virtual_batch_size=fit_config['virtual_batch_size'], 87 | logname=logname, 88 | resume_dir=train_config['resume_dir'], 89 | n_gpu=n_gpu 90 | ) 91 | 92 | preds_test = clf.predict(X_test) 93 | 94 | if task == 'classification': 95 | test_acc = accuracy_score(y_pred=preds_test, y_true=y_test) 96 | print(f"FINAL TEST ACCURACY FOR {train_config['dataset']} : {test_acc}") 97 | 98 | elif task == 'regression': 99 | test_mse = mean_squared_error(y_pred=preds_test, y_true=y_test) 100 | print(f"FINAL TEST MSE FOR {train_config['dataset']} : {test_mse}") 101 | -------------------------------------------------------------------------------- /model/AcceleratedModule.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | class AcceleratedCreator(object): 6 | def __init__(self, input_dim, base_out_dim, k): 7 | super(AcceleratedCreator, self).__init__() 8 | self.input_dim = input_dim 9 | self.base_out_dim = base_out_dim 10 | self.computer = Extractor(k) 11 | 12 | def __call__(self, network): 13 | network.init_layer = self.extract_module(network.init_layer, self.input_dim, self.input_dim) 14 | for i in range(len(network.layer)): 15 | network.layer[i] = self.extract_module(network.layer[i], self.base_out_dim, self.input_dim) 16 | return network 17 | 18 | def extract_module(self, basicblock, base_input_dim, fix_input_dim): 19 | basicblock.conv1 = self.computer(basicblock.conv1, base_input_dim, self.base_out_dim // 2) 20 | basicblock.conv2 = self.computer(basicblock.conv2, self.base_out_dim // 2, self.base_out_dim) 21 | basicblock.downsample = self.computer(basicblock.downsample._modules['1'], fix_input_dim, self.base_out_dim) 22 | return basicblock 23 | 24 | 25 | class Extractor(object): 26 | def __init__(self, k): 27 | super(Extractor, self).__init__() 28 | self.k = k 29 | 30 | @staticmethod 31 | def get_parameter(abs_layer): 32 | bn = abs_layer.bn.bn 33 | alpha, beta, eps = bn.weight.data, bn.bias.data, bn.eps # [240] 34 | mu, var = bn.running_mean.data, bn.running_var.data 35 | locality = abs_layer.masker 36 | sparse_weight = locality.smax(locality.weight.data) # 6, 10 37 | 38 | feat_pro = abs_layer.fc 39 | process_weight = feat_pro.weight.data # ([240, 10, 1]) [240] 40 | process_bias = feat_pro.bias.data if feat_pro.bias is not None else None 41 | return alpha, beta, eps, mu, var, sparse_weight, process_weight, process_bias 42 | 43 | @staticmethod 44 | def compute_weights(a, b, eps, mu, var, sw, pw, pb, base_input_dim, base_output_dim, k): 45 | """ 46 | standard shape: [path, output_shape, input_shape, branch] 47 | """ 48 | sw_ = sw[:, None, :, None] 49 | pw_ = pw.view(k, 2, base_output_dim, base_input_dim).permute(0, 2, 3, 1) 50 | if pb is not None: 51 | pb_ = pb.view(k, 2, base_output_dim).permute(0, 2, 1)[:, :, None, :] 52 | a_ = a.view(k, 2, base_output_dim).permute(0, 2, 1)[:, :, None, :] 53 | b_ = b.view(k, 2, base_output_dim).permute(0, 2, 1)[:, :, None, :] 54 | mu_ = mu.view(k, 2, base_output_dim).permute(0, 2, 1)[:, :, None, :] 55 | var_ = var.view(k, 2, base_output_dim).permute(0, 2, 1)[:, :, None, :] 56 | 57 | W = sw_ * pw_ 58 | if pb is not None: 59 | mu_ = mu_ - pb_ 60 | W = a_ / (var_ + eps).sqrt() * W 61 | B = b_ - a_ / (var_ + eps).sqrt() * mu_ 62 | 63 | W_att = W[..., 0] 64 | B_att = B[..., 0] 65 | 66 | W_fc = W[..., 1] 67 | B_fc = B[..., 1] 68 | 69 | return W_att, W_fc, B_att.squeeze(), B_fc.squeeze() 70 | 71 | def __call__(self, abslayer, input_dim, base_out_dim): 72 | (a, b, e, m, v, s, pw, pb) = self.get_parameter(abslayer) 73 | wa, wf, ba, bf = self.compute_weights(a, b, e, m, v, s, pw, pb, input_dim, base_out_dim, self.k) 74 | return CompressAbstractLayer(wa, wf, ba, bf) 75 | 76 | 77 | class CompressAbstractLayer(nn.Module): 78 | def __init__(self, att_w, f_w, att_b, f_b): 79 | super(CompressAbstractLayer, self).__init__() 80 | self.att_w = nn.Parameter(att_w) 81 | self.f_w = nn.Parameter(f_w) 82 | self.att_bias = nn.Parameter(att_b[None, :, :]) 83 | self.f_bias = nn.Parameter(f_b[None, :, :]) 84 | 85 | def forward(self, x): 86 | att = torch.sigmoid(torch.einsum('poi,bi->bpo', self.att_w, x) + self.att_bias) # (2 * i + 2) * p * o 87 | y = torch.einsum('poi,bi->bpo', self.f_w, x) + self.f_bias # (2 * i + 1) * p * o 88 | return torch.sum(F.relu(att * y), dim=-2, keepdim=False) # 3 * p * o 89 | 90 | 91 | if __name__ == '__main__': 92 | import torch.optim as optim 93 | from DANet import AbstractLayer 94 | 95 | input_feat = torch.rand((8, 10), requires_grad=False) 96 | loss_function = nn.L1Loss() 97 | target = torch.rand((8, 20), requires_grad=False) 98 | abs_layer = AbstractLayer(base_input_dim=10, base_output_dim=20, k=6, virtual_batch_size=4, bias=False) 99 | y_ = abs_layer(input_feat) 100 | optimizer = optim.SGD(abs_layer.parameters(), lr=0.3) 101 | abs_layer.zero_grad() 102 | loss_function(y_, target).backward() 103 | optimizer.step() 104 | 105 | abs_layer = abs_layer.eval() 106 | y = abs_layer(input_feat) 107 | computer = Extractor(k=6) 108 | (a, b, e, m, v, s, pw, pb) = computer.get_parameter(abs_layer) 109 | wa, wf, ba, bf = computer.compute_weights(a, b, e, m, v, s, pw, pb, 10, 20, 6) 110 | acc_abs = CompressAbstractLayer(wa, wf, ba, bf) 111 | y2 = acc_abs(input_feat) 112 | -------------------------------------------------------------------------------- /model/DANet.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import numpy as np 4 | import torch.nn.functional as F 5 | import model.sparsemax as sparsemax 6 | 7 | def initialize_glu(module, input_dim, output_dim): 8 | gain_value = np.sqrt((input_dim + output_dim) / np.sqrt(input_dim)) 9 | torch.nn.init.xavier_normal_(module.weight, gain=gain_value) 10 | return 11 | 12 | class GBN(torch.nn.Module): 13 | """ 14 | Ghost Batch Normalization 15 | https://arxiv.org/abs/1705.08741 16 | """ 17 | def __init__(self, input_dim, virtual_batch_size=512): 18 | super(GBN, self).__init__() 19 | self.input_dim = input_dim 20 | self.virtual_batch_size = virtual_batch_size 21 | self.bn = nn.BatchNorm1d(self.input_dim) 22 | 23 | def forward(self, x): 24 | if self.training == True: 25 | chunks = x.chunk(int(np.ceil(x.shape[0] / self.virtual_batch_size)), 0) 26 | res = [self.bn(x_) for x_ in chunks] 27 | return torch.cat(res, dim=0) 28 | else: 29 | return self.bn(x) 30 | 31 | class LearnableLocality(nn.Module): 32 | 33 | def __init__(self, input_dim, k): 34 | super(LearnableLocality, self).__init__() 35 | self.register_parameter('weight', nn.Parameter(torch.rand(k, input_dim))) 36 | self.smax = sparsemax.Entmax15(dim=-1) 37 | 38 | def forward(self, x): 39 | mask = self.smax(self.weight) 40 | masked_x = torch.einsum('nd,bd->bnd', mask, x) # [B, k, D] 41 | return masked_x 42 | 43 | class AbstractLayer(nn.Module): 44 | def __init__(self, base_input_dim, base_output_dim, k, virtual_batch_size, bias=True): 45 | super(AbstractLayer, self).__init__() 46 | self.masker = LearnableLocality(input_dim=base_input_dim, k=k) 47 | self.fc = nn.Conv1d(base_input_dim * k, 2 * k * base_output_dim, kernel_size=1, groups=k, bias=bias) 48 | initialize_glu(self.fc, input_dim=base_input_dim * k, output_dim=2 * k * base_output_dim) 49 | self.bn = GBN(2 * base_output_dim * k, virtual_batch_size) 50 | self.k = k 51 | self.base_output_dim = base_output_dim 52 | 53 | def forward(self, x): 54 | b = x.size(0) 55 | x = self.masker(x) # [B, D] -> [B, k, D] 56 | x = self.fc(x.view(b, -1, 1)) # [B, k, D] -> [B, k * D, 1] -> [B, k * (2 * D'), 1] 57 | x = self.bn(x) 58 | chunks = x.chunk(self.k, 1) # k * [B, 2 * D', 1] 59 | x = sum([F.relu(torch.sigmoid(x_[:, :self.base_output_dim, :]) * x_[:, self.base_output_dim:, :]) for x_ in chunks]) # k * [B, D', 1] -> [B, D', 1] 60 | return x.squeeze(-1) 61 | 62 | 63 | class BasicBlock(nn.Module): 64 | def __init__(self, input_dim, base_outdim, k, virtual_batch_size, fix_input_dim, drop_rate): 65 | super(BasicBlock, self).__init__() 66 | self.conv1 = AbstractLayer(input_dim, base_outdim // 2, k, virtual_batch_size) 67 | self.conv2 = AbstractLayer(base_outdim // 2, base_outdim, k, virtual_batch_size) 68 | 69 | self.downsample = nn.Sequential( 70 | nn.Dropout(drop_rate), 71 | AbstractLayer(fix_input_dim, base_outdim, k, virtual_batch_size) 72 | ) 73 | 74 | def forward(self, x, pre_out=None): 75 | if pre_out == None: 76 | pre_out = x 77 | out = self.conv1(pre_out) 78 | out = self.conv2(out) 79 | identity = self.downsample(x) 80 | out += identity 81 | return F.leaky_relu(out, 0.01) 82 | 83 | 84 | class DANet(nn.Module): 85 | def __init__(self, input_dim, num_classes, layer_num, base_outdim, k, virtual_batch_size, drop_rate=0.1): 86 | super(DANet, self).__init__() 87 | params = {'base_outdim': base_outdim, 'k': k, 'virtual_batch_size': virtual_batch_size, 88 | 'fix_input_dim': input_dim, 'drop_rate': drop_rate} 89 | self.init_layer = BasicBlock(input_dim, **params) 90 | self.lay_num = layer_num 91 | self.layer = nn.ModuleList() 92 | for i in range((layer_num // 2) - 1): 93 | self.layer.append(BasicBlock(base_outdim, **params)) 94 | self.drop = nn.Dropout(0.1) 95 | 96 | self.fc = nn.Sequential(nn.Linear(base_outdim, 256), 97 | nn.ReLU(inplace=True), 98 | nn.Linear(256, 512), 99 | nn.ReLU(inplace=True), 100 | nn.Linear(512, num_classes)) 101 | 102 | def forward(self, x): 103 | out = self.init_layer(x) 104 | for i in range(len(self.layer)): 105 | out = self.layer[i](x, out) 106 | out = self.drop(out) 107 | out = self.fc(out) 108 | return out 109 | -------------------------------------------------------------------------------- /model/sparsemax.py: -------------------------------------------------------------------------------- 1 | from torch import nn 2 | from torch.autograd import Function 3 | import torch.nn.functional as F 4 | 5 | import torch 6 | 7 | """ 8 | Other possible implementations: 9 | https://github.com/KrisKorrel/sparsemax-pytorch/blob/master/sparsemax.py 10 | https://github.com/msobroza/SparsemaxPytorch/blob/master/mnist/sparsemax.py 11 | https://github.com/vene/sparse-structured-attention/blob/master/pytorch/torchsparseattn/sparsemax.py 12 | """ 13 | 14 | 15 | # credits to Yandex https://github.com/Qwicen/node/blob/master/lib/nn_utils.py 16 | def _make_ix_like(input, dim=0): 17 | d = input.size(dim) 18 | rho = torch.arange(1, d + 1, device=input.device, dtype=input.dtype) 19 | view = [1] * input.dim() 20 | view[0] = -1 21 | return rho.view(view).transpose(0, dim) 22 | 23 | 24 | class SparsemaxFunction(Function): 25 | """ 26 | An implementation of sparsemax (Martins & Astudillo, 2016). See 27 | :cite:`DBLP:journals/corr/MartinsA16` for detailed description. 28 | By Ben Peters and Vlad Niculae 29 | """ 30 | 31 | @staticmethod 32 | def forward(ctx, input, dim=-1): 33 | """sparsemax: normalizing sparse transform (a la softmax) 34 | 35 | Parameters 36 | ---------- 37 | ctx : torch.autograd.function._ContextMethodMixin 38 | input : torch.Tensor 39 | any shape 40 | dim : int 41 | dimension along which to apply sparsemax 42 | 43 | Returns 44 | ------- 45 | output : torch.Tensor 46 | same shape as input 47 | 48 | """ 49 | ctx.dim = dim 50 | max_val, _ = input.max(dim=dim, keepdim=True) 51 | input -= max_val # same numerical stability trick as for softmax 52 | tau, supp_size = SparsemaxFunction._threshold_and_support(input, dim=dim) 53 | output = torch.clamp(input - tau, min=0) 54 | ctx.save_for_backward(supp_size, output) 55 | return output 56 | 57 | @staticmethod 58 | def backward(ctx, grad_output): 59 | supp_size, output = ctx.saved_tensors 60 | dim = ctx.dim 61 | grad_input = grad_output.clone() 62 | grad_input[output == 0] = 0 63 | 64 | v_hat = grad_input.sum(dim=dim) / supp_size.to(output.dtype).squeeze() 65 | v_hat = v_hat.unsqueeze(dim) 66 | grad_input = torch.where(output != 0, grad_input - v_hat, grad_input) 67 | return grad_input, None 68 | 69 | @staticmethod 70 | def _threshold_and_support(input, dim=-1): 71 | """Sparsemax building block: compute the threshold 72 | 73 | Parameters 74 | ---------- 75 | input: torch.Tensor 76 | any dimension 77 | dim : int 78 | dimension along which to apply the sparsemax 79 | 80 | Returns 81 | ------- 82 | tau : torch.Tensor 83 | the threshold value 84 | support_size : torch.Tensor 85 | 86 | """ 87 | 88 | input_srt, _ = torch.sort(input, descending=True, dim=dim) 89 | input_cumsum = input_srt.cumsum(dim) - 1 90 | rhos = _make_ix_like(input, dim) 91 | support = rhos * input_srt > input_cumsum 92 | 93 | support_size = support.sum(dim=dim).unsqueeze(dim) 94 | tau = input_cumsum.gather(dim, support_size - 1) 95 | tau /= support_size.to(input.dtype) 96 | return tau, support_size 97 | 98 | 99 | sparsemax = SparsemaxFunction.apply 100 | 101 | 102 | class Sparsemax(nn.Module): 103 | 104 | def __init__(self, dim=-1): 105 | self.dim = dim 106 | super(Sparsemax, self).__init__() 107 | 108 | def forward(self, input): 109 | return sparsemax(input, self.dim) 110 | 111 | 112 | class Entmax15Function(Function): 113 | """ 114 | An implementation of exact Entmax with alpha=1.5 (B. Peters, V. Niculae, A. Martins). See 115 | :cite:`https://arxiv.org/abs/1905.05702 for detailed description. 116 | Source: https://github.com/deep-spin/entmax 117 | """ 118 | 119 | @staticmethod 120 | def forward(ctx, input, dim=-1): 121 | ctx.dim = dim 122 | 123 | max_val, _ = input.max(dim=dim, keepdim=True) 124 | input = input - max_val # same numerical stability trick as for softmax 125 | input = input / 2 # divide by 2 to solve actual Entmax 126 | 127 | tau_star, _ = Entmax15Function._threshold_and_support(input, dim) 128 | output = torch.clamp(input - tau_star, min=0) ** 2 129 | ctx.save_for_backward(output) 130 | return output 131 | 132 | @staticmethod 133 | def backward(ctx, grad_output): 134 | Y, = ctx.saved_tensors 135 | gppr = Y.sqrt() # = 1 / g'' (Y) 136 | dX = grad_output * gppr 137 | q = dX.sum(ctx.dim) / gppr.sum(ctx.dim) 138 | q = q.unsqueeze(ctx.dim) 139 | dX -= q * gppr 140 | return dX, None 141 | 142 | @staticmethod 143 | def _threshold_and_support(input, dim=-1): 144 | Xsrt, _ = torch.sort(input, descending=True, dim=dim) 145 | 146 | rho = _make_ix_like(input, dim) 147 | mean = Xsrt.cumsum(dim) / rho 148 | mean_sq = (Xsrt ** 2).cumsum(dim) / rho 149 | ss = rho * (mean_sq - mean ** 2) 150 | delta = (1 - ss) / rho 151 | 152 | # NOTE this is not exactly the same as in reference algo 153 | # Fortunately it seems the clamped values never wrongly 154 | # get selected by tau <= sorted_z. Prove this! 155 | delta_nz = torch.clamp(delta, 0) 156 | tau = mean - torch.sqrt(delta_nz) 157 | 158 | support_size = (tau <= Xsrt).sum(dim).unsqueeze(dim) 159 | tau_star = tau.gather(dim, support_size - 1) 160 | return tau_star, support_size 161 | 162 | 163 | class Entmoid15(Function): 164 | """ A highly optimized equivalent of lambda x: Entmax15([x, 0]) """ 165 | 166 | @staticmethod 167 | def forward(ctx, input): 168 | output = Entmoid15._forward(input) 169 | ctx.save_for_backward(output) 170 | return output 171 | 172 | @staticmethod 173 | def _forward(input): 174 | input, is_pos = abs(input), input >= 0 175 | tau = (input + torch.sqrt(F.relu(8 - input ** 2))) / 2 176 | tau.masked_fill_(tau <= input, 2.0) 177 | y_neg = 0.25 * F.relu(tau - input, inplace=True) ** 2 178 | return torch.where(is_pos, 1 - y_neg, y_neg) 179 | 180 | @staticmethod 181 | def backward(ctx, grad_output): 182 | return Entmoid15._backward(ctx.saved_tensors[0], grad_output) 183 | 184 | @staticmethod 185 | def _backward(output, grad_output): 186 | gppr0, gppr1 = output.sqrt(), (1 - output).sqrt() 187 | grad_input = grad_output * gppr0 188 | q = grad_input / (gppr0 + gppr1) 189 | grad_input -= q * gppr0 190 | return grad_input 191 | 192 | 193 | entmax15 = Entmax15Function.apply 194 | entmoid15 = Entmoid15.apply 195 | 196 | 197 | class Entmax15(nn.Module): 198 | 199 | def __init__(self, dim=-1): 200 | self.dim = dim 201 | super(Entmax15, self).__init__() 202 | 203 | def forward(self, input): 204 | return entmax15(input, self.dim) -------------------------------------------------------------------------------- /predict.py: -------------------------------------------------------------------------------- 1 | from DAN_Task import DANetClassifier, DANetRegressor 2 | from sklearn.metrics import accuracy_score, mean_squared_error 3 | from lib.multiclass_utils import infer_output_dim 4 | from lib.utils import normalize_reg_label 5 | import numpy as np 6 | import argparse 7 | from data.dataset import get_data 8 | import os 9 | os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID" 10 | 11 | def get_args(): 12 | parser = argparse.ArgumentParser(description='PyTorch v1.4, DANet Testing') 13 | parser.add_argument('-d', '--dataset', type=str, default='forest', help='Dataset Name for extracting data') 14 | parser.add_argument('-m', '--model_file', type=str, default='./weights/forest_layer32.pth', metavar="FILE", help='Inference model path') 15 | parser.add_argument('-g', '--gpu_id', type=str, default='1', help='GPU ID') 16 | args = parser.parse_args() 17 | os.environ['CUDA_VISIBLE_DEVICES'] = args.gpu_id 18 | dataset = args.dataset 19 | model_file = args.model_file 20 | task = 'regression' if dataset in ['year', 'yahoo', 'MSLR'] else 'classification' 21 | 22 | return dataset, model_file, task, len(args.gpu_id) 23 | 24 | def set_task_model(task): 25 | if task == 'classification': 26 | clf = DANetClassifier() 27 | metric = accuracy_score 28 | elif task == 'regression': 29 | clf = DANetRegressor() 30 | metric = mean_squared_error 31 | return clf, metric 32 | 33 | def prepare_data(task, y_train, y_valid, y_test): 34 | output_dim = 1 35 | mu, std = None, None 36 | if task == 'classification': 37 | output_dim, train_labels = infer_output_dim(y_train) 38 | target_mapper = {class_label: index for index, class_label in enumerate(train_labels)} 39 | y_train = np.vectorize(target_mapper.get)(y_train) 40 | y_valid = np.vectorize(target_mapper.get)(y_valid) 41 | y_test = np.vectorize(target_mapper.get)(y_test) 42 | 43 | elif task == 'regression': 44 | mu, std = y_train.mean(), y_train.std() 45 | print("mean = %.5f, std = %.5f" % (mu, std)) 46 | y_train = normalize_reg_label(y_train, mu, std) 47 | y_valid = normalize_reg_label(y_valid, mu, std) 48 | y_test = normalize_reg_label(y_test, mu, std) 49 | 50 | return output_dim, std, y_train, y_valid, y_test 51 | 52 | if __name__ == '__main__': 53 | dataset, model_file, task, n_gpu = get_args() 54 | print('===> Getting data ...') 55 | X_train, y_train, X_valid, y_valid, X_test, y_test = get_data(dataset) 56 | output_dim, std, y_train, y_valid, y_test = prepare_data(task, y_train, y_valid, y_test) 57 | clf, metric = set_task_model(task) 58 | 59 | filepath = model_file 60 | clf.load_model(filepath, input_dim=X_test.shape[1], output_dim=output_dim, n_gpu=n_gpu) 61 | 62 | preds_test = clf.predict(X_test) 63 | test_value = metric(y_pred=preds_test, y_true=y_test) 64 | 65 | if task == 'classification': 66 | print(f"FINAL TEST ACCURACY FOR {dataset} : {test_value}") 67 | 68 | elif task == 'regression': 69 | print(f"FINAL TEST MSE FOR {dataset} : {test_value}") 70 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | torch>=1.4.0 2 | category_encoders 3 | yacs 4 | tensorboard>=2.2.2 5 | qhoptim --------------------------------------------------------------------------------