├── kuma_utils ├── __init__.py ├── torch │ ├── optimizer │ │ ├── __init__.py │ │ └── sam.py │ ├── lr_scheduler │ │ ├── __init__.py │ │ ├── manual_scheduler.py │ │ └── cyclic_scheduler.py │ ├── __init__.py │ ├── hooks │ │ ├── __init__.py │ │ ├── base.py │ │ └── simple_hook.py │ ├── modules │ │ ├── __init__.py │ │ ├── activation.py │ │ ├── calibration.py │ │ ├── groupnorm.py │ │ ├── pooling.py │ │ └── attention.py │ ├── callbacks │ │ ├── __init__.py │ │ ├── base.py │ │ ├── stopping.py │ │ ├── snapshot.py │ │ └── logger.py │ ├── extras.py │ ├── distributed.py │ ├── ddp_worker.py │ ├── sampler.py │ ├── clip_grad.py │ ├── utils.py │ └── trainer.py ├── visualization │ ├── __init__.py │ └── eda.py ├── stats │ ├── __init__.py │ ├── propensity_score.py │ └── tables.py ├── metrics │ ├── __init__.py │ ├── regression.py │ ├── base.py │ └── classification.py ├── training │ ├── __init__.py │ ├── optuna │ │ ├── __init__.py │ │ └── configs.py │ ├── splitter.py │ ├── logger.py │ ├── utils.py │ └── validator.py ├── preprocessing │ ├── __init__.py │ ├── utils.py │ ├── base.py │ ├── misc.py │ ├── pipeline.py │ ├── feature_selection.py │ ├── transformer.py │ └── imputer.py └── utils.py ├── examples ├── images │ ├── eda_wandb.png │ ├── cifar_wandb.png │ └── cifar_tensorboard.png ├── Train_CNN_distributed.py └── Train_CNN_model.ipynb ├── .gitignore ├── LICENSE.txt ├── pyproject.toml └── README.md /kuma_utils/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /kuma_utils/torch/optimizer/__init__.py: -------------------------------------------------------------------------------- 1 | from .sam import * -------------------------------------------------------------------------------- /kuma_utils/visualization/__init__.py: -------------------------------------------------------------------------------- 1 | from .eda import * 2 | -------------------------------------------------------------------------------- /kuma_utils/stats/__init__.py: -------------------------------------------------------------------------------- 1 | from .tables import * 2 | from .propensity_score import * 3 | -------------------------------------------------------------------------------- /kuma_utils/torch/lr_scheduler/__init__.py: -------------------------------------------------------------------------------- 1 | from .manual_scheduler import * 2 | from .cyclic_scheduler import * -------------------------------------------------------------------------------- /kuma_utils/torch/__init__.py: -------------------------------------------------------------------------------- 1 | from .trainer import * 2 | from .callbacks import * 3 | from .hooks import * 4 | -------------------------------------------------------------------------------- /examples/images/eda_wandb.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/analokmaus/kuma_utils/HEAD/examples/images/eda_wandb.png -------------------------------------------------------------------------------- /kuma_utils/metrics/__init__.py: -------------------------------------------------------------------------------- 1 | from .base import * 2 | from .regression import * 3 | from .classification import * 4 | -------------------------------------------------------------------------------- /examples/images/cifar_wandb.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/analokmaus/kuma_utils/HEAD/examples/images/cifar_wandb.png -------------------------------------------------------------------------------- /kuma_utils/torch/hooks/__init__.py: -------------------------------------------------------------------------------- 1 | from .base import HookTemplate 2 | from .simple_hook import TrainHook, SimpleHook 3 | -------------------------------------------------------------------------------- /examples/images/cifar_tensorboard.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/analokmaus/kuma_utils/HEAD/examples/images/cifar_tensorboard.png -------------------------------------------------------------------------------- /kuma_utils/training/__init__.py: -------------------------------------------------------------------------------- 1 | from .trainer import * 2 | from .validator import * 3 | from .logger import * 4 | from .utils import * 5 | from .splitter import * -------------------------------------------------------------------------------- /kuma_utils/torch/modules/__init__.py: -------------------------------------------------------------------------------- 1 | from .activation import * 2 | from .attention import * 3 | from .pooling import * 4 | from .groupnorm import * 5 | from .calibration import * 6 | -------------------------------------------------------------------------------- /kuma_utils/preprocessing/__init__.py: -------------------------------------------------------------------------------- 1 | from .transformer import * 2 | from .imputer import * 3 | from .pipeline import * 4 | from .feature_selection import * 5 | from .misc import * 6 | from .utils import * -------------------------------------------------------------------------------- /kuma_utils/torch/callbacks/__init__.py: -------------------------------------------------------------------------------- 1 | from .base import CallbackTemplate 2 | from .stopping import SaveEveryEpoch, EarlyStopping, CollectTopK 3 | from .logger import TorchLogger, DummyLogger 4 | from .snapshot import SaveSnapshot, SaveAllSnapshots, SaveAverageSnapshot -------------------------------------------------------------------------------- /kuma_utils/preprocessing/utils.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import pandas as pd 3 | 4 | 5 | def analyze_column(input_series: pd.Series) -> str: 6 | if pd.api.types.is_numeric_dtype(input_series): 7 | return 'numerical' 8 | else: 9 | return 'categorical' 10 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # General 2 | nohup.out 3 | .DS_Store 4 | .vscode/ 5 | 6 | # Python-related 7 | .ipynb_checkpoints/ 8 | __pycache__/ 9 | 10 | # Package specific 11 | catboost_info/ 12 | 13 | # Project 14 | *.lock 15 | input/ 16 | results/ 17 | datasets/ 18 | wandb/ 19 | ./*.ipynb 20 | ./.py -------------------------------------------------------------------------------- /kuma_utils/training/optuna/__init__.py: -------------------------------------------------------------------------------- 1 | from .configs import * 2 | 3 | 4 | PARAMS_ZOO = { 5 | 'CatBoostClassifier': catboost_cls_params, 6 | 'CatBoostRegressor': catboost_reg_params, 7 | 'XGBClassifier': xgboost_params, 8 | 'XGBRegressor': xgboost_params, 9 | 'SVC': svm_params, 10 | 'SVR': svm_params, 11 | 'RandomForestClassifier': random_forest_params, 12 | 'RandomForestRegressor': random_forest_params, 13 | } 14 | -------------------------------------------------------------------------------- /kuma_utils/preprocessing/base.py: -------------------------------------------------------------------------------- 1 | import pandas as pd 2 | 3 | 4 | class PreprocessingTemplate: 5 | 6 | def __init__(self): 7 | pass 8 | 9 | def fit(self, X: pd.DataFrame, y: pd.Series): 10 | pass 11 | 12 | def transform(self, X: pd.DataFrame): 13 | pass 14 | 15 | def fit_transform(self, X: pd.DataFrame, y: pd.Series): 16 | pass 17 | 18 | def __repr__(self): 19 | return self.__class__.__name__ 20 | -------------------------------------------------------------------------------- /kuma_utils/torch/modules/activation.py: -------------------------------------------------------------------------------- 1 | ''' 2 | Code borrowed from: 3 | https://github.com/lessw2020/mish 4 | ''' 5 | import torch 6 | import torch.nn as nn 7 | import torch.nn.functional as F 8 | 9 | 10 | class Mish(nn.Module): 11 | def __init__(self): 12 | super().__init__() 13 | 14 | def forward(self, x): 15 | #inlining this saves 1 second per epoch (V100 GPU) vs having a temp x and then returning x(!) 16 | return x * (torch.tanh(F.softplus(x))) 17 | -------------------------------------------------------------------------------- /kuma_utils/torch/extras.py: -------------------------------------------------------------------------------- 1 | class DummyGradScaler: 2 | ''' 3 | A dummy scaler with the same interface as amp.Gradscaler 4 | ''' 5 | def scale(self, loss): 6 | return loss 7 | 8 | def step(self, optimizer): 9 | optimizer.step() 10 | 11 | def update(self): 12 | pass 13 | 14 | 15 | class DummyAutoCast: 16 | ''' 17 | ''' 18 | def __enter__(self): 19 | return None 20 | 21 | def __exit__(self, *args): 22 | pass 23 | -------------------------------------------------------------------------------- /kuma_utils/preprocessing/misc.py: -------------------------------------------------------------------------------- 1 | from .base import PreprocessingTemplate 2 | 3 | 4 | class Cast(PreprocessingTemplate): 5 | ''' 6 | ''' 7 | def __init__(self, dtype): 8 | self.dtype = dtype 9 | 10 | def fit(self, X, y=None) -> None: 11 | raise RuntimeError("fit() is not supported.") 12 | 13 | def transform(self, X): 14 | return X.copy().astype(self.dtype) 15 | 16 | def fit_transform(self, X, y=None): 17 | return X.copy().astype(self.dtype) 18 | -------------------------------------------------------------------------------- /kuma_utils/utils.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | 4 | def is_env_notebook(): 5 | if 'get_ipython' not in globals(): 6 | return False 7 | env_name = get_ipython().__class__.__name__ 8 | if env_name == 'TerminalInteractiveShell': 9 | return False 10 | return True 11 | 12 | 13 | def vector_normalize(v, axis=-1, order=2): 14 | l2 = np.linalg.norm(v, ord=order, axis=axis, keepdims=True) 15 | l2[l2 == 0] = 1 16 | return v/l2 17 | 18 | 19 | def sigmoid(x): 20 | return 1 / (1 + np.exp(-np.clip(x, -709, 100000))) 21 | -------------------------------------------------------------------------------- /kuma_utils/torch/callbacks/base.py: -------------------------------------------------------------------------------- 1 | from typing import Any 2 | 3 | 4 | class CallbackTemplate: 5 | ''' 6 | Callback is called before or after each epoch. 7 | ''' 8 | 9 | def __init__(self): 10 | pass 11 | 12 | def before_epoch(self, env, loader=None, loader_valid=None): 13 | pass 14 | 15 | def after_epoch(self, env, loader=None, loader_valid=None): 16 | pass 17 | 18 | def save_snapshot(self, trainer, path): 19 | pass 20 | 21 | def load_snapshot(self, trainer, path, device): 22 | pass 23 | 24 | def state_dict(self): 25 | return {} 26 | 27 | def load_state_dict(self, checkpoint): 28 | pass 29 | 30 | def __repr__(self): 31 | return self.__class__.__name__ 32 | -------------------------------------------------------------------------------- /kuma_utils/torch/hooks/base.py: -------------------------------------------------------------------------------- 1 | 2 | class HookTemplate: 3 | ''' 4 | Hook is called in each mini-batch during traing / inference 5 | and after processed all mini-batches, 6 | in order to define the training process and evaluate the results of each epoch. 7 | ''' 8 | 9 | def __init__(self): 10 | pass 11 | 12 | def forward_train(self, trainer, inputs): 13 | # return loss, approx 14 | pass 15 | 16 | forward_valid = forward_train 17 | 18 | def forward_test(self, trainer, inputs, approx): 19 | # return approx 20 | pass 21 | 22 | def backprop(self, trainer, loss, inputs): 23 | pass 24 | 25 | def evaluate_batch(self, trainer, inputs): 26 | # return None 27 | pass 28 | 29 | def evaluate_epoch(self, trainer): 30 | # return metric_total, monitor_metrics_total 31 | pass 32 | -------------------------------------------------------------------------------- /kuma_utils/torch/lr_scheduler/manual_scheduler.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.optim.lr_scheduler import _LRScheduler 3 | 4 | 5 | class ManualScheduler(_LRScheduler): 6 | ''' 7 | Example: 8 | config = { 9 | # epoch: learning rate 10 | 0: 1e-3, 11 | 10: 5e-4, 12 | 20: 1e-4 13 | } 14 | ''' 15 | def __init__(self, optimizer, config, verbose=False, **kwargs): 16 | self.config = config 17 | self.verbose = verbose 18 | super().__init__(optimizer, **kwargs) 19 | 20 | def get_lr(self): 21 | if not self.last_epoch in self.config.keys(): 22 | return [group['lr'] for group in self.optimizer.param_groups] 23 | else: 24 | new_lr = [ 25 | self.config[self.last_epoch] for group in self.optimizer.param_groups] 26 | if self.verbose: 27 | print(f'learning rate -> {new_lr}') 28 | return new_lr 29 | -------------------------------------------------------------------------------- /kuma_utils/metrics/regression.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from sklearn.metrics import mean_squared_error, r2_score 3 | from .base import MetricTemplate 4 | 5 | 6 | class RMSE(MetricTemplate): 7 | ''' 8 | Root mean square error 9 | ''' 10 | def __init__(self): 11 | super().__init__(maximize=False) 12 | 13 | def _test(self, target, approx): 14 | return mean_squared_error(target, approx, squared=False) 15 | 16 | 17 | class PearsonCorr(MetricTemplate): 18 | ''' 19 | Pearson product-moment correlation coefficients 20 | ''' 21 | def __init__(self): 22 | super().__init__(maximize=False) 23 | 24 | def _test(self, target, approx): 25 | return np.corrcoef(target, approx)[0, 1] 26 | 27 | 28 | class R2Score(MetricTemplate): 29 | ''' 30 | Coefficient of determination score 31 | ''' 32 | def __init__(self): 33 | super().__init__(maximize=False) 34 | 35 | def _test(self, target, approx): 36 | return r2_score(target, approx) 37 | -------------------------------------------------------------------------------- /LICENSE.txt: -------------------------------------------------------------------------------- 1 | Copyright 2020 Hiroshi Yoshihara 2 | 3 | Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions: 4 | 5 | The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software. 6 | 7 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. -------------------------------------------------------------------------------- /kuma_utils/torch/distributed.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.distributed as dist 3 | 4 | import socket 5 | 6 | 7 | def get_host_ip(): 8 | try: 9 | s = socket.socket(socket.AF_INET, socket.SOCK_DGRAM) 10 | s.connect(('8.8.8.8', 80)) 11 | ip = s.getsockname()[0] 12 | finally: 13 | s.close() 14 | return ip 15 | 16 | 17 | def find_free_port(): 18 | sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) 19 | sock.bind(("", 0)) 20 | port = sock.getsockname()[1] 21 | sock.close() 22 | return port 23 | 24 | 25 | def gather_tensor(tensor): 26 | world_size = dist.get_world_size() 27 | if world_size == 1: 28 | return tensor 29 | tensor_list = [torch.zeros_like(tensor) for _ in range(world_size)] 30 | dist.all_gather(tensor_list, tensor) 31 | return torch.cat(tensor_list) 32 | 33 | 34 | def sync(): 35 | if not dist.is_available() or dist.is_initialized(): 36 | return 37 | world_size = dist.get_world_size() 38 | if world_size == 1: 39 | return 40 | dist.barrier() 41 | -------------------------------------------------------------------------------- /kuma_utils/torch/modules/calibration.py: -------------------------------------------------------------------------------- 1 | ''' 2 | Temperature Scaling 3 | https://github.com/gpleiss/temperature_scaling 4 | Modified 5 | ''' 6 | import torch 7 | from torch import nn, optim 8 | 9 | 10 | class TemperatureScaler(nn.Module): 11 | 12 | def __init__(self, model, num_classes=1, verbose=False): 13 | super().__init__() 14 | self.model = model 15 | self.verbose = verbose 16 | self.num_classes = num_classes 17 | self.temperature = nn.Parameter(torch.ones((1, self.num_classes)) * 1.5) 18 | 19 | def forward(self, *args): 20 | return self.temperature_scale(self.model(*args)) 21 | 22 | def temperature_scale(self, logits): 23 | temperature = self.temperature.expand(logits.size(0), -1) 24 | return logits / temperature 25 | 26 | def set_temperature(self, logits, labels): 27 | nll_criterion = nn.BCEWithLogitsLoss() 28 | optimizer = optim.LBFGS([self.temperature], lr=0.01, max_iter=50) 29 | 30 | def eval(): 31 | loss = nll_criterion(self.temperature_scale(logits), labels) 32 | loss.backward() 33 | return loss 34 | 35 | optimizer.step(eval) 36 | 37 | if self.verbose: 38 | print(f'Optimal temperature: {self.temperature}') 39 | -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [tool.poetry] 2 | name = "kuma_utils" 3 | version = "0.7.0" 4 | description = "" 5 | authors = ["Hiroshi Yoshihara"] 6 | readme = "README.md" 7 | 8 | [tool.poetry.dependencies] 9 | python = ">=3.10,<3.13" 10 | torch = [ 11 | {version = "^2.5.0", source = "torch_cu121", markers = "sys_platform == 'linux' or sys_platform == 'win32'"}, 12 | {version = "^2.5.0", source = "torch_cpu", markers = "sys_platform == 'darwin'"} 13 | ] 14 | torchvision = [ 15 | {version = "^0.20.0", source = "torch_cu121", markers = "sys_platform == 'linux' or sys_platform == 'win32'"}, 16 | {version = "^0.20.0", source = "torch_cpu", markers = "sys_platform == 'darwin'"} 17 | ] 18 | torchaudio = [ 19 | {version = "^2.5.0", source = "torch_cu121", markers = "sys_platform == 'linux' or sys_platform == 'win32'"}, 20 | {version = "^2.5.0", source = "torch_cpu", markers = "sys_platform == 'darwin'"} 21 | ] 22 | ipykernel = "^6.29.0" 23 | jupyter = "^1.0.0" 24 | scipy = "^1.12.0" 25 | seaborn = "^0.13.2" 26 | lightgbm = "^4.3.0" 27 | catboost = "^1.2.2" 28 | xgboost = "^1.7.6" 29 | optuna = "^3.5.0" 30 | torchsummary = "^1.5.1" 31 | pyarrow = "^15.0.0" 32 | tensorboard = "^2.15.1" 33 | scikit-learn = "^1.4.0" 34 | japanize-matplotlib = "^1.1.3" 35 | openpyxl = "^3.1.2" 36 | wandb = "^0.16.3" 37 | timm = "^0.9.12" 38 | matplotlib-venn = "^0.11.10" 39 | 40 | [[tool.poetry.source]] 41 | name = "torch_cpu" 42 | url = "https://download.pytorch.org/whl/cpu" 43 | priority = "explicit" 44 | 45 | [[tool.poetry.source]] 46 | name = "torch_cu121" 47 | url = "https://download.pytorch.org/whl/cu121" 48 | priority = "explicit" 49 | 50 | [build-system] 51 | requires = ["poetry-core"] 52 | build-backend = "poetry.core.masonry.api" 53 | -------------------------------------------------------------------------------- /kuma_utils/torch/ddp_worker.py: -------------------------------------------------------------------------------- 1 | import pickle 2 | import argparse 3 | import traceback 4 | from pathlib import Path 5 | import sys 6 | import os 7 | import multiprocessing 8 | multiprocessing.current_process().authkey = '0'.encode('utf-8') 9 | 10 | 11 | class CustomUnpickler(pickle.Unpickler): 12 | ''' 13 | Unpickle objects defined in your origin __main__ 14 | ''' 15 | def __init__(self, f, main): 16 | super().__init__(f) 17 | if main[-3:] == '.py': 18 | main = main[:-3] 19 | self.main = main 20 | 21 | def find_class(self, module, name): 22 | if module == "__main__": 23 | module = self.main 24 | return super().find_class(module, name) 25 | 26 | 27 | def ddp_worker(path, rank, origin): 28 | origin = Path(origin) 29 | main_file = origin.stem 30 | main_dir = origin.parent 31 | sys.path = sys.path[1:] # prevent internal import in kuma_utils 32 | sys.path.append(str(main_dir)) # add origin directory 33 | 34 | with open(path, 'rb') as f: 35 | unpickler = CustomUnpickler(f, main=main_file) 36 | ddp_tmp = unpickler.load() 37 | trainer = ddp_tmp['trainer'] 38 | dist_url = ddp_tmp['dist_url'] 39 | loader = ddp_tmp['loader'] 40 | loader_valid = ddp_tmp['loader_valid'] 41 | num_epochs = ddp_tmp['num_epochs'] 42 | 43 | assert rank < trainer.world_size 44 | trainer._train_ddp(rank, dist_url, loader, loader_valid, num_epochs) 45 | 46 | 47 | if __name__ == "__main__": 48 | parser = argparse.ArgumentParser() 49 | parser.add_argument('--path', type=str, required=True) 50 | parser.add_argument('--origin', type=str, required=True) 51 | opt = parser.parse_args() 52 | local_rank = int(os.environ["LOCAL_RANK"]) 53 | try: 54 | ddp_worker(opt.path, local_rank, opt.origin) 55 | except: 56 | print(traceback.format_exc()) 57 | if Path(opt.path).exists(): 58 | Path(opt.path).unlink() 59 | -------------------------------------------------------------------------------- /kuma_utils/preprocessing/pipeline.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import pandas as pd 3 | 4 | 5 | class PrepPipeline: 6 | ''' 7 | Data Preprocessing Pipeline 8 | ''' 9 | 10 | def __init__(self, transforms, target_col=None): 11 | self._transforms = transforms 12 | self._target_col = target_col 13 | 14 | def fit(self, input_df: pd.DataFrame) -> None: 15 | raise RuntimeError("fit() is not supported.") 16 | 17 | def fit_transform(self, input_df: pd.DataFrame) -> pd.DataFrame: 18 | if self._target_col is not None: 19 | y = input_df[self._target_col].copy() 20 | else: 21 | y = None 22 | 23 | for transform in self._transforms: 24 | # Save df columns before transform 25 | if isinstance(input_df, pd.DataFrame): 26 | columns = input_df.columns 27 | else: 28 | columns = None 29 | if y is not None: 30 | try: 31 | input_df = transform.fit_transform(input_df, y) 32 | except: 33 | input_df = transform.fit_transform(input_df) 34 | else: 35 | input_df = transform.fit_transform(input_df) 36 | # Restore columns 37 | if (columns is not None) and isinstance(input_df, np.ndarray) and (len(columns) == input_df.shape[1]): 38 | input_df = pd.DataFrame(input_df, columns=columns) 39 | return input_df 40 | 41 | def transform(self, input_df: pd.DataFrame) -> pd.DataFrame: 42 | for transform in self._transforms: 43 | if isinstance(input_df, pd.DataFrame): 44 | columns = input_df.columns 45 | else: 46 | columns = None 47 | input_df = transform.transform(input_df) 48 | if (columns is not None) and isinstance(input_df, np.ndarray) and (len(columns) == input_df.shape[1]): 49 | input_df = pd.DataFrame(input_df, columns=columns) 50 | return input_df 51 | -------------------------------------------------------------------------------- /kuma_utils/metrics/base.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | 4 | class MetricTemplate: 5 | ''' 6 | Custom metric template 7 | 8 | # Usage 9 | general: Metric()(target, approx) 10 | catboost: eval_metric=Metric() 11 | lightgbm: metric='Metric_Name', feval=Metric().lgb 12 | pytorch: Metric().torch(output, labels) 13 | ''' 14 | 15 | def __init__(self, maximize=False): 16 | self.maximize = maximize 17 | 18 | def __repr__(self): 19 | return f'{type(self).__name__}(maximize={self.maximize})' 20 | 21 | def _test(self, target, approx): 22 | # Metric calculation 23 | pass 24 | 25 | def __call__(self, target, approx): 26 | return self._test(target, approx) 27 | 28 | ''' CatBoost ''' 29 | def get_final_error(self, error, weight): 30 | return error / weight 31 | 32 | def is_max_optimal(self): 33 | return self.maximize 34 | 35 | def evaluate(self, approxes, target, weight=None): 36 | # approxes - list of list-like objects (one object per approx dimension) 37 | # target - list-like object 38 | # weight - list-like object, can be None 39 | assert len(approxes[0]) == len(target) 40 | if not isinstance(target, np.ndarray): 41 | target = np.array(target) 42 | 43 | approx = np.array(approxes[0]) 44 | error_sum = self._test(target, approx) 45 | weight_sum = 1.0 46 | 47 | return error_sum, weight_sum 48 | 49 | ''' LightGBM ''' 50 | def lgb(self, approx, data): 51 | target = data.get_label() 52 | return self.__class__.__name__, self._test(target, approx), self.maximize 53 | 54 | lgbm = lgb # for compatibility 55 | 56 | ''' XGBoost ''' 57 | def xgb(self, approx, dtrain): 58 | target = dtrain.get_label() 59 | return self.__class__.__name__, self._test(target, approx) 60 | 61 | ''' PyTorch ''' 62 | def torch(self, approx, target): 63 | return self._test(target.detach().cpu().numpy(), 64 | approx.detach().cpu().numpy()) 65 | -------------------------------------------------------------------------------- /kuma_utils/torch/modules/groupnorm.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | from torch.nn.modules.normalization import GroupNorm as _GroupNorm 4 | 5 | 6 | class GroupNorm2d(_GroupNorm): 7 | 8 | def _check_input_dim(self, input): 9 | if input.dim() != 4: 10 | raise ValueError('expected 4D input (got {}D input)' 11 | .format(input.dim())) 12 | 13 | 14 | class GroupNorm3d(_GroupNorm): 15 | """ 16 | Assume the data format is (B, C, D, H, W) 17 | """ 18 | 19 | def _check_input_dim(self, input): 20 | if input.dim() != 5: 21 | raise ValueError('expected 5D input (got {}D input)' 22 | .format(input.dim())) 23 | 24 | 25 | class GroupNorm1d(_GroupNorm): 26 | """ 27 | Assume the data format is (N, C, W) 28 | """ 29 | 30 | def _check_input_dim(self, input): 31 | if input.dim() != 3: 32 | raise ValueError('expected 3D input (got {}D input)' 33 | .format(input.dim())) 34 | 35 | 36 | def convert_groupnorm(module, num_groups=32): 37 | if isinstance(module, torch.nn.DataParallel): 38 | mod = module.module 39 | mod = convert_groupnorm(mod) 40 | 41 | mod = module 42 | for batchnorm, groupnorm in zip([torch.nn.modules.batchnorm.BatchNorm1d, 43 | torch.nn.modules.batchnorm.BatchNorm2d, 44 | torch.nn.modules.batchnorm.BatchNorm3d], 45 | [GroupNorm1d, 46 | GroupNorm2d, 47 | GroupNorm3d]): 48 | if isinstance(module, batchnorm): 49 | mod = groupnorm( 50 | num_groups, module.num_features, 51 | module.eps, module.affine) 52 | 53 | if module.affine: 54 | mod.weight.data = module.weight.data.clone().detach() 55 | mod.bias.data = module.bias.data.clone().detach() 56 | 57 | for name, child in module.named_children(): 58 | mod.add_module(name, convert_groupnorm(child)) 59 | 60 | return mod 61 | -------------------------------------------------------------------------------- /kuma_utils/torch/modules/pooling.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | import torch.nn.functional as F 4 | from torch.nn.parameter import Parameter 5 | 6 | 7 | class AdaptiveConcatPool2d(nn.Module): 8 | def __init__(self, sz=None): 9 | super().__init__() 10 | sz = sz or (1, 1) 11 | self.ap = nn.AdaptiveAvgPool2d(sz) 12 | self.mp = nn.AdaptiveMaxPool2d(sz) 13 | 14 | def forward(self, x): return torch.cat([self.mp(x), self.ap(x)], 1) 15 | 16 | 17 | class ChannelPool(nn.Module): 18 | 19 | def __init__(self, dim=1, concat=True): 20 | super().__init__() 21 | self.dim = dim 22 | self.concat = concat 23 | 24 | def forward(self, x): 25 | max_out = torch.max(x, self.dim)[0].unsqueeze(1) 26 | avg_out = torch.mean(x, self.dim).unsqueeze(1) 27 | if self.concat: 28 | return torch.cat((max_out, avg_out), dim=self.dim) 29 | else: 30 | return max_out, avg_out 31 | 32 | 33 | class AdaptiveConcatPool3d(nn.Module): 34 | def __init__(self, sz=None): 35 | super().__init__() 36 | sz = sz or (1, 1, 1) 37 | self.ap = nn.AdaptiveAvgPool3d(sz) 38 | self.mp = nn.AdaptiveMaxPool3d(sz) 39 | 40 | def forward(self, x): return torch.cat([self.mp(x), self.ap(x)], 1) 41 | 42 | 43 | class GeM(nn.Module): 44 | def __init__(self, p=3, eps=1e-6): 45 | super().__init__() 46 | self.p = Parameter(torch.ones(1)*p) 47 | self.eps = eps 48 | 49 | def forward(self, x): 50 | return F.avg_pool2d( 51 | x.clamp(min=self.eps).pow(self.p), 52 | (x.size(-2), x.size(-1))).pow(1./self.p) 53 | 54 | def __repr__(self): 55 | return f'GeM(p={self.p}, eps={self.eps})' 56 | 57 | 58 | class AdaptiveGeM(nn.Module): 59 | def __init__(self, size=(1, 1), p=3, eps=1e-6): 60 | super().__init__() 61 | self.size = size 62 | self.p = Parameter(torch.ones(1)*p) 63 | self.eps = eps 64 | 65 | def forward(self, x): 66 | return F.adaptive_avg_pool2d( 67 | x.clamp(min=self.eps).pow(self.p), self.size).pow(1./self.p) 68 | 69 | def __repr__(self): 70 | return f'AdaptiveGeM(size={self.size}, p={self.p}, eps={self.eps})' 71 | -------------------------------------------------------------------------------- /kuma_utils/torch/sampler.py: -------------------------------------------------------------------------------- 1 | from typing import Any, Callable, Iterator, List, Optional, Union 2 | import torch 3 | from torch.utils.data.distributed import DistributedSampler 4 | from torch.utils.data.sampler import Sampler 5 | 6 | 7 | class DistributedProxySampler(DistributedSampler): 8 | """Distributed sampler proxy to adapt user's sampler for distributed data parallelism configuration. 9 | Code is borrowed from: https://pytorch.org/ignite/_modules/ignite/distributed/auto.html#DistributedProxySampler 10 | Code is based on https://github.com/pytorch/pytorch/issues/23430#issuecomment-562350407 11 | 12 | Args: 13 | sampler: Input torch data sampler. 14 | num_replicas: Number of processes participating in distributed training. 15 | rank: Rank of the current process within ``num_replicas``. 16 | 17 | .. note:: 18 | Input sampler is assumed to have a constant size. 19 | """ 20 | 21 | def __init__(self, sampler: Sampler, num_replicas: Optional[int] = None, rank: Optional[int] = None) -> None: 22 | 23 | if not isinstance(sampler, Sampler): 24 | raise TypeError(f"Argument sampler should be instance of torch Sampler, but given: {type(sampler)}") 25 | 26 | if isinstance(sampler, DistributedSampler): 27 | raise TypeError("Argument sampler must not be a distributed sampler already") 28 | 29 | if not hasattr(sampler, "__len__"): 30 | raise TypeError("Argument sampler should have length") 31 | 32 | super(DistributedProxySampler, self).__init__( 33 | sampler, num_replicas=num_replicas, rank=rank, shuffle=False # type: ignore[arg-type] 34 | ) 35 | self.sampler = sampler 36 | 37 | def __iter__(self) -> Iterator: 38 | # deterministically shuffle based on epoch 39 | torch.manual_seed(self.epoch) 40 | 41 | indices = [] # type: List 42 | while len(indices) < self.total_size: 43 | indices += list(self.sampler) 44 | 45 | if len(indices) > self.total_size: 46 | indices = indices[: self.total_size] 47 | 48 | # subsample 49 | indices = indices[self.rank : self.total_size : self.num_replicas] 50 | if len(indices) != self.num_samples: 51 | raise RuntimeError(f"{len(indices)} vs {self.num_samples}") 52 | 53 | return iter(indices) 54 | -------------------------------------------------------------------------------- /kuma_utils/training/splitter.py: -------------------------------------------------------------------------------- 1 | import random 2 | from collections import Counter, defaultdict 3 | import numpy as np 4 | 5 | 6 | class StratifiedGroupKFold: 7 | ''' 8 | StratifiedGroupKFold 9 | ''' 10 | 11 | def __init__(self, n_splits, random_state=None): 12 | 13 | self.n_splits = n_splits 14 | self.random_state = random_state 15 | 16 | def split(self, X, y, groups): 17 | 18 | labels_num = np.max(y) + 1 19 | y_counts_per_group = defaultdict(lambda: np.zeros(labels_num)) 20 | y_distr = Counter() 21 | for label, g in zip(y, groups): 22 | y_counts_per_group[g][label] += 1 23 | y_distr[label] += 1 24 | 25 | y_counts_per_fold = defaultdict(lambda: np.zeros(labels_num)) 26 | groups_per_fold = defaultdict(set) 27 | 28 | def eval_y_counts_per_fold(y_counts, fold): 29 | y_counts_per_fold[fold] += y_counts 30 | std_per_label = [] 31 | for label in range(labels_num): 32 | label_std = np.std( 33 | [y_counts_per_fold[i][label] / y_distr[label] for i in range(self.n_splits)]) 34 | std_per_label.append(label_std) 35 | y_counts_per_fold[fold] -= y_counts 36 | return np.mean(std_per_label) 37 | 38 | groups_and_y_counts = list(y_counts_per_group.items()) 39 | random.Random(self.random_state).shuffle(groups_and_y_counts) 40 | 41 | for g, y_counts in sorted(groups_and_y_counts, key=lambda x: -np.std(x[1])): 42 | best_fold = None 43 | min_eval = None 44 | for i in range(self.n_splits): 45 | fold_eval = eval_y_counts_per_fold(y_counts, i) 46 | if min_eval is None or fold_eval < min_eval: 47 | min_eval = fold_eval 48 | best_fold = i 49 | y_counts_per_fold[best_fold] += y_counts 50 | groups_per_fold[best_fold].add(g) 51 | 52 | all_groups = set(groups) 53 | for i in range(self.n_splits): 54 | train_groups = all_groups - groups_per_fold[i] 55 | test_groups = groups_per_fold[i] 56 | 57 | train_indices = [i for i, g in enumerate( 58 | groups) if g in train_groups] 59 | test_indices = [i for i, g in enumerate( 60 | groups) if g in test_groups] 61 | 62 | yield train_indices, test_indices 63 | 64 | def get_n_splits(self): 65 | return self.n_splits 66 | -------------------------------------------------------------------------------- /kuma_utils/training/logger.py: -------------------------------------------------------------------------------- 1 | import time 2 | import logging 3 | from pathlib import Path 4 | from pprint import pprint, pformat 5 | 6 | 7 | def get_time(time_format='%H:%M:%S'): 8 | return time.strftime(time_format, time.gmtime()) 9 | 10 | 11 | class LGBMLogger: 12 | 13 | def __init__( 14 | self, 15 | path: str | Path, 16 | stdout: bool = True, 17 | file: bool = False, 18 | logger_name: str = 'LGBMLogger', 19 | default_level: str = 'INFO'): 20 | self.path = path 21 | self.stdout = stdout 22 | self.file = file 23 | self.logger_name = logger_name 24 | self.level = default_level 25 | self.system_logger = logging.getLogger(self.logger_name) 26 | for handler in self.system_logger.handlers[:]: 27 | self.system_logger.removeHandler(handler) 28 | handler.close() 29 | self.system_logger.setLevel(self.level) 30 | formatter = logging.Formatter("%(asctime)s - %(levelname)-8s - %(message)s") 31 | if self.file: 32 | fh = logging.FileHandler(self.path) 33 | fh.setFormatter(formatter) 34 | self.system_logger.addHandler(fh) 35 | if self.stdout: 36 | sh = logging.StreamHandler() 37 | sh.setFormatter(formatter) 38 | self.system_logger.addHandler(sh) 39 | 40 | for level in ['debug', 'info', 'warning', 'error', 'critical']: 41 | setattr(self, level, getattr(self.system_logger, level)) 42 | 43 | def lgbm(self, env): 44 | log_str = '' 45 | log_str += f'[iter {env.iteration:-5}] ' 46 | for inputs in env.evaluation_result_list: 47 | for i in inputs: 48 | if isinstance(i, str): 49 | log_str += f'{i} ' 50 | elif isinstance(i, bool): 51 | pass 52 | else: 53 | log_str += f'{i:.6f} ' 54 | else: 55 | log_str += '/ ' 56 | log_str += '\n' 57 | self.debug(log_str) 58 | 59 | def optuna(self, study, trial): 60 | best_score = study.best_value 61 | curr_score = trial.value 62 | if curr_score == best_score: 63 | log_str = '' 64 | log_str += f'[trial {trial.number:-4}] New best: {best_score:.6f} \n' 65 | log_str += f'{pformat(study.best_params, compact=True, indent=2)}' 66 | self.info(log_str) 67 | 68 | def __call__(self, log_str): 69 | self.info(log_str) 70 | -------------------------------------------------------------------------------- /kuma_utils/torch/clip_grad.py: -------------------------------------------------------------------------------- 1 | """ Adaptive Gradient Clipping 2 | An impl of AGC, as per (https://arxiv.org/abs/2102.06171): 3 | @article{brock2021high, 4 | author={Andrew Brock and Soham De and Samuel L. Smith and Karen Simonyan}, 5 | title={High-Performance Large-Scale Image Recognition Without Normalization}, 6 | journal={arXiv preprint arXiv:}, 7 | year={2021} 8 | } 9 | Code references: 10 | * Official JAX impl (paper authors): https://github.com/deepmind/deepmind-research/tree/master/nfnets 11 | * Phil Wang's PyTorch gist: https://gist.github.com/lucidrains/0d6560077edac419ab5d3aa29e674d5c 12 | Hacked together by / Copyright 2021 Ross Wightman 13 | """ 14 | import torch 15 | 16 | 17 | def unitwise_norm(x, norm_type=2.0): 18 | if x.ndim <= 1: 19 | return x.norm(norm_type) 20 | else: 21 | # works for nn.ConvNd and nn,Linear where output dim is first in the kernel/weight tensor 22 | # might need special cases for other weights (possibly MHA) where this may not be true 23 | return x.norm(norm_type, dim=tuple(range(1, x.ndim)), keepdim=True) 24 | 25 | 26 | def adaptive_clip_grad(parameters, clip_factor=0.01, eps=1e-3, norm_type=2.0): 27 | if isinstance(parameters, torch.Tensor): 28 | parameters = [parameters] 29 | for p in parameters: 30 | if p.grad is None: 31 | continue 32 | p_data = p.detach() 33 | g_data = p.grad.detach() 34 | max_norm = unitwise_norm(p_data, norm_type=norm_type).clamp_(min=eps).mul_(clip_factor) 35 | grad_norm = unitwise_norm(g_data, norm_type=norm_type) 36 | clipped_grad = g_data * (max_norm / grad_norm.clamp(min=1e-6)) 37 | new_grads = torch.where(grad_norm < max_norm, g_data, clipped_grad) 38 | p.grad.detach().copy_(new_grads) 39 | 40 | 41 | def dispatch_clip_grad(parameters, value: float, mode: str = 'norm', norm_type: float = 2.0): 42 | """ Dispatch to gradient clipping method 43 | Args: 44 | parameters (Iterable): model parameters to clip 45 | value (float): clipping value/factor/norm, mode dependant 46 | mode (str): clipping mode, one of 'norm', 'value', 'agc' 47 | norm_type (float): p-norm, default 2.0 48 | """ 49 | if mode == 'norm': 50 | torch.nn.utils.clip_grad_norm_(parameters, value, norm_type=norm_type) 51 | elif mode == 'value': 52 | torch.nn.utils.clip_grad_value_(parameters, value) 53 | elif mode == 'agc': 54 | adaptive_clip_grad(parameters, value, norm_type=norm_type) 55 | elif mode is None: 56 | pass 57 | else: 58 | assert False, f"Unknown clip mode ({mode})." 59 | -------------------------------------------------------------------------------- /kuma_utils/torch/optimizer/sam.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | 4 | class SAM(torch.optim.Optimizer): 5 | ''' 6 | https://github.com/davda54/sam/blob/main/sam.py 7 | Modified 8 | ''' 9 | def __init__(self, params, base_optimizer, rho=0.05, **kwargs): 10 | assert rho >= 0.0, f"Invalid rho, should be non-negative: {rho}" 11 | 12 | defaults = dict(rho=rho, **kwargs) 13 | super(SAM, self).__init__(params, defaults) 14 | 15 | self.base_optimizer = base_optimizer(self.param_groups, **kwargs) 16 | self.param_groups = self.base_optimizer.param_groups 17 | 18 | def __getstate__(self): 19 | return { 20 | 'defaults': self.defaults, 21 | 'state': self.state, 22 | 'param_groups': self.param_groups, 23 | 'base_optimizer': self.base_optimizer # pickle base_optimizer as well 24 | } 25 | 26 | @torch.no_grad() 27 | def first_step(self, zero_grad=False): 28 | grad_norm = self._grad_norm() 29 | for group in self.param_groups: 30 | scale = group["rho"] / (grad_norm + 1e-12) 31 | 32 | for p in group["params"]: 33 | if p.grad is None: 34 | continue 35 | e_w = p.grad * scale.to(p) 36 | p.add_(e_w) # climb to the local maximum "w + e(w)" 37 | self.state[p]["e_w"] = e_w 38 | 39 | if zero_grad: 40 | self.zero_grad() 41 | 42 | @torch.no_grad() 43 | def second_step(self, zero_grad=False): 44 | for group in self.param_groups: 45 | for p in group["params"]: 46 | if p.grad is None: 47 | continue 48 | p.sub_(self.state[p]["e_w"]) # get back to "w" from "w + e(w)" 49 | 50 | self.base_optimizer.step() # do the actual "sharpness-aware" update 51 | 52 | if zero_grad: 53 | self.zero_grad() 54 | 55 | @torch.no_grad() 56 | def step(self, closure=None): 57 | assert closure is not None, "Sharpness Aware Minimization requires closure, but it was not provided" 58 | # the closure should do a full forward-backward pass 59 | closure = torch.enable_grad()(closure) 60 | 61 | self.first_step(zero_grad=True) 62 | closure() 63 | self.second_step() 64 | 65 | def _grad_norm(self): 66 | # put everything on the same device, in case of model parallelism 67 | shared_device = self.param_groups[0]["params"][0].device 68 | norm = torch.norm( 69 | torch.stack([ 70 | p.grad.norm(p=2).to(shared_device) 71 | for group in self.param_groups for p in group["params"] 72 | if p.grad is not None 73 | ]), 74 | p=2 75 | ) 76 | return norm 77 | -------------------------------------------------------------------------------- /kuma_utils/preprocessing/feature_selection.py: -------------------------------------------------------------------------------- 1 | import pandas as pd 2 | from .base import PreprocessingTemplate 3 | from .utils import analyze_column 4 | 5 | 6 | class SelectNumerical(PreprocessingTemplate): 7 | ''' 8 | ''' 9 | def __init__(self, include_cols: list = [], exclude_cols: list = []): 10 | self.include_cols = include_cols 11 | self.exclude_cols = exclude_cols 12 | 13 | def fit(self, X: pd.DataFrame, y=None) -> None: 14 | raise RuntimeError("fit() is not supported.") 15 | 16 | def transform(self, X: pd.DataFrame) -> pd.DataFrame: 17 | if len(self.include_cols) == 0: 18 | select_cols = [col for col in X.columns if \ 19 | (analyze_column(X[col]) == 'numerical') and (col not in self.exclude_cols)] 20 | else: 21 | select_cols = self.include_cols 22 | return X[select_cols].copy() 23 | 24 | def fit_transform(self, X: pd.DataFrame, y=None) -> pd.DataFrame: 25 | return self.transform(X) 26 | 27 | 28 | class SelectCategorical(PreprocessingTemplate): 29 | ''' 30 | ''' 31 | def __init__(self, include_cols: list = [], exclude_cols: list = []): 32 | self.include_cols = include_cols 33 | self.exclude_cols = exclude_cols 34 | 35 | def fit(self, X: pd.DataFrame, y=None) -> None: 36 | raise RuntimeError("fit() is not supported.") 37 | 38 | def transform(self, X: pd.DataFrame) -> pd.DataFrame: 39 | if len(self.include_cols) == 0: 40 | select_cols = [col for col in X.columns if \ 41 | (analyze_column(X[col]) == 'categorical') and (col not in self.exclude_cols)] 42 | else: 43 | select_cols = self.include_cols 44 | return X[select_cols].copy() 45 | 46 | def fit_transform(self, X: pd.DataFrame, y=None) -> pd.DataFrame: 47 | return self.transform(X) 48 | 49 | 50 | class DummyVarible(PreprocessingTemplate): 51 | 52 | def __init__(self, include_cols: list = [], exclude_cols: list = [], dummy_na: bool = False): 53 | self.include_cols = include_cols 54 | self.exclude_cols = exclude_cols 55 | self.dummy_na = dummy_na 56 | 57 | def fit(self, X: pd.DataFrame, y=None) -> None: 58 | raise RuntimeError("fit() is not supported.") 59 | 60 | def transform(self, X: pd.DataFrame) -> pd.DataFrame: 61 | X_new = X.copy() 62 | if len(self.include_cols) == 0: 63 | target_cols = [col for col in X.columns if col not in self.exclude_cols] 64 | else: 65 | target_cols = self.include_cols 66 | X_new = pd.concat([ 67 | X, pd.get_dummies( 68 | X[target_cols], dummy_na=self.dummy_na, prefix=target_cols, prefix_sep='=').astype(int)], axis=1) 69 | X_new.drop(target_cols, axis=1, inplace=True) 70 | return X_new 71 | 72 | def fit_transform(self, X: pd.DataFrame, y=None) -> pd.DataFrame: 73 | return self.transform(X) 74 | -------------------------------------------------------------------------------- /kuma_utils/torch/modules/attention.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | import torch.nn.functional as F 4 | from .pooling import ChannelPool 5 | 6 | 7 | class ChannelAttention(nn.Module): 8 | def __init__(self, in_planes, ratio=12): 9 | super().__init__() 10 | self.avg_pool = nn.AdaptiveAvgPool2d(1) 11 | self.max_pool = nn.AdaptiveMaxPool2d(1) 12 | 13 | self.fc1 = nn.Conv2d(in_planes, in_planes // ratio, 1, bias=False) 14 | self.relu1 = nn.ReLU(inplace=True) 15 | self.fc2 = nn.Conv2d(in_planes // ratio, in_planes, 1, bias=False) 16 | self.sigmoid = nn.Sigmoid() 17 | 18 | def forward(self, x): 19 | avg_out = self.fc2(self.relu1(self.fc1(self.avg_pool(x)))) 20 | max_out = self.fc2(self.relu1(self.fc1(self.max_pool(x)))) 21 | out = avg_out + max_out 22 | return self.sigmoid(out) 23 | 24 | 25 | class SpatialAttention(nn.Module): 26 | def __init__(self, kernel_size=7): 27 | super().__init__() 28 | 29 | self.pool = ChannelPool() 30 | self.conv1 = nn.Conv2d(2, 1, kernel_size, padding=( 31 | kernel_size-1)//2, bias=False) 32 | self.sigmoid = nn.Sigmoid() 33 | 34 | def forward(self, x): 35 | x = self.conv1(self.pool(x)) 36 | return self.sigmoid(x) 37 | 38 | 39 | class CBAM2d(nn.Module): 40 | 41 | def __init__(self, in_planes, kernel_size=7, return_mask=False): 42 | super().__init__() 43 | 44 | self.ch_attn = ChannelAttention(in_planes) 45 | self.sp_attn = SpatialAttention(kernel_size) 46 | self.return_mask = return_mask 47 | 48 | def forward(self, x): 49 | # x: bs x ch x w x h 50 | x = self.ch_attn(x) * x 51 | sp_mask = self.sp_attn(x) 52 | x = sp_mask * x 53 | if self.return_mask: 54 | return sp_mask, x 55 | else: 56 | return x 57 | 58 | 59 | class MultiInstanceAttention(nn.Module): 60 | ''' 61 | Implementation of: 62 | Attention-based Multiple Instance Learning 63 | https://arxiv.org/abs/1802.04712 64 | ''' 65 | 66 | def __init__(self, feature_size, 67 | num_classes=1, hidden_size=512, gated_attention=False): 68 | super().__init__() 69 | 70 | self.gated = gated_attention 71 | 72 | self.attn_U = nn.Sequential( 73 | nn.Linear(feature_size, hidden_size), 74 | nn.Tanh() 75 | ) 76 | if self.gated: 77 | self.attn_V = nn.Sequential( 78 | nn.Linear(feature_size, hidden_size), 79 | nn.Sigmoid() 80 | ) 81 | self.attn_W = nn.Linear(hidden_size, num_classes) 82 | 83 | def forward(self, x): 84 | # x: bs x k x f 85 | # k: num of instance 86 | # f: feature dimension 87 | bs, k, f = x.shape 88 | x = x.view(bs*k, f) 89 | if self.gated: 90 | x = self.attn_W(self.attn_U(x) * self.attn_V(x)) 91 | else: 92 | x = self.attn_W(self.attn_U(x)) 93 | x = x.view(bs, k, self.attn_W.out_features) 94 | x = F.softmax(x.transpose(1, 2), dim=2) # Softmax over k 95 | return x # : bs x 1 x k 96 | -------------------------------------------------------------------------------- /kuma_utils/training/optuna/configs.py: -------------------------------------------------------------------------------- 1 | def catboost_reg_params(trial, params): 2 | _params = { 3 | "colsample_bylevel": trial.suggest_float("colsample_bylevel", 0.01, 0.1), 4 | "depth": trial.suggest_int("depth", 1, 12), 5 | "boosting_type": trial.suggest_categorical("boosting_type", ["Ordered", "Plain"]), 6 | "bootstrap_type": trial.suggest_categorical( 7 | "bootstrap_type", ["Bayesian", "Bernoulli", "MVS"] 8 | ) 9 | } 10 | if _params["bootstrap_type"] == "Bayesian": 11 | _params["bagging_temperature"] = trial.suggest_float("bagging_temperature", 0, 10) 12 | elif _params["bootstrap_type"] == "Bernoulli": 13 | _params["subsample"] = trial.suggest_float("subsample", 0.1, 1) 14 | 15 | _params.update(params) 16 | return _params 17 | 18 | 19 | def catboost_cls_params(trial, params): 20 | _params = { 21 | "objective": trial.suggest_categorical("objective", ["Logloss", "CrossEntropy"]), 22 | } 23 | _params = catboost_reg_params(trial, params) 24 | _params.update(params) 25 | return _params 26 | 27 | 28 | def xgboost_params(trial, params): 29 | _params = { 30 | # "booster": trial.suggest_categorical("booster", ["gbtree", "gblinear", "dart"]), 31 | "booster": trial.suggest_categorical("booster", ["gbtree", "dart"]), 32 | "lambda": trial.suggest_float("lambda", 1e-8, 1.0, log=True), 33 | "alpha": trial.suggest_float("alpha", 1e-8, 1.0, log=True), 34 | } 35 | if _params["booster"] == "gbtree" or _params["booster"] == "dart": 36 | _params["max_depth"] = trial.suggest_int("max_depth", 1, 9) 37 | _params["eta"] = trial.suggest_float("eta", 1e-8, 1.0, log=True) 38 | _params["gamma"] = trial.suggest_float("gamma", 1e-8, 1.0, log=True) 39 | _params["grow_policy"] = trial.suggest_categorical("grow_policy", ["depthwise", "lossguide"]) 40 | if _params["booster"] == "dart": 41 | _params["sample_type"] = trial.suggest_categorical("sample_type", ["uniform", "weighted"]) 42 | _params["normalize_type"] = trial.suggest_categorical("normalize_type", ["tree", "forest"]) 43 | _params["rate_drop"] = trial.suggest_float("rate_drop", 1e-8, 1.0, log=True) 44 | _params["skip_drop"] = trial.suggest_float("skip_drop", 1e-8, 1.0, log=True) 45 | _params.update(params) 46 | return _params 47 | 48 | 49 | def random_forest_params(trial, params): 50 | _params = { 51 | 'max_depth': trial.suggest_int("max_depth", 2, 32, log=True), 52 | 'n_estimators': trial.suggest_categorical('n_estimators', [5, 10, 20, 30, 50, 100]), 53 | 'max_features': trial.suggest_float('max_features', 2, 32, log=True) 54 | } 55 | _params.update(params) 56 | return _params 57 | 58 | 59 | def svm_params(trial, params): 60 | _params = { 61 | 'kernel': trial.suggest_categorical('kernel', ['rbf', 'linear', 'poly']), 62 | 'C': trial.suggest_float("C", 1e-3, 1e3, log=True) 63 | } 64 | if _params['kernel'] in ['rbf', 'poly']: 65 | _params['gamma'] = trial.suggest_int('gamma', 1, 1e3, log=True) 66 | if _params['kernel'] == 'poly': 67 | _params['degree'] = trial.suggest_int('degree', 0, 5) 68 | _params.update(params) 69 | return _params 70 | -------------------------------------------------------------------------------- /kuma_utils/preprocessing/transformer.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import pandas as pd 3 | from sklearn.preprocessing import StandardScaler, MinMaxScaler, PowerTransformer, QuantileTransformer 4 | from copy import copy 5 | from tqdm.auto import tqdm 6 | from .base import PreprocessingTemplate 7 | 8 | 9 | class _DistTransformer1d(PreprocessingTemplate): 10 | 11 | def __init__(self, transform='standard'): 12 | assert transform in [ 13 | 'standard', 'min-max', 'box-cox', 'yeo-johnson', 'rankgauss'] 14 | self.t = transform 15 | 16 | def fit(self, X: pd.Series, y=None) -> None: 17 | if self.t == 'standard': 18 | self.transformer = StandardScaler() 19 | elif self.t == 'min-max': 20 | self.transformer = MinMaxScaler() 21 | elif self.t == 'box-cox': 22 | self.transformer = PowerTransformer(method='box-cox') 23 | elif self.t == 'yeo-johnson': 24 | self.transformer = PowerTransformer(method='yeo-johnson') 25 | elif self.t == 'rankgauss': 26 | self.transformer = QuantileTransformer( 27 | n_quantiles=len(X), random_state=0, 28 | output_distribution='normal') 29 | else: 30 | raise ValueError(self.transform) 31 | 32 | if isinstance(X, pd.Series): 33 | self.transformer.fit(X.values.reshape(-1, 1)) 34 | elif isinstance(X, np.ndarray): 35 | self.transformer.fit(X.reshape(-1, 1)) 36 | else: 37 | raise TypeError(type(X)) 38 | 39 | def transform(self, X: pd.Series) -> np.ndarray: 40 | if isinstance(X, pd.Series): 41 | return self.transformer.transform(X.values.reshape(-1, 1)) 42 | elif isinstance(X, np.ndarray): 43 | return self.transformer.transform(X.reshape(-1, 1)) 44 | else: 45 | raise TypeError(type(X)) 46 | 47 | def fit_transform(self, X: pd.Series, y=None) -> np.ndarray: 48 | self.fit(X) 49 | return self.transform(X) 50 | 51 | def copy(self): 52 | return copy(self) 53 | 54 | 55 | class DistTransformer(PreprocessingTemplate): 56 | ''' 57 | Distribution Transformer for numerical features 58 | 59 | Availbale transforms: 60 | ['standard', 'min-max', 'box-cox', 'yeo-johnson', 'rankgauss'] 61 | ''' 62 | 63 | def __init__(self, transform='standard', verbose=False): 64 | self.t = transform 65 | self.transformers = {} 66 | self.verbose = verbose 67 | 68 | def fit(self, X: pd.DataFrame, y=None) -> None: 69 | self._input_cols = X.columns.tolist() 70 | 71 | col_iter = tqdm(self._input_cols) if self.verbose else self._input_cols 72 | for col in col_iter: 73 | self.transformers[col] = _DistTransformer1d(self.t) 74 | self.transformers[col].fit(X[col]) 75 | 76 | def transform(self, X: pd.DataFrame) -> pd.DataFrame: 77 | out_df = X.copy() 78 | col_iter = tqdm(self._input_cols) if self.verbose else self._input_cols 79 | for col in col_iter: 80 | out_df[col] = self.transformers[col].transform(X[col]) 81 | 82 | return out_df 83 | 84 | def fit_transform(self, X: pd.DataFrame, y=None) -> pd.DataFrame: 85 | self.fit(X) 86 | return self.transform(X) 87 | -------------------------------------------------------------------------------- /kuma_utils/training/utils.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import lightgbm as lgb 3 | from lightgbm.compat import _LGBMLabelEncoder 4 | from xgboost.compat import XGBoostLabelEncoder 5 | import xgboost as xgb 6 | from sklearn.metrics import ( 7 | roc_auc_score, mean_squared_error, mean_absolute_error) 8 | 9 | 10 | def booster2sklearn(booster, model, X, y): 11 | assert isinstance(booster, (lgb.Booster, xgb.Booster)) 12 | new_model = model() 13 | new_model._Booster = booster 14 | new_model._n_features = X.shape[1] 15 | new_model.fitted_ = True 16 | if new_model.__class__.__name__ == 'LGBMClassifier': 17 | new_model._le = _LGBMLabelEncoder().fit(y) 18 | new_model._class_map = dict(zip( 19 | new_model._le.classes_, 20 | new_model._le.transform(new_model._le.classes_))) 21 | new_model._classes = new_model._le.classes_ 22 | new_model._n_classes = len(new_model._classes) 23 | elif new_model.__class__.__name__ == 'XGBClassifier': 24 | new_model._le = XGBoostLabelEncoder().fit(y) 25 | new_model.classes_ = new_model._le.classes_ 26 | return new_model 27 | 28 | 29 | def acc_metric(model, data): 30 | if isinstance(model, (lgb.Booster, xgb.Booster)): 31 | target = data.get_label() 32 | approx = (model.predict(data) >= 0.5).astype(int) 33 | else: 34 | target = data[1] 35 | approx = model.predict(data[0]) 36 | return np.mean(target == approx) 37 | 38 | 39 | def auc_metric(model, data): 40 | if isinstance(model, (lgb.Booster, xgb.Booster)): 41 | target = data.get_label() 42 | approx = model.predict(data) 43 | else: 44 | target = data[1] 45 | approx = model.predict_proba(data[0])[:, 1] 46 | return roc_auc_score(target, approx) 47 | 48 | 49 | def mae_metric(model, data): 50 | if isinstance(model, (lgb.Booster, xgb.Booster)): 51 | target = data.get_label() 52 | approx = model.predict(data) 53 | else: 54 | target = data[1] 55 | approx = model.predict(data[0]) 56 | return mean_absolute_error(target, approx) 57 | 58 | 59 | def mse_metric(model, data): 60 | if isinstance(model, (lgb.Booster, xgb.Booster)): 61 | target = data.get_label() 62 | approx = model.predict(data) 63 | else: 64 | target = data[1] 65 | approx = model.predict(data[0]) 66 | return mean_squared_error(target, approx) 67 | 68 | 69 | def rmse_metric(model, data): 70 | if isinstance(model, (lgb.Booster, xgb.Booster)): 71 | target = data.get_label() 72 | approx = model.predict(data) 73 | else: 74 | target = data[1] 75 | approx = model.predict(data[0]) 76 | return mean_squared_error(target, approx, squared=False) 77 | 78 | 79 | class ModelExtractor: 80 | ''' 81 | Model extractor for lightgbm and xgboost .cv() 82 | ''' 83 | 84 | def __init__(self): 85 | self.model = None 86 | 87 | def __call__(self, env): 88 | if env.model is not None: 89 | # lightgbm 90 | self.model = env.model 91 | else: 92 | # xgboost 93 | self.model = [cvpack.bst for cvpack in env.cvfolds] 94 | 95 | def get_model(self): 96 | return self.model 97 | 98 | def get_best_iteration(self): 99 | return self.model.best_iteration 100 | 101 | 102 | class XGBModelExtractor(xgb.callback.TrainingCallback): 103 | def __init__(self, cvboosters): 104 | self._cvboosters = cvboosters 105 | 106 | def after_training(self, model): 107 | self._cvboosters[:] = [cvpack.bst for cvpack in model.cvfolds] 108 | return model 109 | -------------------------------------------------------------------------------- /kuma_utils/torch/utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | import random 3 | import subprocess 4 | import numpy as np 5 | import torch 6 | import time 7 | import resource 8 | try: 9 | import torch_xla 10 | import torch_xla.core.xla_model as xm 11 | XLA = True 12 | except ModuleNotFoundError: 13 | XLA = False 14 | 15 | 16 | def freeze_module(module): 17 | for i, param in enumerate(module.parameters()): 18 | param.requires_grad = False 19 | 20 | 21 | def fit_state_dict(state_dict, model): 22 | ''' 23 | Ignore size mismatch when loading state_dict 24 | ''' 25 | for name, param in model.named_parameters(): 26 | if name in state_dict.keys(): 27 | new_param = state_dict[name] 28 | else: 29 | continue 30 | if new_param.size() != param.size(): 31 | print(f'Size mismatch in {name}: {new_param.shape} -> {param.shape}') 32 | state_dict.pop(name) 33 | 34 | 35 | def get_device(arg): 36 | if isinstance(arg, torch.device) or \ 37 | (XLA and isinstance(arg, xm.xla_device)): 38 | device = arg 39 | elif arg is None or isinstance(arg, (list, tuple)): 40 | if XLA: 41 | device = xm.xla_device() 42 | else: 43 | device = torch.device( 44 | 'cuda' if torch.cuda.is_available() else 'cpu') 45 | elif isinstance(arg, str): 46 | if arg == 'xla' and XLA: 47 | device = xm.xla_device() 48 | else: 49 | device = torch.device(arg) 50 | 51 | if isinstance(arg, (list, tuple)): 52 | if isinstance(arg[0], int): 53 | device_ids = list(arg) 54 | elif isinstance(arg[0], str) and arg[0].isnumeric(): 55 | device_ids = [int(a) for a in arg] 56 | else: 57 | raise ValueError(f'Invalid device: {arg}') 58 | else: 59 | if device.type == 'cuda': 60 | assert torch.cuda.is_available() 61 | if device.index is None: 62 | device_count = torch.cuda.device_count() 63 | if device_count > 1: 64 | device_ids = list(range(device_count)) 65 | else: 66 | device_ids = [0] 67 | else: 68 | device_ids = [device.index] 69 | else: 70 | device_ids = [device.index] 71 | 72 | return device, device_ids 73 | 74 | 75 | def seed_everything(random_state=0, deterministic=False): 76 | random.seed(random_state) 77 | os.environ['PYTHONHASHSEED'] = str(random_state) 78 | np.random.seed(random_state) 79 | torch.manual_seed(random_state) 80 | torch.cuda.manual_seed(random_state) 81 | if deterministic: 82 | torch.backends.cudnn.deterministic = True 83 | torch.backends.cudnn.benchmark = False 84 | else: 85 | torch.backends.cudnn.deterministic = False 86 | 87 | 88 | def get_gpu_memory(): 89 | """ 90 | Code borrowed from: 91 | https://discuss.pytorch.org/t/access-gpu-memory-usage-in-pytorch/3192/4 92 | 93 | Get the current gpu usage. 94 | 95 | Returns 96 | ------- 97 | usage: dict 98 | Keys are device ids as integers. 99 | Values are memory usage as integers in MB. 100 | """ 101 | result = subprocess.check_output( 102 | [ 103 | 'nvidia-smi', '--query-gpu=memory.used', 104 | '--format=csv,nounits,noheader' 105 | ], encoding='utf-8') 106 | # Convert lines into a dictionary 107 | gpu_memory = [int(x) for x in result.strip().split('\n')] 108 | gpu_memory_map = dict(zip(range(len(gpu_memory)), gpu_memory)) 109 | return gpu_memory_map 110 | 111 | 112 | def get_system_usage(): 113 | try: 114 | ram_usage = resource.getrusage(resource.RUSAGE_SELF).ru_maxrss // 1024 # MB 115 | except Exception: 116 | ram_usage = None 117 | try: 118 | gram_usage = ' / '.join([f'{gid}:{gram} MB' for gid, gram in get_gpu_memory().items()]) 119 | except Exception: 120 | gram_usage = None 121 | return {'ram_usage': ram_usage, 'gram_usage': gram_usage} 122 | 123 | 124 | def get_time(time_format='%H:%M:%S'): 125 | return time.strftime(time_format, time.localtime()) 126 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Kuma's Toolkit 2025 2 | 3 | ``` 4 |      ┼╂┼ 5 |     ∩_┃_∩ 6 | |ノ ヽ 7 | / ● ● | 8 | | (_●_) ミ < There is absolutely no warranty. > 9 | 彡、 |∪| 、`\ 10 | / __ ヽノ /´> ) 11 | (___) / (_/ 12 | ``` 13 | 14 | # Overview 15 | Using this library, you can: 16 | 17 | - Simplify the structuring of table data and feature engineering 18 | - implify the training and hyperparameter search for ML tools with Sklearn API (including sklearn, lightgbm, catboost, etc.) 19 | - Simplify the training of Pytorch models (including the use of amp and parallelization across multiple GPUs) 20 | - Customize training with Hook/Callback interface (such as Earlystop, logging functions integrated with wandb, etc.) 21 | - Automated exploratory data analysis 22 | - Convenient functions for basic biostatistical analysis. 23 | 24 | ## Work in progress 25 | - Multi-node DDP 26 | 27 | # Setup 28 | ## Pip 29 | **Stable** 30 | ```bash 31 | pip install git+https://github.com/analokmaus/kuma_utils.git@v0.7.0 # Stable 32 | ``` 33 | 34 | **Latest** 35 | ```bash 36 | pip install git+https://github.com/analokmaus/kuma_utils.git@master # Latest 37 | ``` 38 | 39 | ### IMPORTANT: 40 | **Mac users must install `libomp` before installing this package.** 41 | ```bash 42 | brew install libomp 43 | pip install git+https://github.com/analokmaus/kuma_utils.git 44 | ``` 45 | 46 | ## Poetry 47 | ```bash 48 | poetry add git+https://github.com/analokmaus/kuma_utils.git 49 | ``` 50 | 51 | ## Alternative installation methods 52 | WIP 53 | 54 | # Tutorials 55 | - [Exploratory data analysis](examples/Exploratory_data_analysis.ipynb) 56 | - [Data preprocessing](examples/Data_preprocessing.ipynb) 57 | - [Train and validate scikit-learn API models](examples/Train_and_validate_models.ipynb) 58 | - [Train pytorch models on single GPU](examples/Train_CNN_model.ipynb) 59 | - [Train pytorch models on multiple GPU](examples/Train_CNN_distributed.py) 60 | - [Statistical analysis (propensity score matching)](examples/Statistical_analysis.ipynb) 61 | 62 | # Directory 63 | ``` 64 | ┣ visualization 65 | ┃ ┣ explore_data - Simple exploratory data analysis. 66 | ┃ 67 | ┣ preprocessing 68 | ┃ ┣ SelectNumerical 69 | ┃ ┣ SelectCategorical 70 | ┃ ┣ DummyVariable 71 | ┃ ┣ DistTransformer - Distribution transformer for numerical features. 72 | ┃ ┣ LGBMImputer - Regression imputer for missing values using LightGBM. 73 | ┃ 74 | ┣ stats 75 | ┃ ┣ make_demographic_table - Automated demographic table generator. 76 | ┃ ┣ PropensityScoreMatching - Fast and capable of using all sklearn API models as a backend. 77 | ┃ 78 | ┣ training 79 | ┃ ┣ Trainer - Wrapper for scikit-learn API models. 80 | ┃ ┣ CrossValidator - Ccross validation wrapper. 81 | ┃ ┣ LGBMLogger - Logger callback for LightGBM/XGBoost/Optuna. 82 | ┃ ┣ StratifiedGroupKFold - Stratified group k-fold split. 83 | ┃ ┣ optuna - optuna modifications. 84 | ┃ 85 | ┣ metrics - Universal metrics 86 | ┃ ┣ SensitivityAtFixedSpecificity 87 | ┃ ┣ RMSE 88 | ┃ ┣ Pearson correlation coefficient 89 | ┃ ┣ R2 score 90 | ┃ ┣ AUC 91 | ┃ ┣ Accuracy 92 | ┃ ┣ QuandricWeightKappa 93 | ┃ 94 | ┣ torch 95 | ┣ lr_scheduler 96 | ┃ ┣ ManualScheduler 97 | ┃ ┣ CyclicCosAnnealingLR 98 | ┃ ┣ CyclicLinearLR 99 | ┃ 100 | ┣ optimizer 101 | ┃ ┣ SAM 102 | ┃ 103 | ┣ modules 104 | ┃ ┣ Mish 105 | ┃ ┣ AdaptiveConcatPool2d/3d 106 | ┃ ┣ GeM 107 | ┃ ┣ CBAM2d 108 | ┃ ┣ GroupNorm1d/2d/3d 109 | ┃ ┣ convert_groupnorm - Convert all BatchNorm to GroupNorm. 110 | ┃ ┣ TemperatureScaler - Probability calibration for pytorch models. 111 | ┃ ┣ etc... 112 | ┃ 113 | ┣ TorchTrainer - PyTorch Trainer. 114 | ┣ EarlyStopping - Early stopping callback for TorchTrainer. Save snapshot when best score is achieved. 115 | ┣ SaveEveryEpoch - Save snapshot at the end of every epoch. 116 | ┣ SaveSnapshot - Snapshot callback. 117 | ┣ SaveAverageSnapshot - Moving average snapshot callback. 118 | ┣ TorchLogger - Logger 119 | ┣ SimpleHook - Simple train hook for almost any tasks (see tutorial). 120 | 121 | ``` 122 | 123 | # License 124 | The source code in this repository is released under the MIT license. -------------------------------------------------------------------------------- /kuma_utils/metrics/classification.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from sklearn.metrics import confusion_matrix, roc_auc_score 3 | from .base import MetricTemplate 4 | 5 | 6 | class AUC(MetricTemplate): 7 | ''' 8 | Area under ROC curve 9 | ''' 10 | def __init__(self): 11 | super().__init__(maximize=True) 12 | 13 | def _test(self, target, approx): 14 | if len(approx.shape) == 1: 15 | approx = approx 16 | elif approx.shape[1] == 1: 17 | approx = np.squeeze(approx) 18 | elif approx.shape[1] == 2: 19 | approx = approx[:, 1] 20 | else: 21 | raise ValueError(f'Invalid approx shape: {approx.shape}') 22 | return roc_auc_score(target, approx) 23 | 24 | 25 | class Accuracy(MetricTemplate): 26 | ''' 27 | Accuracy 28 | ''' 29 | def __init__(self): 30 | super().__init__(maximize=True) 31 | 32 | def _test(self, target, approx): 33 | assert len(target) == len(approx) 34 | target = np.asarray(target, dtype=int) 35 | approx = np.asarray(approx, dtype=float) 36 | if len(approx.shape) == 1: 37 | approx = approx 38 | elif approx.shape[1] == 1: 39 | approx = np.squeeze(approx) 40 | elif approx.shape[1] >= 2: 41 | approx = np.argmax(approx, axis=1) 42 | approx = approx.round().astype(int) 43 | return np.mean((target == approx).astype(int)) 44 | 45 | 46 | class SensitivityAtFixedSpecificity(MetricTemplate): 47 | ''' 48 | Maximize sensitivity at fixed specificity 49 | ''' 50 | def __init__(self, sp=0.9, return_sp=False): 51 | super().__init__(maximize=True) 52 | self.sp = sp 53 | self.return_sp = return_sp 54 | 55 | def _get_threshold(self, target, approx): 56 | tn_idx = (target == 0) 57 | p_tn = np.sort(approx[tn_idx]) 58 | 59 | return p_tn[int(len(p_tn) * self.sp)] 60 | 61 | def _test(self, target, approx): 62 | if not isinstance(target, np.ndarray): 63 | target = np.array(target) 64 | if not isinstance(approx, np.ndarray): 65 | approx = np.array(approx) 66 | 67 | if len(approx.shape) == 1: 68 | pass 69 | elif approx.shape[1] == 1: 70 | approx = np.squeeze(approx) 71 | elif approx.shape[1] == 2: 72 | approx = approx[:, 1] 73 | else: 74 | raise ValueError(f'Invalid approx shape: {approx.shape}') 75 | 76 | if min(approx) < 0: 77 | approx -= min(approx) # make all values positive 78 | target = target.astype(int) 79 | thres = self._get_threshold(target, approx) 80 | pred = (approx > thres).astype(int) 81 | tn, fp, fn, tp = confusion_matrix(target, pred).ravel() 82 | se = tp / (tp + fn) 83 | sp = tn / (tn + fp) 84 | 85 | if self.return_sp: 86 | return se, sp 87 | else: 88 | return se 89 | 90 | 91 | class QuandricWeightKappa(MetricTemplate): 92 | ''' 93 | Quandric Weight Kappa 94 | ''' 95 | def __init__(self, max_rat): 96 | super().__init__(maximize=True) 97 | self.max_rat = max_rat 98 | 99 | def _test(self, target, approx): 100 | assert len(target) == len(approx) 101 | target = np.asarray(target, dtype=int) 102 | approx = np.asarray(approx, dtype=float) 103 | if len(approx.shape) == 1: 104 | approx = approx 105 | elif approx.shape[1] == 1: 106 | approx = np.squeeze(approx) 107 | elif approx.shape[1] >= 2: 108 | approx = np.argmax(approx, axis=1) 109 | approx = np.clip(approx.round(), 0, self.max_rat-1).astype(int) 110 | 111 | hist1 = np.zeros((self.max_rat+1, )) 112 | hist2 = np.zeros((self.max_rat+1, )) 113 | 114 | o = 0 115 | for k in range(target.shape[0]): 116 | i, j = target[k], approx[k] 117 | hist1[i] += 1 118 | hist2[j] += 1 119 | o += (i - j) * (i - j) 120 | 121 | e = 0 122 | for i in range(self.max_rat + 1): 123 | for j in range(self.max_rat + 1): 124 | e += hist1[i] * hist2[j] * (i - j) * (i - j) 125 | 126 | e = e / target.shape[0] 127 | 128 | return 1 - o / e 129 | 130 | 131 | ''' 132 | Alias 133 | ''' 134 | SeAtFixedSp = SensitivityAtFixedSpecificity 135 | QWK = QuandricWeightKappa 136 | -------------------------------------------------------------------------------- /kuma_utils/torch/hooks/simple_hook.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from .base import HookTemplate 3 | from ..clip_grad import dispatch_clip_grad 4 | 5 | 6 | class TrainHook(HookTemplate): 7 | 8 | def __init__(self, evaluate_in_batch=False, clip_grad=None, max_grad_norm=100., sam_optimizer=False): 9 | super().__init__() 10 | self.evaluate_in_batch = evaluate_in_batch 11 | self.clip_grad = clip_grad 12 | self.max_grad_norm = max_grad_norm 13 | self.sam_optimizer = sam_optimizer 14 | 15 | def forward_train(self, trainer, inputs): 16 | target = inputs[-1] 17 | approx = trainer.model(*inputs[:-1]) 18 | loss = trainer.criterion(approx, target) 19 | return loss, approx.detach() 20 | 21 | forward_valid = forward_train 22 | 23 | def forward_test(self, trainer, inputs): 24 | approx = trainer.model(*inputs[:-1]) 25 | return approx 26 | 27 | def _backprop_normal(self, trainer, loss, inputs=None): 28 | trainer.scaler.scale(loss).backward() 29 | dispatch_clip_grad(trainer.model.parameters(), self.max_grad_norm, mode=self.clip_grad) 30 | trainer.scaler.step(trainer.optimizer) 31 | trainer.scaler.update() 32 | trainer.optimizer.zero_grad() 33 | 34 | def _backprop_sam(self, trainer, loss, inputs): 35 | trainer.scaler.scale(loss).backward() 36 | dispatch_clip_grad(trainer.model.parameters(), self.max_grad_norm, mode=self.clip_grad) 37 | if trainer.fp16: 38 | # first step 39 | optimizer_state = trainer.scaler._per_optimizer_states[id(trainer.optimizer)] 40 | trainer.scaler.unscale_(trainer.optimizer) 41 | if not sum(v.item() for v in optimizer_state["found_inf_per_device"].values()): 42 | trainer.optimizer.first_step(zero_grad=True) 43 | optimizer_state["stage"] = 2 44 | trainer.scaler.update() 45 | # second step 46 | with trainer.autocast: 47 | loss2, _ = trainer.forward_train(trainer, inputs) 48 | trainer.scaler.scale(loss2).backward() 49 | trainer.scaler.unscale_(trainer.optimizer) 50 | if not sum(v.item() for v in optimizer_state["found_inf_per_device"].values()): 51 | trainer.optimizer.second_step(zero_grad=True) 52 | optimizer_state["stage"] = 2 53 | else: 54 | trainer.optimizer.first_step(zero_grad=True) 55 | loss2, _ = trainer.forward_train(trainer, inputs) 56 | loss2.backward() 57 | trainer.optimizer.second_step(zero_grad=True) 58 | trainer.scaler.update() 59 | trainer.optimizer.zero_grad() 60 | 61 | def backprop(self, trainer, loss, inputs=None): 62 | if self.sam_optimizer: 63 | self._backprop_sam(trainer, loss, inputs) 64 | else: 65 | self._backprop_normal(trainer, loss, inputs) 66 | 67 | def _evaluate(self, trainer, approx, target): 68 | if trainer.eval_metric is None: 69 | if trainer.criterion is None: 70 | metric_score = 0. 71 | else: 72 | metric_score = trainer.criterion(approx, target).item() 73 | else: 74 | metric_score = trainer.eval_metric(approx, target) 75 | if isinstance(metric_score, torch.Tensor): 76 | metric_score = metric_score.item() 77 | monitor_score = [] 78 | for monitor_metric in trainer.monitor_metrics: 79 | score = monitor_metric(approx, target) 80 | if isinstance(score, torch.Tensor): 81 | score = score.item() 82 | monitor_score.append(score) 83 | return metric_score, monitor_score 84 | 85 | def evaluate_batch(self, trainer, inputs, approx): 86 | target = inputs[-1] 87 | storage = trainer.epoch_storage 88 | if self.evaluate_in_batch: 89 | # Add scores to storage 90 | metric_score, monitor_score = self._evaluate(trainer, approx, target) 91 | storage['batch_metric'].append(metric_score) 92 | storage['batch_monitor'].append(monitor_score) 93 | else: 94 | # Add prediction and target to storage 95 | storage['approx'].append(approx) 96 | storage['target'].append(target) 97 | 98 | def evaluate_epoch(self, trainer): 99 | storage = trainer.epoch_storage 100 | if self.evaluate_in_batch: 101 | # Calculate mean metrics from all batches 102 | metric_total = storage['batch_metric'].mean(0) 103 | monitor_total = storage['batch_monitor'].mean(0).tolist() 104 | else: 105 | # Calculate scores 106 | metric_total, monitor_total = self._evaluate( 107 | trainer, storage['approx'], storage['target']) 108 | return metric_total, monitor_total 109 | 110 | 111 | SimpleHook = TrainHook # Compatibility 112 | -------------------------------------------------------------------------------- /kuma_utils/stats/propensity_score.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import pandas as pd 3 | from typing import Any 4 | from sklearn.linear_model import LogisticRegression 5 | from sklearn.preprocessing import LabelEncoder 6 | from scipy.optimize import linear_sum_assignment 7 | from kuma_utils.training import Trainer 8 | from kuma_utils.preprocessing import PreprocessingTemplate, PrepPipeline, SelectCategorical, SelectNumerical 9 | 10 | 11 | class PropensityScoreMatching: 12 | ''' 13 | Propensity score matching with various backend models 14 | ''' 15 | def __init__(self, 16 | match_cols: list[str], 17 | group_col: str, 18 | return_zscore: bool = True, 19 | categorical_encoder: PreprocessingTemplate = PrepPipeline([SelectCategorical()]), 20 | numerical_encoder: PreprocessingTemplate = PrepPipeline([SelectNumerical()]), 21 | matching_method: str = 'hungarian', 22 | model: Any = LogisticRegression, 23 | trainer_params: dict[dict] = {'params': {}, 'fit_params': {}}, 24 | fit_method: str = 'fit', 25 | caliper: str | float = 'auto'): 26 | self.match_cols = match_cols 27 | self.group_col = group_col 28 | self.return_zscore = return_zscore 29 | self.matching_method = matching_method 30 | self._le = LabelEncoder() 31 | self._cat_enc = categorical_encoder 32 | self._num_enc = numerical_encoder 33 | self.trainer = Trainer(model) 34 | self.trainer_params = trainer_params 35 | self.fit_method = fit_method 36 | self.caliper = caliper 37 | assert self.matching_method in ['greedy', 'hungarian'] 38 | assert self.caliper == 'auto' or isinstance(caliper, float) 39 | assert self.fit_method in ['fit', 'cv'] 40 | 41 | def _match(self, ps1, ps2, caliper): 42 | distance_matrix = np.abs(ps1 - ps2) 43 | if self.matching_method == 'greedy': 44 | ps1_index = [] 45 | ps2_index = [] 46 | while np.min(distance_matrix) < caliper: 47 | # get index of minimum distance element 48 | i, j = np.unravel_index(np.argmin(distance_matrix), distance_matrix.shape) 49 | ps1_index.append(i) 50 | ps2_index.append(j) 51 | distance_matrix[i, :] = 1 52 | distance_matrix[:, j] = 1 53 | elif self.matching_method == 'hungarian': 54 | row_idx, col_idx = linear_sum_assignment(distance_matrix) 55 | valid_idx = distance_matrix[row_idx, col_idx] < caliper 56 | ps1_index, ps2_index = row_idx[valid_idx], col_idx[valid_idx] 57 | del distance_matrix 58 | return ps1_index, ps2_index 59 | 60 | def run(self, df: pd.DataFrame): 61 | assert df[self.group_col].nunique() == 2 62 | X_match = df[self.match_cols].copy() 63 | X_match = pd.concat([ 64 | self._cat_enc.fit_transform(X_match).reset_index(drop=True), 65 | self._num_enc.fit_transform(X_match).reset_index(drop=True)], axis=1) 66 | y_match = self._le.fit_transform(df[self.group_col].copy()) 67 | 68 | if self.fit_method == 'fit': 69 | self.trainer.fit( 70 | train_data=(X_match, y_match), 71 | valid_data=(X_match, y_match), 72 | **self.trainer_params) 73 | X_match['__z_score'] = self.trainer.predict_proba(X_match)[:, 1] 74 | elif self.fit_method == 'cv': 75 | self.trainer.cv( 76 | data=(X_match, y_match), 77 | **self.trainer_params) 78 | X_match['__z_score'] = np.mean(self.trainer.smart_predict(X_match), axis=0) 79 | 80 | if self.caliper == 'auto': 81 | caliper = X_match['__z_score'].std() * 0.2 82 | else: 83 | caliper = self.caliper 84 | zero_index, one_index = self._match( 85 | X_match.loc[y_match == 0, '__z_score'].values.reshape(-1, 1), 86 | X_match.loc[y_match == 1, '__z_score'].values, 87 | caliper=caliper) 88 | if self.return_zscore: 89 | df['_z_score'] = X_match['__z_score'].copy() 90 | matched_data = pd.concat([ 91 | df.loc[y_match == 0].iloc[zero_index], 92 | df.loc[y_match == 1].iloc[one_index], 93 | ], axis=0).reset_index(drop=True) 94 | 95 | return matched_data 96 | 97 | def plot_feature_importance(self, df: pd.DataFrame): 98 | assert df[self.group_col].nunique() == 2 99 | X_match = df[self.match_cols].copy() 100 | X_match = pd.concat([ 101 | self._cat_enc.fit_transform(X_match).reset_index(drop=True), 102 | self._num_enc.fit_transform(X_match).reset_index(drop=True)], axis=1) 103 | y_match = self._le.fit_transform(df[self.group_col].copy()) 104 | self.trainer.plot_feature_importance( 105 | importance_type='permutation', fit_params={'X': X_match, 'y': y_match}) 106 | -------------------------------------------------------------------------------- /examples/Train_CNN_distributed.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass, asdict 2 | from copy import deepcopy 3 | import numpy as np 4 | from sklearn.model_selection import StratifiedKFold 5 | import torch 6 | import torchvision 7 | import torchvision.transforms as transforms 8 | import torch.nn as nn 9 | import torch.nn.functional as F 10 | import torch.utils.data as D 11 | import torch.optim as optim 12 | import timm 13 | from pathlib import Path 14 | import wandb 15 | 16 | from kuma_utils.torch import TorchTrainer, TorchLogger 17 | from kuma_utils.torch.callbacks import EarlyStopping, SaveSnapshot 18 | from kuma_utils.torch.hooks import SimpleHook 19 | from kuma_utils.metrics import Accuracy 20 | 21 | 22 | @dataclass 23 | class Config: 24 | num_workers: int = 32 25 | batch_size: int = 64 26 | num_epochs: int = 10 27 | early_stopping_rounds: int = 5 28 | 29 | 30 | def get_dataset(): 31 | transform = transforms.Compose([ 32 | transforms.ToTensor(), 33 | transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)) 34 | ]) 35 | train = torchvision.datasets.CIFAR10( 36 | root='input', train=True, download=True, transform=transform) 37 | test = torchvision.datasets.CIFAR10( 38 | root='input', train=False, download=True, transform=transform) 39 | return train, test 40 | 41 | 42 | def split_dataset(dataset, index): 43 | new_dataset = deepcopy(dataset) 44 | new_dataset.data = new_dataset.data[index] 45 | new_dataset.targets = np.array(new_dataset.targets)[index] 46 | return new_dataset 47 | 48 | 49 | def get_model(num_classes): 50 | model = timm.create_model('tf_efficientnet_b0.ns_jft_in1k', pretrained=True, num_classes=num_classes) 51 | return model 52 | 53 | cfg = Config( 54 | num_workers=32, 55 | batch_size=2048, 56 | num_epochs=10, 57 | early_stopping_rounds=5, 58 | ) 59 | export_dir = Path('results/demo') 60 | export_dir.mkdir(parents=True, exist_ok=True) 61 | 62 | train, test = get_dataset() 63 | print('classes', train.classes) 64 | 65 | predictions = [] 66 | splitter = StratifiedKFold(n_splits=5, shuffle=True, random_state=0) 67 | for fold, (train_idx, valid_idx) in enumerate( 68 | splitter.split(train.targets, train.targets)): 69 | 70 | print(f'fold{fold} starting') 71 | 72 | valid_fold = split_dataset(train, valid_idx) 73 | train_fold = split_dataset(train, train_idx) 74 | 75 | print(f'train: {len(train_fold)} / valid: {len(valid_fold)}') 76 | 77 | loader_train = D.DataLoader( 78 | train_fold, batch_size=cfg.batch_size, num_workers=cfg.num_workers, 79 | shuffle=True, pin_memory=True) 80 | loader_valid = D.DataLoader( 81 | valid_fold, batch_size=cfg.batch_size, num_workers=cfg.num_workers, 82 | shuffle=False, pin_memory=True) 83 | loader_test = D.DataLoader( 84 | test, batch_size=cfg.batch_size, num_workers=cfg.num_workers, 85 | shuffle=False, pin_memory=True) 86 | 87 | model = get_model(num_classes=len(train.classes)) 88 | optimizer = optim.Adam(model.parameters(), lr=2e-3) 89 | scheduler = optim.lr_scheduler.ReduceLROnPlateau( 90 | optimizer, mode='max', factor=0.5, patience=2) 91 | logger = TorchLogger( 92 | path='results/demo/log', 93 | log_items='epoch train_loss train_metric valid_loss valid_metric learning_rate early_stop', 94 | file=True, 95 | use_wandb=True, wandb_params={ 96 | 'project': 'kuma_utils_demo', 97 | 'group': 'demo_cross_validation_ddp', 98 | 'name': f'fold{fold}', 99 | 'config': asdict(cfg), 100 | }, 101 | use_tensorboard=True, 102 | tensorboard_dir=export_dir/'tensorboard', 103 | ) 104 | 105 | trn = TorchTrainer(model, serial=f'fold{fold}') 106 | trn.train( 107 | loader=loader_train, 108 | loader_valid=loader_valid, 109 | criterion=nn.CrossEntropyLoss(), 110 | eval_metric=Accuracy().torch, 111 | monitor_metrics=[ 112 | Accuracy().torch 113 | ], 114 | optimizer=optimizer, 115 | scheduler=scheduler, 116 | scheduler_target='valid_loss', # ReduceLROnPlateau reads metric each epoch 117 | num_epochs=cfg.num_epochs, 118 | hook=SimpleHook( 119 | evaluate_in_batch=False, clip_grad=None, sam_optimizer=False), 120 | callbacks=[ 121 | EarlyStopping( 122 | patience=cfg.early_stopping_rounds, 123 | target='valid_metric', 124 | maximize=True), 125 | SaveSnapshot() # Default snapshot path: {export_dir}/{serial}.pt 126 | ], 127 | logger=logger, 128 | export_dir=export_dir, 129 | parallel='ddp', # Supported parallel methods: 'dp', 'ddp' 130 | fp16=True, # Pytorch mixed precision 131 | deterministic=True, 132 | random_state=0, 133 | progress_bar=True # Progress bar shows batches done 134 | ) 135 | 136 | oof = trn.predict(loader_valid) 137 | predictions.append(trn.predict(loader_test)) 138 | 139 | score = Accuracy()(valid_fold.targets, oof) 140 | print(f'Folf{fold} score: {score:.6f}') 141 | -------------------------------------------------------------------------------- /kuma_utils/torch/callbacks/stopping.py: -------------------------------------------------------------------------------- 1 | from pprint import pformat 2 | import numpy as np 3 | from .base import CallbackTemplate 4 | 5 | 6 | class SaveEveryEpoch(CallbackTemplate): 7 | ''' 8 | Save snapshot every epoch 9 | ''' 10 | 11 | def __init__(self, ): 12 | super().__init__() 13 | 14 | def after_epoch(self, env, loader=None, loader_valid=None): 15 | env.checkpoint = True 16 | 17 | 18 | class EarlyStopping(CallbackTemplate): 19 | ''' 20 | Early stops the training if validation loss doesn't improve after a given patience. 21 | patience: int = 22 | target: str = 23 | maximize: bool = 24 | skip_epoch: int = 25 | ''' 26 | 27 | def __init__(self, patience=5, target='valid_metric', maximize=False, skip_epoch=0): 28 | super().__init__() 29 | self.state = { 30 | 'patience': patience, 31 | 'target': target, 32 | 'maximize': maximize, 33 | 'skip_epoch': skip_epoch, 34 | 'counter': 0, 35 | 'best_score': None, 36 | 'best_epoch': None 37 | } 38 | 39 | def after_epoch(self, env, loader=None, loader_valid=None): 40 | score = env.state[self.state['target']] 41 | epoch = env.state['epoch'] # local epoch 42 | if epoch < self.state['skip_epoch'] or epoch == 0: 43 | self.state['best_score'] = score 44 | self.state['best_epoch'] = env.global_epoch 45 | env.checkpoint = True 46 | env.state['best_score'] = self.state['best_score'] 47 | env.state['best_epoch'] = self.state['best_epoch'] 48 | else: 49 | if (self.state['maximize'] and score > self.state['best_score']) or \ 50 | (not self.state['maximize'] and score < self.state['best_score']): 51 | self.state['best_score'] = score 52 | self.state['best_epoch'] = env.global_epoch 53 | self.state['counter'] = 0 54 | env.checkpoint = True 55 | env.state['best_score'] = self.state['best_score'] 56 | env.state['best_epoch'] = self.state['best_epoch'] 57 | else: 58 | self.state['counter'] += 1 59 | 60 | env.state['patience'] = self.state['counter'] 61 | if self.state['counter'] >= self.state['patience']: 62 | env.stop_train = True 63 | 64 | def state_dict(self): 65 | return self.state 66 | 67 | def load_state_dict(self, checkpoint): 68 | self.state = checkpoint 69 | 70 | def __repr__(self): 71 | return f'EarlyStopping(patience={self.state["patience"]}, skip_epoch={self.state["skip_epoch"]})' 72 | 73 | 74 | class CollectTopK(CallbackTemplate): 75 | ''' 76 | Collect top k checkpoints for weight average snapshot 77 | k: int = 78 | target: str = 79 | maximize: bool = 80 | ''' 81 | 82 | def __init__(self, k=3, target='valid_metric', maximize=False, ): 83 | super().__init__() 84 | self.state = { 85 | 'k': k, 86 | 'target': target, 87 | 'maximize': maximize, 88 | 'best_scores': np.array([]), 89 | 'best_epochs': np.array([]), 90 | 'counter': 0 91 | } 92 | 93 | def after_epoch(self, env, loader=None, loader_valid=None): 94 | score = env.state[self.state['target']] 95 | epoch = env.state['epoch'] # local epoch 96 | if len(self.state['best_scores']) < self.state['k']: 97 | self.state['best_scores'] = np.append(self.state['best_scores'], score) 98 | self.state['best_epochs'] = np.append(self.state['best_epochs'], env.global_epoch) 99 | if self.state['maximize']: 100 | rank = np.argsort(-self.state['best_scores']) 101 | else: 102 | rank = np.argsort(self.state['best_scores']) 103 | 104 | env.checkpoint = True 105 | env.state['best_score'] = self.state['best_scores'][rank][0] 106 | env.state['best_epoch'] = self.state['best_epochs'][rank][0] 107 | 108 | elif (self.state['maximize'] and score > np.min(self.state['best_scores'])) or \ 109 | (not self.state['maximize'] and score < np.max(self.state['best_scores'])): 110 | if self.state['maximize']: 111 | del_idx = np.argmin(self.state['best_scores']) 112 | else: 113 | del_idx = np.argmax(self.state['best_scores']) 114 | self.state['best_scores'] = np.delete(self.state['best_scores'], del_idx) 115 | self.state['best_epochs'] = np.delete(self.state['best_epochs'], del_idx) 116 | self.state['best_scores'] = np.append(self.state['best_scores'], score) 117 | self.state['best_epochs'] = np.append(self.state['best_epochs'], env.global_epoch) 118 | if self.state['maximize']: 119 | rank = np.argsort(-self.state['best_scores']) 120 | else: 121 | rank = np.argsort(self.state['best_scores']) 122 | 123 | env.checkpoint = True 124 | env.state['best_score'] = self.state['best_scores'][rank][0] 125 | env.state['best_epoch'] = self.state['best_epochs'][rank][0] 126 | self.state['counter'] = 0 127 | else: 128 | self.state['counter'] += 1 129 | 130 | env.state['patience'] = self.state['counter'] 131 | 132 | def __repr__(self): 133 | return f'CollectTopK(k={self.state["k"]})' -------------------------------------------------------------------------------- /kuma_utils/torch/lr_scheduler/cyclic_scheduler.py: -------------------------------------------------------------------------------- 1 | ''' 2 | Implementation of Cyclic learning rate schedulers 3 | https://github.com/bluesky314/Cyclical_LR_Scheduler_With_Decay_Pytorch 4 | ''' 5 | import math 6 | from bisect import bisect_right, bisect_left 7 | 8 | import torch 9 | import numpy as np 10 | from torch.optim.lr_scheduler import _LRScheduler 11 | from torch.optim.optimizer import Optimizer 12 | 13 | 14 | class CyclicCosAnnealingLR(_LRScheduler): 15 | r""" 16 | 17 | Implements reset on milestones inspired from CosineAnnealingLR pytorch 18 | 19 | Set the learning rate of each parameter group using a cosine annealing 20 | schedule, where :math:`\eta_{max}` is set to the initial lr and 21 | :math:`T_{cur}` is the number of epochs since the last restart in SGDR: 22 | .. math:: 23 | \eta_t = \eta_{min} + \frac{1}{2}(\eta_{max} - \eta_{min})(1 + 24 | \cos(\frac{T_{cur}}{T_{max}}\pi)) 25 | When last_epoch > last set milestone, lr is automatically set to \eta_{min} 26 | It has been proposed in 27 | `SGDR: Stochastic Gradient Descent with Warm Restarts`_. Note that this only 28 | implements the cosine annealing part of SGDR, and not the restarts. 29 | Args: 30 | optimizer (Optimizer): Wrapped optimizer. 31 | milestones (list of ints): List of epoch indices. Must be increasing. 32 | decay_milestones(list of ints):List of increasing epoch indices. Ideally,decay values should overlap with milestone points 33 | gamma (float): factor by which to decay the max learning rate at each decay milestone 34 | eta_min (float): Minimum learning rate. Default: 1e-6 35 | last_epoch (int): The index of last epoch. Default: -1. 36 | 37 | 38 | .. _SGDR\: Stochastic Gradient Descent with Warm Restarts: 39 | https://arxiv.org/abs/1608.03983 40 | """ 41 | 42 | def __init__(self, optimizer, milestones, decay_milestones=None, gamma=0.5, eta_min=1e-6, last_epoch=-1): 43 | if not list(milestones) == sorted(milestones): 44 | raise ValueError('Milestones should be a list of' 45 | ' increasing integers. Got {}', milestones) 46 | self.eta_min = eta_min 47 | self.milestones = milestones 48 | self.milestones2 = decay_milestones 49 | 50 | self.gamma = gamma 51 | super(CyclicCosAnnealingLR, self).__init__(optimizer, last_epoch) 52 | 53 | def get_lr(self): 54 | 55 | if self.last_epoch >= self.milestones[-1]: 56 | return [self.eta_min for base_lr in self.base_lrs] 57 | 58 | idx = bisect_right(self.milestones, self.last_epoch) 59 | 60 | left_barrier = 0 if idx == 0 else self.milestones[idx-1] 61 | right_barrier = self.milestones[idx] 62 | 63 | width = right_barrier - left_barrier 64 | curr_pos = self.last_epoch - left_barrier 65 | 66 | if self.milestones2: 67 | return [self.eta_min + (base_lr * self.gamma ** bisect_right(self.milestones2, self.last_epoch) - self.eta_min) * 68 | (1 + math.cos(math.pi * curr_pos / width)) / 2 69 | for base_lr in self.base_lrs] 70 | else: 71 | return [self.eta_min + (base_lr - self.eta_min) * 72 | (1 + math.cos(math.pi * curr_pos / width)) / 2 73 | for base_lr in self.base_lrs] 74 | 75 | 76 | class CyclicLinearLR(_LRScheduler): 77 | r""" 78 | Implements reset on milestones inspired from Linear learning rate decay 79 | 80 | Set the learning rate of each parameter group using a linear decay 81 | schedule, where :math:`\eta_{max}` is set to the initial lr and 82 | :math:`T_{cur}` is the number of epochs since the last restart: 83 | .. math:: 84 | \eta_t = \eta_{min} + (\eta_{max} - \eta_{min})(1 -\frac{T_{cur}}{T_{max}}) 85 | When last_epoch > last set milestone, lr is automatically set to \eta_{min} 86 | 87 | Args: 88 | optimizer (Optimizer): Wrapped optimizer. 89 | milestones (list of ints): List of epoch indices. Must be increasing. 90 | decay_milestones(list of ints):List of increasing epoch indices. Ideally,decay values should overlap with milestone points 91 | gamma (float): factor by which to decay the max learning rate at each decay milestone 92 | eta_min (float): Minimum learning rate. Default: 1e-6 93 | last_epoch (int): The index of last epoch. Default: -1. 94 | .. _SGDR\: Stochastic Gradient Descent with Warm Restarts: 95 | https://arxiv.org/abs/1608.03983 96 | """ 97 | 98 | def __init__(self, optimizer, milestones, decay_milestones=None, gamma=0.5, eta_min=1e-6, last_epoch=-1): 99 | if not list(milestones) == sorted(milestones): 100 | raise ValueError('Milestones should be a list of' 101 | ' increasing integers. Got {}', milestones) 102 | self.eta_min = eta_min 103 | 104 | self.gamma = gamma 105 | self.milestones = milestones 106 | self.milestones2 = decay_milestones 107 | super(CyclicLinearLR, self).__init__(optimizer, last_epoch) 108 | 109 | def get_lr(self): 110 | 111 | if self.last_epoch >= self.milestones[-1]: 112 | return [self.eta_min for base_lr in self.base_lrs] 113 | 114 | idx = bisect_right(self.milestones, self.last_epoch) 115 | 116 | left_barrier = 0 if idx == 0 else self.milestones[idx-1] 117 | right_barrier = self.milestones[idx] 118 | 119 | width = right_barrier - left_barrier 120 | curr_pos = self.last_epoch - left_barrier 121 | 122 | if self.milestones2: 123 | return [self.eta_min + (base_lr * self.gamma ** bisect_right(self.milestones2, self.last_epoch) - self.eta_min) * 124 | (1. - 1.0*curr_pos / width) 125 | for base_lr in self.base_lrs] 126 | 127 | else: 128 | return [self.eta_min + (base_lr - self.eta_min) * 129 | (1. - 1.0*curr_pos / width) 130 | for base_lr in self.base_lrs] 131 | -------------------------------------------------------------------------------- /kuma_utils/torch/callbacks/snapshot.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.nn.parallel import DataParallel, DistributedDataParallel 3 | from .base import CallbackTemplate 4 | 5 | 6 | def _save_snapshot(trainer, path, 7 | save_optimizer=False, 8 | save_scheduler=False): 9 | if isinstance( 10 | trainer.model, 11 | (DataParallel, DistributedDataParallel)): 12 | module = trainer.model.module 13 | else: 14 | module = trainer.model 15 | 16 | serialized = { 17 | 'global_epoch': trainer.global_epoch, 18 | 'model': module.state_dict(), 19 | 'state': trainer.state, 20 | 'all_states': trainer._states 21 | } 22 | if save_optimizer: 23 | serialized['optimizer'] = trainer.optimizer.state_dict() 24 | if save_scheduler: 25 | serialized['scheduler'] = trainer.scheduler.state_dict() 26 | 27 | torch.save(serialized, str(path)) 28 | 29 | 30 | def _load_snapshot(trainer, path, device): 31 | checkpoint = torch.load(str(path), map_location=device) 32 | 33 | if isinstance( 34 | trainer.model, 35 | (DataParallel, DistributedDataParallel)): 36 | trainer.model.module.load_state_dict(checkpoint['model']) 37 | else: 38 | trainer.model.load_state_dict(checkpoint['model']) 39 | 40 | if hasattr(trainer, 'optimizer') and 'optimizer' in checkpoint.keys(): 41 | trainer.optimizer.load_state_dict(checkpoint['optimizer']) 42 | if hasattr(trainer, 'scheduler') and 'scheduler' in checkpoint.keys(): 43 | trainer.scheduler.load_state_dict(checkpoint['scheduler']) 44 | if hasattr(trainer, 'global_epoch'): 45 | trainer.global_epoch = checkpoint['global_epoch'] 46 | trainer.state = checkpoint['state'] 47 | trainer._states = checkpoint['all_states'] 48 | 49 | 50 | def _save_average_snapshot( 51 | trainer, path, num_snapshot=3, save_optimizer=False, save_scheduler=False): 52 | if isinstance( 53 | trainer.model, 54 | (DataParallel, DistributedDataParallel)): 55 | module = trainer.model.module 56 | else: 57 | module = trainer.model 58 | 59 | if path.exists(): 60 | try: 61 | checkpoints = torch.load(str(path), map_location='cpu')['checkpoints'] 62 | except Exception: 63 | checkpoints = [] 64 | else: 65 | checkpoints = [] 66 | 67 | if len(checkpoints) >= num_snapshot: 68 | del checkpoints[0] 69 | checkpoints.append({k: v.cpu() for k, v in module.state_dict().items()}) 70 | 71 | # average checkpoints 72 | model_weights = checkpoints[-1].copy() 73 | for k, v in model_weights.items(): 74 | model_weights[k] = v / len(checkpoints) 75 | for i in range(len(checkpoints)-1): 76 | model_weights[k] += checkpoints[i][k] / len(checkpoints) 77 | 78 | serialized = { 79 | 'global_epoch': trainer.global_epoch, 80 | 'model': model_weights, 81 | 'checkpoints': checkpoints, 82 | 'state': trainer.state, 83 | 'all_states': trainer._states 84 | } 85 | if save_optimizer: 86 | serialized['optimizer'] = trainer.optimizer.state_dict() 87 | if save_scheduler: 88 | serialized['scheduler'] = trainer.scheduler.state_dict() 89 | 90 | torch.save(serialized, str(path)) 91 | 92 | 93 | class SaveAllSnapshots(CallbackTemplate): 94 | def __init__(self, path=None, save_optimizer=False, save_scheduler=False): 95 | super().__init__() 96 | self.path = path 97 | self.save_optimizer = save_optimizer 98 | self.save_scheduler = save_scheduler 99 | 100 | def save_snapshot(self, trainer, path): 101 | if path is None: 102 | path = trainer.base_dir / f'{trainer.serial}_epoch_{trainer.global_epoch}.pt' 103 | _save_snapshot(trainer, path, self.save_optimizer, self.save_scheduler) 104 | 105 | def load_snapshot(self, trainer, path=None, device=None): 106 | if path is None or not path.exists(): 107 | # Pickup latest 108 | path = sorted(list(trainer.base_dir.glob(f'{trainer.serial}_epoch_*.pt')))[-1] 109 | 110 | if device is None: 111 | device = trainer.device 112 | _load_snapshot(trainer, path, device) 113 | 114 | 115 | class SaveSnapshot(CallbackTemplate): 116 | ''' 117 | Path priority: path argument > BestEpoch.path > trainer.snapshot_path 118 | ''' 119 | 120 | def __init__(self, path=None, save_optimizer=False, save_scheduler=False): 121 | super().__init__() 122 | self.path = path 123 | self.save_optimizer = save_optimizer 124 | self.save_scheduler = save_scheduler 125 | 126 | def save_snapshot(self, trainer, path): 127 | if path is None: 128 | path = self.path if self.path is not None else trainer.snapshot_path 129 | _save_snapshot(trainer, path, self.save_optimizer, self.save_scheduler) 130 | 131 | def load_snapshot(self, trainer, path=None, device=None): 132 | if path is None: 133 | path = self.path if self.path is not None else trainer.snapshot_path 134 | if device is None: 135 | device = trainer.device 136 | _load_snapshot(trainer, path, device) 137 | 138 | 139 | class SaveAverageSnapshot(CallbackTemplate): 140 | ''' 141 | Path priority: path argument > BestEpoch.path > trainer.snapshot_path 142 | ''' 143 | 144 | def __init__(self, num_snapshot=3, path=None, save_optimizer=False, save_scheduler=False): 145 | super().__init__() 146 | self.num_snapshot = num_snapshot 147 | self.path = path 148 | self.save_optimizer = save_optimizer 149 | self.save_scheduler = save_scheduler 150 | 151 | def save_snapshot(self, trainer, path): 152 | if path is None: 153 | path = self.path if self.path is not None else trainer.snapshot_path 154 | _save_average_snapshot(trainer, path, self.num_snapshot, self.save_optimizer, self.save_scheduler) 155 | 156 | def load_snapshot(self, trainer, path=None, device=None): 157 | if path is None: 158 | path = self.path if self.path is not None else trainer.snapshot_path 159 | if device is None: 160 | device = trainer.device 161 | _load_snapshot(trainer, path, device) 162 | -------------------------------------------------------------------------------- /kuma_utils/torch/callbacks/logger.py: -------------------------------------------------------------------------------- 1 | import logging 2 | from pathlib import Path 3 | from torch.utils.tensorboard import SummaryWriter 4 | try: 5 | import wandb 6 | WANDB = True 7 | except Exception as e: 8 | WANDB = False 9 | pass 10 | 11 | from kuma_utils.torch.utils import get_gpu_memory, get_time 12 | 13 | 14 | class TorchLogger: 15 | 16 | def __init__( 17 | self, 18 | path: str | Path, 19 | log_items: list[str] | str = [ 20 | 'epoch', 21 | 'train_loss', 'valid_loss', 22 | 'train_metric', 'valid_metric', 23 | 'train_monitor', 'valid_monitor', 24 | 'learning_rate', 'early_stop'], 25 | verbose_eval: int = 1, 26 | stdout: bool = True, 27 | file: bool = False, 28 | logger_name: str = 'TorchLogger', 29 | default_level: str = 'INFO', 30 | use_wandb: bool = False, 31 | wandb_params: dict = {'project': 'test', 'config': {}}, 32 | use_tensorboard: bool = False, 33 | tensorboard_dir: str | Path = None): 34 | if isinstance(log_items, str): 35 | log_items = log_items.split(' ') 36 | self.path = path 37 | if isinstance(self.path, str): 38 | self.path = Path(self.path) 39 | self.log_items = log_items 40 | self.verbose_eval = verbose_eval 41 | self.stdout = stdout 42 | self.file = file 43 | self.logger_name = logger_name 44 | self.level = default_level 45 | self.use_wandb = use_wandb 46 | self.wandb_params = wandb_params 47 | self.use_tensorboard = use_tensorboard 48 | self.tensorboard_dir = tensorboard_dir 49 | self.sep = ' | ' 50 | self.dataframe = [] 51 | 52 | self.system_logger = logging.getLogger(self.logger_name) 53 | for handler in self.system_logger.handlers[:]: 54 | self.system_logger.removeHandler(handler) 55 | handler.close() 56 | self.system_logger.setLevel(self.level) 57 | formatter = logging.Formatter("%(asctime)s - %(levelname)-8s - %(message)s") 58 | if self.file: 59 | fh = logging.FileHandler(self.path) 60 | fh.setFormatter(formatter) 61 | self.system_logger.addHandler(fh) 62 | if self.stdout: 63 | sh = logging.StreamHandler() 64 | sh.setFormatter(formatter) 65 | self.system_logger.addHandler(sh) 66 | for level in ['debug', 'info', 'warning', 'error', 'critical']: 67 | setattr(self, level, getattr(self.system_logger, level)) 68 | 69 | def init_wandb(self, serial: str = None): # This is called in Trainer._train() 70 | if not WANDB: 71 | raise ValueError('wandb is not installed.') 72 | wandb_params = self.wandb_params.copy() 73 | if serial is not None and 'name' not in wandb_params.keys(): # No override 74 | wandb_params.update({'name': serial}) 75 | wandb.init(**wandb_params) 76 | 77 | def init_tensorboard(self, serial): # This is called in Trainer._train() 78 | if self.tensorboard_dir is None and self.path is not None: 79 | self.tensorboard_dir = self.path.parent/'tensorboard' 80 | self.tensorboard_dir = self.tensorboard_dir/serial 81 | if not self.tensorboard_dir.exists(): 82 | self.tensorboard_dir.mkdir(exist_ok=True, parents=True) 83 | self.tb_writer = SummaryWriter(log_dir=self.tensorboard_dir) 84 | 85 | def __call__(self, log_str): 86 | self.info(log_str) 87 | 88 | def after_epoch(self, env, loader=None, loader_valid=None): 89 | ''' callback ''' 90 | epoch = env.state['epoch'] 91 | if epoch % self.verbose_eval != 0: 92 | return 93 | log_str = '' 94 | log_dict = {} 95 | for item in self.log_items: 96 | if item == 'epoch': 97 | num_len = len(str(env.max_epochs)) 98 | log_str += f'Epoch {env.global_epoch:-{num_len}}/' 99 | log_str += f'{env.max_epochs:-{num_len}}' 100 | log_dict['global_epoch'] = env.global_epoch 101 | elif item == 'early_stop': 102 | best_score = env.state['best_score'] 103 | counter = env.state['patience'] 104 | if counter > 0: 105 | log_str += f'best={best_score:.6f}(*{counter})' 106 | log_dict.update({ 107 | 'early_stopping_counter': counter, 108 | 'best_score': best_score}) 109 | elif item == 'gpu_memory': 110 | log_str += 'gpu_mem=' 111 | for gpu_i, gpu_mem in get_gpu_memory().items(): 112 | log_str += f'({gpu_i}:{int(gpu_mem)}MB)' 113 | else: 114 | val = env.state[item] 115 | if val is None: 116 | continue 117 | elif isinstance(val, list): 118 | metrics_str = '[' + \ 119 | ', '.join([f'{v:.6f}' for v in val]) + ']' 120 | if len(val) > 0: 121 | log_str += f"{item}={metrics_str}" 122 | for iv, v in enumerate(val): 123 | log_dict[f'{item}{iv}'] = v 124 | else: 125 | log_str += f"{item}={val:.6f}" 126 | log_dict[item] = val 127 | log_str += self.sep 128 | self.__call__(log_str) 129 | self.write_log(log_dict, epoch) 130 | 131 | def write_log(self, 132 | logs: dict, 133 | step: int, 134 | log_wandb: bool = True, 135 | log_tensorboard: bool = True): 136 | if self.use_wandb and log_wandb: 137 | wandb.log(logs, step=step) 138 | if self.use_tensorboard and log_tensorboard: 139 | for k, v in logs.items(): 140 | self.tb_writer.add_scalar(k, v, step) 141 | 142 | 143 | class DummyLogger: 144 | 145 | def __init__(self, path): 146 | for level in ['debug', 'info', 'warning', 'error', 'critical']: 147 | setattr(self, level, self.__call__) 148 | 149 | def __call__(self, log_str): 150 | pass 151 | 152 | def after_epoch(self, env): 153 | pass 154 | 155 | def write_log(self, logs, step): 156 | pass 157 | -------------------------------------------------------------------------------- /kuma_utils/visualization/eda.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import pandas as pd 3 | from pathlib import Path 4 | from typing import List, Optional, Union 5 | from pprint import pprint, pformat 6 | import matplotlib.pyplot as plt 7 | from matplotlib_venn import venn2 8 | import seaborn as sns 9 | 10 | from kuma_utils.preprocessing import analyze_column 11 | from kuma_utils.stats import make_demographic_table, make_summary_table 12 | 13 | try: 14 | import japanize_matplotlib 15 | except ModuleNotFoundError: 16 | pass 17 | try: 18 | import wandb 19 | WANDB = True 20 | except ModuleNotFoundError: 21 | WANDB = False 22 | 23 | 24 | def explore_data( 25 | train_df: pd.DataFrame, 26 | test_df: pd.DataFrame = None, 27 | exclude_columns: List[str] = [], 28 | normalize: bool = True, 29 | histogram_n_bins: int = 20, 30 | use_wandb: bool = False, 31 | wandb_params: dict = {}, 32 | plot: bool = True, 33 | scale_plot: Union[int, float] = 1.0, 34 | save_to: Union[str, Path] = None, 35 | verbose: bool = True, 36 | ) -> None: 37 | 38 | if use_wandb: 39 | if WANDB: 40 | wandb.init(**wandb_params) 41 | print('wandb export is enabled. Local plot will be disabled.') 42 | plot = False 43 | else: 44 | print('wandb is not installed.') 45 | use_wandb = False 46 | 47 | if verbose: 48 | print(f'data shape: {train_df.shape}, {test_df.shape if test_df is not None else None}\n') 49 | 50 | ''' Scan columns ''' 51 | train_columns = train_df.columns.tolist() 52 | if test_df is not None: 53 | test_columns = test_df.columns.tolist() 54 | include_columns = set(train_columns) & set(test_columns) 55 | include_columns = list(include_columns - set(exclude_columns)) 56 | concat_df = pd.concat([ 57 | train_df[include_columns].assign(_group='train'), 58 | test_df[include_columns].assign(_group='test')], axis=0) 59 | else: 60 | include_columns = list(set(train_columns) - set(exclude_columns)) 61 | concat_df = train_df[include_columns] 62 | 63 | n_columns = len(include_columns) 64 | if verbose: 65 | print(f'Included columns: \n{pformat(include_columns, compact=True)} ({n_columns})\n') 66 | if plot: 67 | fig_length = int(np.ceil(np.sqrt(n_columns))) 68 | fig = plt.figure( 69 | figsize=(fig_length*6*scale_plot, fig_length*3*scale_plot), 70 | tight_layout=True) 71 | 72 | ''' Check columns ''' 73 | for icol, col in enumerate(include_columns): 74 | print(f'{icol:-{len(str(n_columns))+1}} {col}: ') 75 | vals = train_df[col] 76 | if test_df is not None: 77 | test_vals = test_df[col] 78 | 79 | column_type = analyze_column(vals) 80 | if column_type == 'numerical': 81 | nan_count = vals.isnull().sum() 82 | print(f'\n train NaN: {nan_count} ({nan_count/len(vals):.3f})') 83 | if test_df is not None: 84 | test_nan_count = test_vals.isnull().sum() 85 | print(f' test NaN: {test_nan_count} ({test_nan_count/len(test_vals):.3f}\n') 86 | bin_edges = np.histogram_bin_edges( 87 | np.concatenate([vals.values, test_vals.values]), bins=histogram_n_bins) 88 | summary = make_demographic_table(concat_df, group_col='_group', display_cols=[col]).drop( 89 | ['_ks_stat', '_nan_info'], axis=1).iloc[1:] 90 | else: 91 | bin_edges = np.histogram_bin_edges(vals.values, bins=histogram_n_bins) 92 | summary = make_summary_table(concat_df, display_cols=[col]).drop( 93 | ['_nan_info'], axis=1).iloc[1:] 94 | 95 | print(summary) 96 | if use_wandb: 97 | wandb.log({f"{col}/summary": wandb.Table(dataframe=summary)}) 98 | 99 | if plot or use_wandb: 100 | if plot: 101 | ax1 = fig.add_subplot(fig_length, fig_length, icol+1) 102 | elif use_wandb: 103 | fig, ax1 = plt.subplots() 104 | 105 | sns.histplot(vals.dropna(), ax=ax1, label='train', bins=bin_edges, 106 | element="step", stat="density", common_norm=False, kde=True) 107 | if test_df is not None: 108 | sns.histplot(test_vals.dropna(), ax=ax1, label='test', bins=bin_edges, 109 | element="step", stat="density", common_norm=False, kde=True) 110 | plt.legend() 111 | 112 | if use_wandb: 113 | wandb.log({f"{col}/distribution": wandb.Image(fig)}) 114 | plt.close() 115 | 116 | elif column_type == 'categorical': 117 | if test_df is not None: 118 | summary = make_demographic_table(concat_df, group_col='_group', display_cols=[col]).drop( 119 | ['_ks_stat', '_nan_info'], axis=1).iloc[1:] 120 | else: 121 | summary = make_summary_table(concat_df, display_cols=[col]).drop( 122 | ['_nan_info'], axis=1).iloc[1:] 123 | 124 | print(summary) 125 | if use_wandb: 126 | wandb.log({f"{col}/summary": wandb.Table(dataframe=summary)}) 127 | 128 | if plot or use_wandb: 129 | if plot: 130 | ax1 = fig.add_subplot(fig_length, fig_length, icol+1) 131 | ax1.set_title(col) 132 | elif use_wandb: 133 | fig, ax1 = plt.subplots() 134 | 135 | train_uni = set(vals.unique().tolist()) 136 | if test_df is not None: 137 | test_uni = set(test_vals.unique().tolist()) 138 | common_uni = train_uni & test_uni 139 | train_uni = train_uni - common_uni 140 | test_uni = test_uni - common_uni 141 | venn2(subsets=( 142 | len(train_uni), len(test_uni), len(common_uni)), 143 | set_labels=('train', 'test'), ax=ax1 144 | ) 145 | else: 146 | venn2(subsets=( 147 | 0, 0, len(train_uni)), 148 | set_labels=('train', 'train'), ax=ax1 149 | ) 150 | plt.legend() 151 | 152 | if use_wandb: 153 | wandb.log({f"{col}/venn": wandb.Image(fig)}) 154 | plt.close() 155 | 156 | if plot: 157 | if save_to: 158 | fig.savefig(save_to) 159 | plt.show() 160 | -------------------------------------------------------------------------------- /kuma_utils/preprocessing/imputer.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import pandas as pd 3 | import lightgbm as lgb 4 | from sklearn.preprocessing import OrdinalEncoder 5 | from sklearn.impute import SimpleImputer 6 | from tqdm.auto import tqdm 7 | from .base import PreprocessingTemplate 8 | from .utils import analyze_column 9 | 10 | 11 | class LGBMImputer(PreprocessingTemplate): 12 | ''' 13 | Regression imputer using LightGBM 14 | 15 | ## Arguments 16 | target_cols: default = [] 17 | 18 | Specify additional columns to impute for cases where you want to include features without missing values, 19 | such as when missing values exist only in the test data. 20 | 21 | cat_features: default = [] 22 | 23 | By default, the algorithm will detect feature type (categorical or numerical) based on the data type. 24 | Columns specified in this argument will be considered categorical regardless of its data type. 25 | 26 | fit_params: default = {'num_boost_round': 100} 27 | 28 | Parameters for `lightgbm.train()`, such as callbacks, can be added. 29 | 30 | fit_method: default = 'fit' 31 | 32 | In case you want to run cross validation, set this to 'cv'. 33 | 34 | verbose: default = False 35 | 36 | Turning on verbose allows you to visualize the fitting process, providing a sense of reassurance. 37 | ''' 38 | def __init__( 39 | self, 40 | target_cols: list[str] = [], 41 | cat_features: list[str, int] = [], 42 | fit_params: dict = {'num_boost_round': 100}, 43 | fit_method: str = 'fit', 44 | verbose: bool = False): 45 | self.target_cols = target_cols 46 | self.cat_features = cat_features 47 | self.fit_params = fit_params 48 | self.fit_method = fit_method 49 | self.verbose = verbose 50 | self.n_features = None 51 | self.feature_names = None 52 | self._feature_scanned = False 53 | self.imputers = {} 54 | self.cat_encoder = OrdinalEncoder() 55 | 56 | def _transform_dataframe(self, X: pd.DataFrame, fit: bool = False): 57 | if fit: 58 | self.n_features = X.shape[1] 59 | if isinstance(X, pd.DataFrame): 60 | self.feature_names = X.columns.tolist() 61 | else: 62 | self.feature_names = [f'f{i}' for i in range(self.n_features)] 63 | X = pd.DataFrame(X, columns=self.feature_names) 64 | if len(self.cat_features) > 0 and isinstance(self.cat_features[0], int): 65 | self.cat_features = [self.feature_names[i] for i in self.cat_features] 66 | else: 67 | if isinstance(X, pd.DataFrame): 68 | pass 69 | else: 70 | assert X.shape[1] == self.n_features 71 | X = pd.DataFrame(X, columns=self.feature_names) 72 | return X.copy() 73 | 74 | def _scan_features(self, X: pd.DataFrame): 75 | self.col_dict = {} 76 | self.target_columns = [] 77 | self.categorical_columns = [] 78 | 79 | for col in X.columns: 80 | self.col_dict[col] = {} 81 | col_arr = X[col] 82 | col_type = analyze_column(col_arr) 83 | if col in self.cat_features: # override 84 | col_type = 'categorical' 85 | self.col_dict[col]['col_type'] = col_type 86 | 87 | if col_type == 'categorical': 88 | num_class = col_arr.dropna().nunique() 89 | self.col_dict[col]['num_class'] = num_class 90 | if num_class == 2: 91 | self.col_dict[col]['params'] = { 92 | 'objective': 'binary' 93 | } 94 | elif num_class > 2: 95 | self.col_dict[col]['params'] = { 96 | 'objective': 'multiclass', 97 | 'num_class': num_class 98 | } 99 | elif num_class == 1: 100 | self.col_dict[col]['params'] = None 101 | self.categorical_columns.append(col) 102 | else: # numerical features 103 | self.col_dict[col]['params'] = { 104 | 'objective': 'regression' 105 | } 106 | 107 | null_mask = col_arr.isnull() 108 | is_target = null_mask.sum() > 0 109 | if col in self.target_cols: # override 110 | is_target = True 111 | self.col_dict[col]['null_mask'] = null_mask 112 | if is_target: 113 | self.target_columns.append(col) 114 | 115 | self._feature_scanned = True 116 | 117 | def _fit_lgb(self, X: pd.DataFrame, col: str): 118 | col_info = self.col_dict[col] 119 | null_mask = col_info['null_mask'] 120 | _categorical_columns = self.categorical_columns.copy() 121 | if col in _categorical_columns: 122 | _categorical_columns.remove(col) 123 | if col_info['params'] is None: # Single class 124 | model = SimpleImputer() 125 | model.fit(X[col]) 126 | else: 127 | params = col_info['params'] 128 | params['verbose'] = -1 129 | x_train = X.loc[~null_mask].drop(col, axis=1) 130 | y_train = X.loc[~null_mask, col].copy() 131 | dtrain = lgb.Dataset(data=x_train, label=y_train) 132 | if self.fit_method == 'fit': 133 | model = lgb.train( 134 | params, dtrain, valid_sets=[dtrain], 135 | categorical_feature=_categorical_columns, **self.fit_params) 136 | elif self.fit_method == 'cv': 137 | res = lgb.cv( 138 | params, dtrain, return_cvbooster=True, 139 | stratified=False if col_info['params']['objective'] == 'regression' else True, 140 | categorical_feature=_categorical_columns, 141 | **self.fit_params) 142 | model = res['cvbooster'] 143 | return model 144 | 145 | def fit(self, X: pd.DataFrame, y=None): 146 | X = self._transform_dataframe(X, fit=True) 147 | self._scan_features(X) 148 | X[self.categorical_columns] = self.cat_encoder.fit_transform(X[self.categorical_columns]) 149 | 150 | if self.verbose: 151 | pbar = tqdm(self.target_columns) 152 | iterator = enumerate(pbar) 153 | else: 154 | iterator = enumerate(self.target_columns) 155 | 156 | for _, col in iterator: 157 | model = self._fit_lgb(X, col) 158 | self.imputers[col] = model 159 | if self.verbose: 160 | pbar.set_description(col) 161 | 162 | def transform(self, X: pd.DataFrame): 163 | assert self._feature_scanned 164 | output_X = self._transform_dataframe(X, fit=False) 165 | output_X[self.categorical_columns] = self.cat_encoder.transform(output_X[self.categorical_columns]) 166 | 167 | for col in self.target_columns: 168 | if self.col_dict[col]['params'] is None: 169 | model = self.imputers[col] 170 | output_X[col] = model.transform(output_X[col]) 171 | else: 172 | objective = self.col_dict[col]['params']['objective'] 173 | model = self.imputers[col] 174 | null_mask = output_X[col].isnull() 175 | x_test = output_X.loc[null_mask].drop(col, axis=1) 176 | y_test = model.predict(x_test) 177 | if self.fit_method == 'cv': 178 | y_test = np.mean(y_test, axis=0) 179 | if objective == 'multiclass': 180 | y_test = np.argmax(y_test, axis=1).astype(float) 181 | elif objective == 'binary': 182 | y_test = (y_test > 0.5).astype(float) 183 | output_X.loc[null_mask, col] = y_test 184 | 185 | output_X[self.categorical_columns] = self.cat_encoder.inverse_transform(output_X[self.categorical_columns]) 186 | return output_X 187 | 188 | def fit_transform(self, X: pd.DataFrame, y=None): 189 | self.fit(X) 190 | return self.transform(X) 191 | -------------------------------------------------------------------------------- /kuma_utils/stats/tables.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import pandas as pd 3 | pd.set_option("future.no_silent_downcasting", True) 4 | import scipy 5 | from kuma_utils.preprocessing.utils import analyze_column 6 | import matplotlib.pyplot as plt 7 | import seaborn as sns 8 | 9 | 10 | def _mean_std(arr): 11 | return arr.mean(), arr.std() 12 | 13 | 14 | def make_demographic_table( 15 | df: pd.DataFrame, 16 | group_col: str, 17 | display_cols: list[str], 18 | categorical_cols: list[str] = [], 19 | numerical_cols: list[str] = [], 20 | categorical_cutoff: int = 5, 21 | handle_missing: str = 'value', 22 | categorical_omission_count: int = 50, 23 | ): 24 | ''' 25 | Demographic table generator 26 | 27 | for numeric variables: 28 | - run KS test 29 | - if the varible follows normal distribution, run T test 30 | - if no, run Mann Whitney U test 31 | 32 | for categorical variables: 33 | - run chi-squared test 34 | ''' 35 | res = [] 36 | group_vals = df[group_col].unique() 37 | assert len(group_vals) == 2 38 | g1, g2 = group_vals 39 | g1_index = np.where(df[group_col] == g1)[0] 40 | g2_index = np.where(df[group_col] == g2)[0] 41 | res.append({ 42 | '_item': 'N', 43 | '_type': 'numerical', 44 | '_ks_stat': None, 45 | '_stat_test': None, 46 | '_nan_info': False, 47 | g1: len(g1_index), 48 | g2: len(g2_index), 49 | 'p-value': None 50 | }) 51 | 52 | if len(categorical_cols) == 0 and len(numerical_cols) == 0: 53 | numerical_cols = [] 54 | categorical_cols = [] 55 | for col in display_cols: 56 | if col == group_col: 57 | continue 58 | if df[col].nunique() <= categorical_cutoff or analyze_column(df[col]) == 'categorical': 59 | categorical_cols.append(col) 60 | else: 61 | numerical_cols.append(col) 62 | 63 | for col in categorical_cols: 64 | if handle_missing == 'value': 65 | col_arr = df[col].copy().fillna('NaN') 66 | elif handle_missing == 'ignore': 67 | col_arr = df[col].copy().dropna() 68 | else: 69 | col_arr = df[col].copy() 70 | val_cnt = col_arr.value_counts() 71 | if len(val_cnt) > categorical_omission_count: 72 | col_arr = col_arr.apply( 73 | lambda x: f'[{col}_other_categories]' if x in val_cnt.iloc[categorical_omission_count:].index else x) 74 | val_cnt = col_arr.value_counts() 75 | _le = {k: v for v, k in enumerate(val_cnt.index)} 76 | vals_all = col_arr.replace(_le).infer_objects(copy=False).values 77 | vals_g1 = vals_all[g1_index] 78 | vals_g2 = vals_all[g2_index] 79 | if len(vals_all) == 1: 80 | continue 81 | for dec_val, enc_val in _le.items(): 82 | if dec_val == 0: 83 | continue 84 | pos_g1 = (vals_g1 == enc_val).sum() 85 | pos_g2 = (vals_g2 == enc_val).sum() 86 | pval = scipy.stats.chi2_contingency([ 87 | [pos_g1, len(vals_g1) - pos_g1], 88 | [pos_g2, len(vals_g2) - pos_g2] 89 | ])[1] 90 | res.append({ 91 | '_item': f'{col}={dec_val}, n(%)', 92 | '_type': 'categorical', 93 | '_ks_stat': None, 94 | '_stat_test': 'Chi2', 95 | '_nan_info': dec_val == 'NaN', 96 | g1: f'{pos_g1} ({pos_g1/len(vals_g1)*100:.1f}%)', 97 | g2: f'{pos_g2} ({pos_g2/len(vals_g2)*100:.1f}%)', 98 | 'p-value': f'{pval:.3f}' 99 | }) 100 | 101 | for col in numerical_cols: 102 | col_arr = df[col].dropna().copy() 103 | vals_all = df[col].values.reshape(-1) 104 | vals_g1 = vals_all[g1_index] 105 | vals_g2 = vals_all[g2_index] 106 | mean_all, std_all = _mean_std(vals_all) 107 | mean_g1, std_g1 = _mean_std(vals_g1) 108 | mean_g2, std_g2 = _mean_std(vals_g2) 109 | ks_res = scipy.stats.kstest(vals_all, 'norm', args=(mean_all, std_all)) 110 | if ks_res.pvalue < 0.05: 111 | col_is_norm = False 112 | else: 113 | col_is_norm = True 114 | if col_is_norm: 115 | t_res = scipy.stats.ttest_ind(vals_g1, vals_g2, equal_var=False) 116 | res.append({ 117 | '_item': f'{col}, mean(std)', 118 | '_type': 'numerical', 119 | '_ks_stat': ks_res.pvalue, 120 | '_stat_test': 'T', 121 | '_nan_info': False, 122 | g1: f'{mean_g1:.3f} ({std_g1:.3f})', 123 | g2: f'{mean_g2:.3f} ({std_g2:.3f})', 124 | 'p-value': f'{t_res.pvalue:.3f}' 125 | }) 126 | else: 127 | u_res = scipy.stats.mannwhitneyu(vals_g1, vals_g2, alternative='two-sided') 128 | res.append({ 129 | '_item': f'{col}, mean(std)', 130 | '_type': 'numerical', 131 | '_ks_stat': ks_res.pvalue, 132 | '_stat_test': 'U', 133 | '_nan_info': False, 134 | g1: f'{mean_g1:.3f} ({std_g1:.3f})', 135 | g2: f'{mean_g2:.3f} ({std_g2:.3f})', 136 | 'p-value': f'{u_res.pvalue:.3f}' 137 | }) 138 | 139 | return pd.DataFrame(res) 140 | 141 | 142 | def make_summary_table( 143 | df: pd.DataFrame, 144 | display_cols: list[str], 145 | categorical_cols: list[str] = [], 146 | numerical_cols: list[str] = [], 147 | categorical_cutoff: int = 5, 148 | handle_missing: str = 'value', 149 | categorical_omission_count: int = 50 150 | ): 151 | ''' 152 | Summary table generator 153 | ''' 154 | res = [] 155 | res.append({ 156 | '_item': 'N', 157 | '_type': 'numerical', 158 | '_nan_info': False, 159 | '_stat': len(df) 160 | }) 161 | 162 | if len(categorical_cols) == 0 and len(numerical_cols) == 0: 163 | numerical_cols = [] 164 | categorical_cols = [] 165 | for col in display_cols: 166 | if df[col].nunique() <= categorical_cutoff or analyze_column(df[col]) == 'categorical': 167 | categorical_cols.append(col) 168 | else: 169 | numerical_cols.append(col) 170 | 171 | for col in categorical_cols: 172 | if handle_missing == 'value': 173 | col_arr = df[col].copy().fillna('NaN') 174 | elif handle_missing == 'ignore': 175 | col_arr = df[col].copy().dropna() 176 | else: 177 | col_arr = df[col].copy() 178 | val_cnt = col_arr.value_counts() 179 | if len(val_cnt) > categorical_omission_count: 180 | col_arr = col_arr.apply( 181 | lambda x: f'[{col}_other_categories]' if x in val_cnt.iloc[categorical_omission_count:].index else x) 182 | val_cnt = col_arr.value_counts() 183 | _le = {k: v for v, k in enumerate(val_cnt.index)} 184 | vals_all = col_arr.replace(_le).infer_objects(copy=False).values 185 | if len(vals_all) == 1: 186 | continue 187 | for dec_val, enc_val in _le.items(): 188 | if dec_val == 0: 189 | continue 190 | pos_all = (vals_all == enc_val).sum() 191 | res.append({ 192 | '_item': f'{col}={dec_val}, n(%)', 193 | '_type': 'categorical', 194 | '_nan_info': dec_val == 'NaN', 195 | '_stat': f'{pos_all} ({pos_all/len(vals_all)*100:.1f}%)', 196 | }) 197 | 198 | for col in numerical_cols: 199 | col_arr = df[col].dropna().copy() 200 | vals_all = df[col].values.reshape(-1) 201 | mean_all, std_all = _mean_std(vals_all) 202 | res.append({ 203 | '_item': f'{col}, mean(std)', 204 | '_type': 'numerical', 205 | '_nan_info': False, 206 | '_stat': f'{mean_all:.3f} ({std_all:.3f})', 207 | }) 208 | 209 | return pd.DataFrame(res) 210 | 211 | 212 | def love_plot( 213 | data: pd.DataFrame, 214 | match_col: str, 215 | treatment_col: str, 216 | covariates: list[str], 217 | fig_params: dict = {}, 218 | return_df: bool = False, 219 | title: str = "Covariate Balance", 220 | ): 221 | data_matched = data.query(f'{match_col} == 1') 222 | means_treated_before = data.loc[data[treatment_col] == 1, covariates].mean() 223 | means_control_before = data.loc[data[treatment_col] == 0, covariates].mean() 224 | means_treated_after = data_matched.loc[data_matched[treatment_col] == 1, covariates].mean() 225 | means_control_after = data_matched.loc[data_matched[treatment_col] == 0, covariates].mean() 226 | std_before = data[covariates].std() 227 | std_after = data_matched[covariates].std() 228 | mean_diffs = pd.concat([ 229 | pd.DataFrame( 230 | (means_treated_before - means_control_before).abs()/std_before).rename( 231 | columns={0: 'value'}).reset_index().assign(name='before matching'), 232 | pd.DataFrame( 233 | (means_treated_after - means_control_after).abs()/std_after).rename( 234 | columns={0: 'value'}).reset_index().assign(name='after matching'), 235 | ]) 236 | if return_df: 237 | return mean_diffs 238 | else: 239 | fig, ax = plt.subplots(**fig_params) 240 | sns.barplot(data=mean_diffs, x='value', y='index', hue='name', ax=ax) 241 | ax.axvline(0.1, color="red", linestyle="--") 242 | ax.set_title(title) 243 | ax.set_xlabel("Absolute Standardized Mean Differences") 244 | ax.set_ylabel("Covariates") 245 | plt.tight_layout() 246 | return fig 247 | -------------------------------------------------------------------------------- /kuma_utils/training/validator.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import pandas as pd 3 | from sklearn.model_selection import KFold 4 | from sklearn.calibration import calibration_curve 5 | from copy import deepcopy 6 | 7 | import matplotlib.pyplot as plt 8 | import matplotlib.cm as colormap 9 | import seaborn as sns 10 | try: 11 | import japanize_matplotlib 12 | except ModuleNotFoundError: 13 | pass 14 | 15 | from .trainer import Trainer, MODEL_ZOO 16 | from .logger import LGBMLogger 17 | 18 | from pathlib import Path 19 | import pickle 20 | 21 | 22 | class CrossValidator: 23 | ''' 24 | Cross validation wrapper for sklearn API models 25 | 26 | features: 27 | - Automated parameter tuning using Optuna 28 | - Most features of Trainer are included 29 | ''' 30 | def __init__(self, model=None, path=None, serial='cv0'): 31 | self.serial = serial 32 | self.snapshot_items = [ 33 | 'serial', 'models', 'is_trained', 34 | 'fold_indices', 'scores', 'best_score', 'outoffold' 35 | ] 36 | if model is not None: 37 | self.model = model 38 | self.models = [] 39 | self.fold_indices = None 40 | self.scores = [] 41 | self.iterations = [] 42 | self.outoffold = None 43 | self.prediction = None 44 | self.best_score = None 45 | self.is_trained = False 46 | elif path is not None: 47 | self.load(path) 48 | else: 49 | raise ValueError('either model or path must be given.') 50 | 51 | def _data_check(self, data): 52 | assert isinstance(data, (list, tuple)) 53 | assert len(data) >= 2 54 | for i, d in enumerate(data): 55 | assert isinstance(d, (pd.Series, pd.DataFrame, np.ndarray)) 56 | if i == 0: 57 | dshape = len(d) 58 | assert dshape == len(d) 59 | dshape = len(d) 60 | 61 | def _split_data(self, data, idx): 62 | if isinstance(data, (pd.DataFrame, pd.Series)): 63 | return data.iloc[idx] 64 | else: 65 | return data[idx] 66 | 67 | def train(self, 68 | # Dataset 69 | data, cat_features=None, groups=None, folds=KFold(n_splits=5), 70 | # Model 71 | params={}, fit_params={}, 72 | # Optuna 73 | tune_model=False, optuna_params=None, maximize=True, 74 | eval_metric=None, n_trials=None, timeout=None, 75 | # Misc 76 | logger=None, n_jobs=-1): 77 | 78 | self._data_check(data) 79 | self.models = [] 80 | 81 | if isinstance(logger, (str, Path)): 82 | logger = LGBMLogger(logger, stdout=True, file=True) 83 | elif logger is None: 84 | logger = LGBMLogger(logger, stdout=True, file=False) 85 | assert isinstance(logger, LGBMLogger) 86 | 87 | _fit_params = fit_params.copy() 88 | if 'verbose_eval' in fit_params.keys(): 89 | _fit_params.update({'verbose_eval': False}) 90 | 91 | if callable(getattr(folds, 'split', None)): 92 | # splitter 93 | fold_iter = enumerate(folds.split(X=data[0], y=data[1], groups=groups)) 94 | self.nfold = folds.get_n_splits() 95 | 96 | else: 97 | # index 98 | fold_iter = enumerate(folds) 99 | self.nfold = len(folds) 100 | 101 | self.fold_indices = [] 102 | self.scores = [] 103 | self.iterations = [] 104 | if not isinstance(eval_metric, (list, tuple)): 105 | eval_metric = [eval_metric] 106 | 107 | for fold_i, (train_idx, valid_idx) in fold_iter: 108 | 109 | logger(f'[{self.serial}] Starting fold {fold_i}') 110 | 111 | self.fold_indices.append([train_idx, valid_idx]) 112 | 113 | train_data = [self._split_data(d, train_idx) for d in data] 114 | valid_data = [self._split_data(d, valid_idx) for d in data] 115 | 116 | trn = Trainer(self.model, serial=f'{self.serial}_fold{fold_i}') 117 | trn.train( 118 | train_data=train_data, valid_data=valid_data, cat_features=cat_features, 119 | params=params, fit_params=deepcopy(_fit_params), 120 | tune_model=tune_model, optuna_params=optuna_params, 121 | maximize=maximize, eval_metric=eval_metric[0], 122 | n_trials=n_trials, timeout=timeout, 123 | logger=logger, n_jobs=n_jobs 124 | ) 125 | 126 | best_score = trn.get_best_score() 127 | best_iter = trn.get_best_iteration() 128 | 129 | all_metrics = [best_score] + [ 130 | m(trn.get_model(), valid_data) for m in eval_metric if m is not None] 131 | log_str = f'[{self.serial}] Fold {fold_i}: ' 132 | for i in range(len(all_metrics)): 133 | if i == 0: 134 | name_metric = 'eval' 135 | else: 136 | name_metric = f'monitor{i-1}' 137 | log_str += f'{name_metric}={all_metrics[i]:.6f} ' 138 | log_str += f'(iter={best_iter})' 139 | logger(log_str) 140 | 141 | if fold_i == 0: 142 | _outoffold = trn.smart_predict(valid_data[0]) 143 | self.outoffold = np.empty((data[0].shape[0], *_outoffold.shape[1:]), dtype=np.float16) 144 | self.outoffold[valid_idx] = _outoffold 145 | else: 146 | self.outoffold[valid_idx] = trn.smart_predict(valid_data[0]) 147 | 148 | self.scores.append(best_score) 149 | self.iterations.append(best_iter) 150 | self.models.append(trn) 151 | 152 | mean_score = np.mean(self.scores) 153 | se_score = np.std(self.scores) 154 | self.best_score = [mean_score, se_score] 155 | logger(f'[{self.serial}] Overall metric: {mean_score:.6f} + {se_score:.6f}') 156 | 157 | self.is_trained = True 158 | 159 | fit = train 160 | 161 | def predict(self, X, **kwargs): 162 | assert self.is_trained 163 | self.prediction = [] 164 | for trn in self.models: 165 | self.prediction.append(trn.predict(X, **kwargs)) 166 | return self.prediction 167 | 168 | def predict_proba(self, X, **kwargs): 169 | assert self.is_trained 170 | self.prediction = [] 171 | for trn in self.models: 172 | self.prediction.append(trn.predict_proba(X, **kwargs)) 173 | return self.prediction 174 | 175 | def smart_predict(self, X, **kwargs): 176 | assert self.is_trained 177 | self.prediction = [] 178 | for trn in self.models: 179 | self.prediction.append(trn.smart_predict(X, **kwargs)) 180 | return self.prediction 181 | 182 | def get_model(self): 183 | return self.models 184 | 185 | def get_feature_importance(self, importance_type='auto', normalize=True, fit_params=None, 186 | as_pandas='auto'): 187 | imps = [] 188 | for trn in self.models: 189 | imps.append(trn.get_feature_importance( 190 | importance_type, normalize, fit_params, as_pandas=False)) 191 | if as_pandas in ['auto', True]: 192 | return pd.DataFrame(imps) 193 | else: 194 | return imps 195 | 196 | def plot_feature_importance(self, importance_type='auto', normalize=True, fit_params=None, 197 | sort=True, size=5, save_to=None): 198 | imp_df = self.get_feature_importance( 199 | importance_type, normalize, fit_params, as_pandas=True) 200 | plt.figure(figsize=(size, imp_df.shape[1]/3)) 201 | order = imp_df.mean().sort_values(ascending=False).index.tolist() \ 202 | if sort else None 203 | sns.barplot(data=imp_df, orient='h', errorbar='sd', 204 | order=order, palette="coolwarm") 205 | if save_to is not None: 206 | plt.savefig(save_to) 207 | plt.show() 208 | 209 | def plot_calibration_curve(self, data, predict_params={}, size=4, save_to=None): 210 | X, y = data[0], data[1] 211 | approx = self.smart_predict(X, **predict_params) 212 | if isinstance(approx, list): 213 | approx = np.stack(approx).mean(0) 214 | fig = plt.figure(figsize=(size, size*1.5), tight_layout=True) 215 | gs = fig.add_gridspec(3, 1) 216 | ax1 = fig.add_subplot(gs[0:2, 0]) 217 | ax2 = fig.add_subplot(gs[2, 0]) 218 | fraction_of_positives, mean_predicted_value = \ 219 | calibration_curve(y, approx, n_bins=10) 220 | ax1.plot([0, 1], [0, 1], color='gray') 221 | ax1.plot(mean_predicted_value, fraction_of_positives, "s-") 222 | ax1.set_xlabel('Fraction of positives') 223 | ax1.set_ylabel('Mean of prediction values') 224 | ax1.grid() 225 | ax1.set_xlim([0.0, 1.0]) 226 | sns.histplot( 227 | approx, bins=10, element="step", stat="density", common_norm=False, ax=ax2) 228 | ax2.set_xlim([0.0, 1.0]) 229 | ax2.set_ylabel('Density') 230 | if save_to is not None: 231 | plt.savefig(save_to) 232 | plt.show() 233 | 234 | def save(self, path): 235 | with open(path, 'wb') as f: 236 | snapshot = tuple([getattr(self, item) 237 | for item in self.snapshot_items]) 238 | pickle.dump(snapshot, f) 239 | 240 | def load(self, path): 241 | with open(path, 'rb') as f: 242 | snapshot = pickle.load(f) 243 | for i, item in enumerate(self.snapshot_items): 244 | setattr(self, item, snapshot[i]) 245 | 246 | def __repr__(self): 247 | desc = f'CrossValidator: {self.serial}\n' 248 | items = ['models', 'is_trained', 'best_score'] 249 | for i in items: 250 | desc += f'{i}: {getattr(self, i)}\n' 251 | return desc 252 | 253 | def info(self): 254 | print(self.__repr__()) 255 | -------------------------------------------------------------------------------- /examples/Train_CNN_model.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": 2, 6 | "metadata": {}, 7 | "outputs": [], 8 | "source": [ 9 | "from dataclasses import dataclass, asdict\n", 10 | "from copy import deepcopy\n", 11 | "import numpy as np\n", 12 | "from sklearn.model_selection import StratifiedKFold\n", 13 | "import torch\n", 14 | "import torchvision\n", 15 | "import torchvision.transforms as transforms\n", 16 | "import torch.nn as nn\n", 17 | "import torch.nn.functional as F\n", 18 | "import torch.utils.data as D\n", 19 | "import torch.optim as optim\n", 20 | "import timm\n", 21 | "from pathlib import Path\n", 22 | "\n", 23 | "from kuma_utils.torch import TorchTrainer, TorchLogger\n", 24 | "from kuma_utils.torch.callbacks import EarlyStopping, SaveSnapshot\n", 25 | "from kuma_utils.torch.hooks import SimpleHook\n", 26 | "from kuma_utils.metrics import Accuracy\n", 27 | "from kuma_utils.torch.optimizer import SAM\n", 28 | "\n", 29 | "\n", 30 | "@dataclass\n", 31 | "class Config:\n", 32 | " num_workers: int = 32\n", 33 | " batch_size: int = 64\n", 34 | " num_epochs: int = 100\n", 35 | " early_stopping_rounds: int = 5\n" 36 | ] 37 | }, 38 | { 39 | "cell_type": "code", 40 | "execution_count": 3, 41 | "metadata": {}, 42 | "outputs": [], 43 | "source": [ 44 | "def get_dataset():\n", 45 | " transform = transforms.Compose([\n", 46 | " transforms.ToTensor(),\n", 47 | " transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))\n", 48 | " ])\n", 49 | " train = torchvision.datasets.CIFAR10(\n", 50 | " root='input', train=True, download=True, transform=transform)\n", 51 | " test = torchvision.datasets.CIFAR10(\n", 52 | " root='input', train=False, download=True, transform=transform)\n", 53 | " return train, test\n", 54 | "\n", 55 | "\n", 56 | "def split_dataset(dataset, index):\n", 57 | " new_dataset = deepcopy(dataset)\n", 58 | " new_dataset.data = new_dataset.data[index]\n", 59 | " new_dataset.targets = np.array(new_dataset.targets)[index]\n", 60 | " return new_dataset\n", 61 | "\n", 62 | "\n", 63 | "def get_model(num_classes):\n", 64 | " model = timm.create_model('tf_efficientnet_b0.ns_jft_in1k', pretrained=True, num_classes=num_classes)\n", 65 | " return model" 66 | ] 67 | }, 68 | { 69 | "cell_type": "code", 70 | "execution_count": 4, 71 | "metadata": {}, 72 | "outputs": [ 73 | { 74 | "name": "stdout", 75 | "output_type": "stream", 76 | "text": [ 77 | "Files already downloaded and verified\n", 78 | "Files already downloaded and verified\n", 79 | "classes ['airplane', 'automobile', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck']\n", 80 | "fold0 starting\n", 81 | "train: 40000 / valid: 10000\n" 82 | ] 83 | }, 84 | { 85 | "name": "stderr", 86 | "output_type": "stream", 87 | "text": [ 88 | "2025-01-12 22:49:11,553 - INFO - Mixed precision training on torch amp.\n", 89 | "2025-01-12 22:49:11,615 - INFO - Model on device: cuda\n", 90 | "2025-01-12 22:49:15,127 - INFO - Epoch 1/10 | train_loss=2.860259 | train_metric=0.372400 | valid_loss=1.918867 | valid_metric=0.480800 | learning_rate=0.002000 | | \n", 91 | "2025-01-12 22:49:18,029 - INFO - Epoch 2/10 | train_loss=1.212211 | train_metric=0.601775 | valid_loss=1.234655 | valid_metric=0.621500 | learning_rate=0.002000 | | \n", 92 | "2025-01-12 22:49:20,924 - INFO - Epoch 3/10 | train_loss=0.870441 | train_metric=0.695300 | valid_loss=0.945847 | valid_metric=0.683500 | learning_rate=0.002000 | | \n", 93 | "2025-01-12 22:49:23,770 - INFO - Epoch 4/10 | train_loss=0.671615 | train_metric=0.763250 | valid_loss=0.826375 | valid_metric=0.716000 | learning_rate=0.002000 | | \n", 94 | "2025-01-12 22:49:26,649 - INFO - Epoch 5/10 | train_loss=0.520029 | train_metric=0.816975 | valid_loss=0.822734 | valid_metric=0.723000 | learning_rate=0.001000 | | \n", 95 | "2025-01-12 22:49:29,471 - INFO - Epoch 6/10 | train_loss=0.416496 | train_metric=0.855300 | valid_loss=0.851483 | valid_metric=0.724300 | learning_rate=0.001000 | | \n", 96 | "2025-01-12 22:49:32,359 - INFO - Epoch 7/10 | train_loss=0.311541 | train_metric=0.896300 | valid_loss=0.930869 | valid_metric=0.723900 | learning_rate=0.001000 | best=0.724300(*1) | \n", 97 | "2025-01-12 22:49:35,209 - INFO - Epoch 8/10 | train_loss=0.211898 | train_metric=0.934550 | valid_loss=1.014133 | valid_metric=0.727300 | learning_rate=0.000500 | | \n", 98 | "2025-01-12 22:49:38,083 - INFO - Epoch 9/10 | train_loss=0.157332 | train_metric=0.954250 | valid_loss=1.094280 | valid_metric=0.722800 | learning_rate=0.000500 | best=0.727300(*1) | \n", 99 | "2025-01-12 22:49:40,968 - INFO - Epoch 10/10 | train_loss=0.108718 | train_metric=0.972200 | valid_loss=1.190615 | valid_metric=0.723800 | learning_rate=0.000500 | best=0.727300(*2) | \n" 100 | ] 101 | }, 102 | { 103 | "name": "stdout", 104 | "output_type": "stream", 105 | "text": [ 106 | "Folf0 score: 0.723400\n" 107 | ] 108 | } 109 | ], 110 | "source": [ 111 | "cfg = Config(\n", 112 | " num_workers=32, \n", 113 | " batch_size=2048,\n", 114 | " num_epochs=10,\n", 115 | " early_stopping_rounds=5,\n", 116 | ")\n", 117 | "export_dir = Path('results/demo')\n", 118 | "export_dir.mkdir(parents=True, exist_ok=True)\n", 119 | "\n", 120 | "train, test = get_dataset()\n", 121 | "print('classes', train.classes)\n", 122 | "\n", 123 | "predictions = []\n", 124 | "splitter = StratifiedKFold(n_splits=5, shuffle=True, random_state=0)\n", 125 | "for fold, (train_idx, valid_idx) in enumerate(\n", 126 | " splitter.split(train.targets, train.targets)):\n", 127 | "\n", 128 | " print(f'fold{fold} starting')\n", 129 | "\n", 130 | " valid_fold = split_dataset(train, valid_idx)\n", 131 | " train_fold = split_dataset(train, train_idx)\n", 132 | "\n", 133 | " print(f'train: {len(train_fold)} / valid: {len(valid_fold)}')\n", 134 | "\n", 135 | " loader_train = D.DataLoader(\n", 136 | " train_fold, batch_size=cfg.batch_size, num_workers=cfg.num_workers, \n", 137 | " shuffle=True, pin_memory=True)\n", 138 | " loader_valid = D.DataLoader(\n", 139 | " valid_fold, batch_size=cfg.batch_size, num_workers=cfg.num_workers, \n", 140 | " shuffle=False, pin_memory=True)\n", 141 | " loader_test = D.DataLoader(\n", 142 | " test, batch_size=cfg.batch_size, num_workers=cfg.num_workers, \n", 143 | " shuffle=False, pin_memory=True)\n", 144 | "\n", 145 | " model = get_model(num_classes=len(train.classes))\n", 146 | " optimizer = optim.Adam(model.parameters(), lr=2e-3)\n", 147 | " # optimizer = SAM(model.parameters(), optim.Adam, lr=2e-3)\n", 148 | " scheduler = optim.lr_scheduler.ReduceLROnPlateau(\n", 149 | " optimizer, mode='max', factor=0.5, patience=2)\n", 150 | " logger = TorchLogger(\n", 151 | " path=export_dir/f'fold{fold}.log', \n", 152 | " log_items='epoch train_loss train_metric valid_loss valid_metric learning_rate early_stop', \n", 153 | " file=True,\n", 154 | " use_wandb=False, wandb_params={\n", 155 | " 'project': 'kuma_utils_demo', \n", 156 | " 'group': 'demo_cross_validation',\n", 157 | " 'name': f'fold{fold}',\n", 158 | " 'config': asdict(cfg),\n", 159 | " },\n", 160 | " use_tensorboard=True, # In addition to epoch summaries, tensorboard can record batch summaries.\n", 161 | " tensorboard_dir=export_dir/'tensorboard',\n", 162 | " )\n", 163 | " \n", 164 | " trn = TorchTrainer(model, serial=f'fold{fold}')\n", 165 | " trn.train(\n", 166 | " loader=loader_train,\n", 167 | " loader_valid=loader_valid,\n", 168 | " criterion=nn.CrossEntropyLoss(),\n", 169 | " eval_metric=Accuracy().torch, \n", 170 | " monitor_metrics=[\n", 171 | " Accuracy().torch\n", 172 | " ],\n", 173 | " optimizer=optimizer,\n", 174 | " scheduler=scheduler,\n", 175 | " scheduler_target='valid_loss', # ReduceLROnPlateau reads metric each epoch\n", 176 | " num_epochs=cfg.num_epochs,\n", 177 | " hook=SimpleHook(\n", 178 | " evaluate_in_batch=False, clip_grad=None, sam_optimizer=False),\n", 179 | " callbacks=[\n", 180 | " EarlyStopping(\n", 181 | " patience=cfg.early_stopping_rounds, \n", 182 | " target='valid_metric', \n", 183 | " maximize=True),\n", 184 | " SaveSnapshot() # Default snapshot path: {export_dir}/{serial}.pt\n", 185 | " ],\n", 186 | " logger=logger, \n", 187 | " export_dir=export_dir,\n", 188 | " parallel=None, # Supported parallel methods: None, 'dp', 'ddp'\n", 189 | " fp16=True, # Pytorch mixed precision\n", 190 | " deterministic=True, \n", 191 | " random_state=0, \n", 192 | " progress_bar=False, # Progress bar shows batches done\n", 193 | " )\n", 194 | "\n", 195 | " oof = trn.predict(loader_valid)\n", 196 | " predictions.append(trn.predict(loader_test))\n", 197 | "\n", 198 | " score = Accuracy()(valid_fold.targets, oof)\n", 199 | " print(f'Folf{fold} score: {score:.6f}')\n", 200 | " break" 201 | ] 202 | }, 203 | { 204 | "cell_type": "markdown", 205 | "metadata": {}, 206 | "source": [ 207 | "![cifar_wandb](images/cifar_wandb.png)\n", 208 | "![cifar_tensorboard](images/cifar_tensorboard.png)" 209 | ] 210 | }, 211 | { 212 | "cell_type": "markdown", 213 | "metadata": {}, 214 | "source": [ 215 | "## `kuma_utils.torch.callbacks.EarlyStopping`\n", 216 | "```python\n", 217 | "EarlyStopping(\n", 218 | " patience: int = 5, \n", 219 | " target: str = 'valid_metric', \n", 220 | " maximize: bool = False, \n", 221 | " skip_epoch: int = 0 \n", 222 | ")\n", 223 | "```\n", 224 | "| argument | description |\n", 225 | "|------------|-----------------------------------------------------------------------------------------------------|\n", 226 | "| patience | Epochs to wait before early stop |\n", 227 | "| target | Variable name to watch (choose from `['train_loss', 'train_metric', 'valid_loss', 'valid_metric']`) |\n", 228 | "| maximize | Whether to maximize the target |\n", 229 | "| skip_epoch | Epochs to skip before early stop counter starts |\n", 230 | "\n", 231 | "\n", 232 | "## `kuma_utils.torch.TorchLogger`\n", 233 | "```python\n", 234 | "TorchLogger(\n", 235 | " path: str | Path,\n", 236 | " log_items: list[str] | str = [\n", 237 | " 'epoch',\n", 238 | " 'train_loss', 'valid_loss',\n", 239 | " 'train_metric', 'valid_metric',\n", 240 | " 'train_monitor', 'valid_monitor',\n", 241 | " 'learning_rate', 'early_stop'],\n", 242 | " verbose_eval: int = 1,\n", 243 | " stdout: bool = True,\n", 244 | " file: bool = False,\n", 245 | " logger_name: str = 'TorchLogger',\n", 246 | " default_level: str = 'INFO',\n", 247 | " use_wandb: bool = False,\n", 248 | " wandb_params: dict = {'project': 'test', 'config': {}},\n", 249 | " use_tensorboard: bool = False,\n", 250 | " tensorboard_dir: str | Path = None\n", 251 | ")\n", 252 | "```\n", 253 | "| argument | description |\n", 254 | "|--------------|--------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------|\n", 255 | "| path | Path to log. |\n", 256 | "| log_items | Items to be shown in log. Must be a combination of the following items: `['epoch', 'train_loss', 'valid_loss', 'train_metric' , 'valid_metric', 'train_monitor', 'valid_monitor', 'learning_rate', 'early_stop', 'gpu_memory']`. List or string separated by space (e.g. `'epoch valid_loss learning_rate'`).| \n", 257 | "| verbose_eval | Frequency of log (unit: epoch). |\n", 258 | "| stdout | Whether to print log. |\n", 259 | "| file | Whether to export log file to the path. (False by default) |\n", 260 | "| use_wandb | Whether to use wandb. |" 261 | ] 262 | }, 263 | { 264 | "cell_type": "markdown", 265 | "metadata": {}, 266 | "source": [ 267 | "## Hook\n", 268 | "Hook is used to specify detailed training and evaluation process.\n", 269 | "Usually it is not necessary to modify training hook, but in some cases such like: \n", 270 | "\n", 271 | "- training a Graph Neural Network which takes multiple arguments in `.forward`\n", 272 | "- training with a special metric which requires extra variables (other than predictions and targets)\n", 273 | "- calculate metrics on whole dataset (not in each mini-batch)\n", 274 | "\n", 275 | "A Hook class contains the following functions:\n", 276 | "```python\n", 277 | "class TrainHook(HookTemplate):\n", 278 | "\n", 279 | " def __init__(self, evaluate_in_batch=False):\n", 280 | " super().__init__()\n", 281 | " self.evaluate_in_batch = evaluate_in_batch\n", 282 | "\n", 283 | " def _evaluate(self, trainer, approx, target):\n", 284 | " if trainer.eval_metric is None:\n", 285 | " metric_score = None\n", 286 | " else:\n", 287 | " metric_score = trainer.eval_metric(approx, target)\n", 288 | " if isinstance(metric_score, torch.Tensor):\n", 289 | " metric_score = metric_score.item()\n", 290 | " monitor_score = []\n", 291 | " for monitor_metric in trainer.monitor_metrics:\n", 292 | " score = monitor_metric(approx, target)\n", 293 | " if isinstance(score, torch.Tensor):\n", 294 | " score = score.item()\n", 295 | " monitor_score.append(score)\n", 296 | " return metric_score, monitor_score\n", 297 | "\n", 298 | " def forward_train(self, trainer, inputs):\n", 299 | " target = inputs[-1]\n", 300 | " approx = trainer.model(*inputs[:-1])\n", 301 | " loss = trainer.criterion(approx, target)\n", 302 | " return loss, approx.detach()\n", 303 | "\n", 304 | " forward_valid = forward_train\n", 305 | "\n", 306 | " def forward_test(self, trainer, inputs):\n", 307 | " approx = trainer.model(*inputs[:-1])\n", 308 | " return approx\n", 309 | "\n", 310 | " def backprop(self, trainer, loss, inputs=None):\n", 311 | " trainer.scaler.scale(loss).backward()\n", 312 | " dispatch_clip_grad(trainer.model.parameters(), self.max_grad_norm, mode=self.clip_grad)\n", 313 | " trainer.scaler.step(trainer.optimizer)\n", 314 | " trainer.scaler.update()\n", 315 | " trainer.optimizer.zero_grad()\n", 316 | "\n", 317 | " def evaluate_batch(self, trainer, inputs, approx):\n", 318 | " target = inputs[-1]\n", 319 | " storage = trainer.epoch_storage\n", 320 | " if self.evaluate_in_batch:\n", 321 | " # Add scores to storage\n", 322 | " metric_score, monitor_score = self._evaluate(trainer, approx, target)\n", 323 | " storage['batch_metric'].append(metric_score)\n", 324 | " storage['batch_monitor'].append(monitor_score)\n", 325 | " else:\n", 326 | " # Add prediction and target to storage\n", 327 | " storage['approx'].append(approx)\n", 328 | " storage['target'].append(target)\n", 329 | "\n", 330 | " def evaluate_epoch(self, trainer):\n", 331 | " storage = trainer.epoch_storage\n", 332 | " if self.evaluate_in_batch:\n", 333 | " # Calculate mean metrics from all batches\n", 334 | " metric_total = storage['batch_metric'].mean(0)\n", 335 | " monitor_total = storage['batch_monitor'].mean(0).tolist()\n", 336 | "\n", 337 | " else: \n", 338 | " # Calculate scores\n", 339 | " metric_total, monitor_total = self._evaluate(\n", 340 | " trainer, storage['approx'], storage['target'])\n", 341 | " return metric_total, monitor_total\n", 342 | "```\n", 343 | "\n", 344 | "`.forward_train()` is called in each mini-batch in training and validation loop. \n", 345 | "This method returns loss and prediction tensors.\n", 346 | "\n", 347 | "`.forward_test()` is called in each mini-batch in inference loop. \n", 348 | "This method returns prediction values tensor.\n", 349 | "\n", 350 | "`.evaluate_batch()` is called in each mini-batch after back-propagation and optimizer.step(). \n", 351 | "This method returns nothing.\n", 352 | "\n", 353 | "`.evaluate_epoch()` is called at the end of each training and validation loop. \n", 354 | "This method returns eval_metric (scaler) and monitor metrics (list).\n", 355 | "\n", 356 | "Note that `trainer.epoch_storage` is a dicationary object you can use. \n", 357 | "In `SampleHook`, predictions and targets are added to storage in each mini-batch, \n", 358 | "and at the end of loop, metrics are calculated on the whole dataset \n", 359 | "(tensors are concatenated batch-wise automatically)." 360 | ] 361 | }, 362 | { 363 | "cell_type": "markdown", 364 | "metadata": {}, 365 | "source": [] 366 | } 367 | ], 368 | "metadata": { 369 | "kernelspec": { 370 | "display_name": ".venv", 371 | "language": "python", 372 | "name": "python3" 373 | }, 374 | "language_info": { 375 | "codemirror_mode": { 376 | "name": "ipython", 377 | "version": 3 378 | }, 379 | "file_extension": ".py", 380 | "mimetype": "text/x-python", 381 | "name": "python", 382 | "nbconvert_exporter": "python", 383 | "pygments_lexer": "ipython3", 384 | "version": "3.11.11" 385 | } 386 | }, 387 | "nbformat": 4, 388 | "nbformat_minor": 2 389 | } 390 | -------------------------------------------------------------------------------- /kuma_utils/torch/trainer.py: -------------------------------------------------------------------------------- 1 | from pathlib import Path 2 | from tqdm import tqdm 3 | from collections import defaultdict 4 | import time 5 | import pickle 6 | import subprocess 7 | import inspect 8 | import os 9 | import uuid 10 | from pprint import pformat 11 | import __main__ 12 | import pandas as pd 13 | import torch 14 | import torch.distributed as dist 15 | from torch.nn.parallel import DataParallel, DistributedDataParallel 16 | from torch.utils.data.distributed import DistributedSampler 17 | from torch.nn import SyncBatchNorm 18 | from torch.utils.data.sampler import Sampler 19 | 20 | from .utils import get_device, seed_everything, get_system_usage 21 | from .callbacks import ( 22 | TorchLogger, DummyLogger, SaveSnapshot 23 | ) 24 | from .hooks import TrainHook 25 | from . import distributed as comm 26 | from .sampler import DistributedProxySampler 27 | from .extras import DummyAutoCast, DummyGradScaler 28 | 29 | try: 30 | from torch.cuda import amp 31 | AMP = True 32 | except ModuleNotFoundError: 33 | AMP = False 34 | 35 | 36 | class TorchTrainer: 37 | ''' 38 | Simple Trainer for PyTorch models 39 | 40 | This is something similar to PyTorch Lightning, but this works with vanilla PyTorch modules. 41 | ''' 42 | 43 | def __init__(self, 44 | model, device=None, serial='exp00'): 45 | 46 | self.serial = serial 47 | self.device, self.device_ids = get_device(device) 48 | self.world_size = len(self.device_ids) 49 | self.model = model 50 | self.rank = 0 51 | self._register_ready = False 52 | self._model_ready = False 53 | self.logger = None 54 | 55 | # Implicit attributes 56 | # DDP 57 | self.ddp_sync_batch_norm = SyncBatchNorm.convert_sync_batchnorm 58 | self.ddp_average_loss = True 59 | self.ddp_params = dict( 60 | broadcast_buffers=True, 61 | static_graph=True, 62 | # find_unused_parameters=True 63 | ) 64 | self.ddp_workers = -1 65 | # MISC 66 | self.loader_to_callback = False 67 | self.debug = False 68 | self.display_ett_time = 30 69 | self.fix_nan = False 70 | 71 | def _register_callbacks(self, callbacks): 72 | self.before_epoch = [func.before_epoch for func in callbacks] 73 | self.after_epoch = [func.after_epoch for func in callbacks] 74 | self._save_snapshot = [func.save_snapshot for func in callbacks] 75 | self._load_snapshot = [func.load_snapshot for func in callbacks] 76 | 77 | def _register_hook(self, hook): 78 | self.forward_train = hook.forward_train 79 | self.forward_valid = hook.forward_valid 80 | self.forward_test = hook.forward_test 81 | self.backprop = hook.backprop 82 | self.evaluate_batch = hook.evaluate_batch 83 | self.evaluate_epoch = hook.evaluate_epoch 84 | 85 | def _configure_model(self): 86 | ''' Mixed precision ''' 87 | if self.fp16: 88 | if AMP: 89 | if self.rank == 0: 90 | self.logger('Mixed precision training on torch amp.') 91 | else: 92 | self.fp16 = False 93 | if self.rank == 0: 94 | self.logger.warning('No mixed precision training backend found.') 95 | 96 | ''' Parallel training ''' 97 | if self.parallel == 'dp': # DP on cuda 98 | self.model = DataParallel( 99 | self.model, device_ids=self.device_ids).to(self.device) 100 | if hasattr(self, 'criterion') and self.criterion is not None: 101 | self.criterion = self.criterion.to(self.device) 102 | self.logger(f'DataParallel on device: {self.device_ids}') 103 | 104 | elif self.parallel == 'ddp': # DDP on cuda 105 | self.model = self.ddp_sync_batch_norm(self.model) 106 | self.model = DistributedDataParallel( 107 | self.model.to(self.rank), device_ids=[self.rank], 108 | **self.ddp_params 109 | ) 110 | if hasattr(self, 'criterion') and self.criterion is not None: 111 | self.criterion = self.criterion.to(self.rank) 112 | if self.rank == 0: 113 | self.logger( 114 | f'DistributedDataParallel on device: {self.device_ids}') 115 | 116 | elif self.parallel is not None: 117 | raise ValueError(f'Unknown type of parallel {self.parallel}') 118 | 119 | else: # Single device 120 | self.model.to(self.device) 121 | if hasattr(self, 'criterion') and self.criterion is not None: 122 | self.criterion = self.criterion.to(self.device) 123 | self.logger(f'Model on device: {self.device}') 124 | 125 | self._model_ready = True 126 | 127 | def _configure_loader_ddp(self, loader, shuffle=True): 128 | if loader is None: 129 | return None 130 | skip_keys = ['sampler', 'batch_sampler', 'dataset_kind'] 131 | dl_args = { 132 | k: v for k, v in loader.__dict__.items() 133 | if not k.startswith('_') and k not in skip_keys 134 | } 135 | if isinstance(loader.sampler, Sampler): 136 | sampler = DistributedProxySampler( 137 | loader.sampler, num_replicas=self.world_size, rank=self.rank) 138 | else: 139 | sampler = DistributedSampler( 140 | loader.dataset, num_replicas=self.world_size, rank=self.rank, shuffle=shuffle) 141 | dl_args['sampler'] = sampler 142 | if self.ddp_workers == -1: 143 | dl_args['num_workers'] = int( 144 | dl_args['num_workers'] / self.world_size) 145 | else: 146 | dl_args['num_workers'] = self.ddp_workers 147 | if dl_args['batch_size'] % self.world_size != 0: 148 | raise ValueError( 149 | f'batch size must be a multiple of world size({self.world_size}).') 150 | dl_args['batch_size'] = int(dl_args['batch_size'] / self.world_size) 151 | return type(loader)(**dl_args) 152 | 153 | def _find_and_fix_nan(self, inputs, loss, approx, prefix=''): 154 | if torch.isnan(loss).any(): 155 | if self.rank == 0: 156 | self.logger.warning(f'{prefix} {torch.isnan(loss).sum()} NaN detected in loss.') 157 | loss = torch.nan_to_num(loss) 158 | for input_i, input_t in enumerate(inputs): 159 | if torch.isnan(input_t).any(): 160 | if self.rank == 0: 161 | self.logger.warning(f'{prefix} NaN detected in {input_i}-th input.') 162 | if torch.isnan(approx).any(): 163 | if self.rank == 0: 164 | self.logger.warning(f'{prefix} {torch.isnan(approx).sum()} NaN detected in output.') 165 | approx = torch.nan_to_num(approx) 166 | return loss, approx 167 | 168 | def _gather_storage(self): 169 | for key, val in self.epoch_storage.items(): 170 | if len(val) > 0: 171 | self.epoch_storage[key] = torch.nan_to_num(comm.gather_tensor(val)) 172 | if self.debug: 173 | self.logger.debug(f'[rank {self.rank}] gather storage {key}: {self.epoch_storage[key].shape}') 174 | 175 | def _concat_storage(self): 176 | for key, val in self.epoch_storage.items(): 177 | if len(val) > 0: 178 | if isinstance(val[0], torch.Tensor): # val: [(batch, ...), (batch, ...), ...] 179 | self.epoch_storage[key] = torch.nan_to_num(torch.cat(val)) 180 | else: # val: [value, value, ...] 181 | self.epoch_storage[key] = torch.nan_to_num(torch.tensor(val)).to(self.device) 182 | if self.debug: 183 | self.logger.debug(f'[rank {self.rank}] concat storage {key}: {self.epoch_storage[key].shape}') 184 | 185 | def _train_one_epoch(self, loader): 186 | loader_time = .0 187 | train_time = .0 188 | start_time = time.time() 189 | curr_time = time.time() 190 | 191 | self.epoch_storage = defaultdict(list) 192 | for key in ['approx', 'target', 'loss', 'batch_metric']: 193 | self.epoch_storage[key] = [] 194 | 195 | self.model.train() 196 | if self.progress_bar and self.rank == 0: 197 | iterator = enumerate(tqdm(loader, desc='train')) 198 | else: 199 | iterator = enumerate(loader) 200 | batch_total = len(loader) 201 | ett_disp = False 202 | 203 | for batch_i, inputs in iterator: 204 | loader_time += time.time() - curr_time 205 | curr_time = time.time() 206 | elapsed_time = curr_time - start_time 207 | if self.rank == 0 and self.state['epoch'] == 0 and elapsed_time > self.display_ett_time and not ett_disp: # show ETA 208 | ett = elapsed_time * batch_total // (batch_i + 1) 209 | system_usage = get_system_usage() 210 | self.logger(f'Estimated epoch training time: {int(ett)} s') 211 | self.logger(f'Maximum RAM usage: {system_usage["ram_usage"]} MB') 212 | self.logger(f'Maximum GRAM usage: {system_usage["gram_usage"]}') 213 | ett_disp = True 214 | 215 | batches_done = batch_total * (self.global_epoch-1) + batch_i 216 | inputs = [t.to(self.device) for t in inputs] 217 | 218 | # forward and backward 219 | with self.autocast: 220 | loss, approx = self.forward_train(self, inputs) 221 | self.evaluate_batch(self, inputs, approx) 222 | loss = loss / self.grad_accumulations 223 | if ((batch_i + 1) % self.grad_accumulations == 0) or ((batch_i + 1) == len(iterator)): 224 | self.backprop(self, loss, inputs) 225 | if self.batch_scheduler: 226 | self.scheduler.step() 227 | 228 | # detect nan value 229 | if self.fix_nan: 230 | loss, approx = self._find_and_fix_nan( 231 | inputs, loss, approx, 232 | prefix=f'[{self.rank}] ({batch_i}/{len(loader)})') 233 | 234 | # logging 235 | if self.parallel == 'ddp' and self.ddp_average_loss: 236 | loss_batch = comm.gather_tensor(loss.detach().clone().view(1)).mean().item() 237 | else: # Use loss on device: 0 238 | loss_batch = loss.item() 239 | 240 | learning_rate = [param_group['lr'] 241 | for param_group in self.optimizer.param_groups] 242 | logs = { 243 | 'batch_train_loss': loss_batch, 244 | 'batch_train_lr': learning_rate[0] 245 | } 246 | if len(self.epoch_storage['batch_metric']) > 0: 247 | metric = self.epoch_storage['batch_metric'][-1] 248 | logs['batch_valid_metric'] = metric 249 | if self.rank == 0: 250 | self.logger.write_log(logs, batches_done, log_wandb=False) 251 | self.epoch_storage['loss'].append(loss_batch) 252 | 253 | train_time += time.time() - curr_time 254 | curr_time = time.time() 255 | 256 | self._concat_storage() 257 | loss_total = self.epoch_storage['loss'].mean().item() 258 | 259 | if self.parallel == 'ddp': 260 | self._gather_storage() 261 | metric_total, monitor_metrics_total = self.evaluate_epoch(self) 262 | else: 263 | metric_total, monitor_metrics_total = self.evaluate_epoch(self) 264 | 265 | if self.eval_metric is None: 266 | metric_total = loss_total 267 | 268 | if self.debug: 269 | self.logger.debug(f'[rank {self.rank}] loader: {loader_time:.1f} s | train: {train_time:.1f} s') 270 | 271 | return loss_total, metric_total, monitor_metrics_total 272 | 273 | def _valid_one_epoch(self, loader): 274 | self.epoch_storage = defaultdict(list) 275 | for key in ['approx', 'target', 'loss', 'batch_metric']: 276 | self.epoch_storage[key] = [] 277 | 278 | self.model.eval() 279 | if self.progress_bar and self.rank == 0: 280 | iterator = enumerate(tqdm(loader, desc='valid')) 281 | else: 282 | iterator = enumerate(loader) 283 | 284 | with torch.no_grad(): 285 | for batch_i, inputs in iterator: 286 | batches_done = len(loader) * (self.global_epoch-1) + batch_i 287 | inputs = [t.to(self.device) for t in inputs] 288 | loss, approx = self.forward_valid(self, inputs) 289 | if self.fix_nan: 290 | loss, approx = self._find_and_fix_nan( 291 | inputs, loss, approx, 292 | prefix=f'[{self.rank}] ({batch_i}/{len(loader)})') 293 | self.evaluate_batch(self, inputs, approx) 294 | if self.parallel == 'ddp' and self.ddp_average_loss: 295 | loss_batch = comm.gather_tensor( 296 | loss.detach().clone().view(1)).mean().item() 297 | else: # Use loss on device: 0 298 | loss_batch = loss.item() 299 | 300 | logs = { 301 | 'batch_valid_loss': loss_batch 302 | } 303 | if len(self.epoch_storage['batch_metric']) > 0: 304 | metric = self.epoch_storage['batch_metric'][-1] 305 | logs['batch_valid_metric'] = metric 306 | if self.rank == 0: 307 | self.logger.write_log(logs, batches_done, log_wandb=False) 308 | self.epoch_storage['loss'].append(loss_batch) 309 | 310 | self._concat_storage() 311 | loss_total = self.epoch_storage['loss'].mean().item() 312 | 313 | if self.parallel == 'ddp': 314 | self._gather_storage() 315 | metric_total, monitor_metrics_total = self.evaluate_epoch(self) 316 | else: 317 | metric_total, monitor_metrics_total = self.evaluate_epoch(self) 318 | 319 | if self.eval_metric is None: 320 | metric_total = loss_total 321 | 322 | return loss_total, metric_total, monitor_metrics_total 323 | 324 | def _train(self, loader, loader_valid, num_epochs): 325 | assert self._register_ready 326 | assert self._model_ready 327 | if self.rank == 0 and hasattr(self.logger, 'use_wandb') and self.logger.use_wandb: 328 | self.logger.init_wandb(serial=self.serial) 329 | if self.rank == 0 and hasattr(self.logger, 'use_tensorboard') and self.logger.use_tensorboard: 330 | self.logger.init_tensorboard(serial=self.serial) 331 | if self.fp16: 332 | self.scaler = torch.GradScaler(self.device.type) 333 | self.autocast = torch.autocast(self.device.type) 334 | else: 335 | self.scaler = DummyGradScaler() 336 | self.autocast = DummyAutoCast() 337 | 338 | for epoch in range(num_epochs): 339 | if self.parallel == 'ddp': 340 | loader.sampler.set_epoch(epoch) 341 | 342 | self.state.update({'epoch': epoch}) 343 | 344 | ''' before epoch callbacks ''' 345 | for func in self.before_epoch: 346 | if self.loader_to_callback: 347 | func(self, loader, loader_valid) 348 | else: 349 | func(self) 350 | 351 | ''' Training loop ''' 352 | loss_train, metric_train, monitor_metrics_train = \ 353 | self._train_one_epoch(loader) 354 | 355 | ''' Validation loop ''' 356 | if loader_valid is None: 357 | loss_valid, metric_valid, monitor_metrics_valid = \ 358 | None, None, None 359 | else: 360 | if self.parallel == 'ddp': 361 | loader_valid.sampler.set_epoch(epoch) 362 | loss_valid, metric_valid, monitor_metrics_valid = \ 363 | self._valid_one_epoch(loader_valid) 364 | 365 | self.state.update({ 366 | 'epoch': epoch, 367 | 'train_loss': loss_train, 368 | 'train_metric': metric_train, 369 | 'train_monitor': monitor_metrics_train, 370 | 'valid_loss': loss_valid, 371 | 'valid_metric': metric_valid, 372 | 'valid_monitor': monitor_metrics_valid, 373 | 'learning_rate': [group['lr'] for group in self.optimizer.param_groups][0] 374 | }) 375 | 376 | if not self.batch_scheduler: # Epoch scheduler 377 | if self.scheduler_target is not None: 378 | self.scheduler.step(self.state[self.scheduler_target]) 379 | else: 380 | self.scheduler.step() 381 | 382 | ''' After epoch callbacks ''' 383 | after_trains = self.after_epoch + [self.logger.after_epoch] 384 | if self.rank != 0: # export logs on rank 0 device only 385 | after_trains = after_trains[:-1] 386 | for func in after_trains: 387 | if self.loader_to_callback: 388 | func(self, loader, loader_valid) 389 | else: 390 | func(self) 391 | self._states.append(self.state.copy()) 392 | 393 | if self.checkpoint and self.rank == 0: 394 | ''' Save model ''' 395 | self.save_snapshot() 396 | self.checkpoint = False 397 | 398 | if self.stop_train: 399 | ''' Early stop ''' 400 | if self.rank == 0: 401 | self.logger('Training stopped by overfit detector.') 402 | break 403 | 404 | self.global_epoch += 1 405 | 406 | if self.parallel == 'ddp': 407 | dist.destroy_process_group() 408 | 409 | def _train_ddp(self, rank, dist_url, loader, loader_valid, num_epochs): 410 | seed_everything(self.random_state, self.deterministic) 411 | self.rank = rank 412 | dist.init_process_group( 413 | backend='nccl', init_method=dist_url, 414 | world_size=self.world_size, rank=rank) 415 | comm.sync() 416 | torch.cuda.set_device(self.rank) 417 | if self.rank == 0: 418 | self.ddp_tmp_path.unlink() 419 | self.logger('All processes initialized.') 420 | 421 | ''' Configure model and loader ''' 422 | self._configure_model() 423 | loader = self._configure_loader_ddp(loader) 424 | loader_valid = self._configure_loader_ddp(loader_valid, shuffle=False) 425 | 426 | ''' Train ''' 427 | self._train(loader, loader_valid, num_epochs) 428 | 429 | def predict(self, loader, parallel=None, fp16=False, progress_bar=False): 430 | self.parallel = parallel 431 | if self.logger is None: 432 | self.logger = DummyLogger('') 433 | if not self._register_ready: # is hook and callbacks registered? 434 | raise AttributeError('Register hook and callbacks by .register() method.') 435 | if not self._model_ready: # is model configured? 436 | self.fp16 = fp16 437 | if parallel == 'ddp': 438 | raise NotImplementedError('DDP prediction is not implemented.') 439 | else: 440 | self._configure_model() 441 | 442 | if progress_bar: 443 | iterator = tqdm(loader, desc='inference') 444 | else: 445 | iterator = loader 446 | prediction = [] 447 | self.model.eval() 448 | with torch.no_grad(): 449 | for inputs in iterator: 450 | inputs = [t.to(self.device) for t in inputs] 451 | if self.fp16: 452 | with torch.autocast(self.device.type): 453 | approx = self.forward_test(self, inputs) 454 | else: 455 | approx = self.forward_test(self, inputs) 456 | prediction.append(approx.detach()) 457 | prediction = torch.cat(prediction).float().cpu().numpy() 458 | 459 | return prediction 460 | 461 | def save_snapshot(self, path=None): 462 | for func in self._save_snapshot: 463 | func(self, path) 464 | 465 | def load_snapshot(self, path=None, device=None): 466 | for func in self._load_snapshot: 467 | func(self, path, device) 468 | 469 | def register(self, hook=TrainHook(), callbacks=[SaveSnapshot()]): 470 | # This function must be called 471 | self._register_hook(hook) 472 | self._register_callbacks(callbacks) 473 | self._register_ready = True 474 | 475 | def train(self, 476 | # Essential 477 | criterion, optimizer, scheduler, loader, num_epochs, 478 | batch_scheduler=False, scheduler_target=None, 479 | hook=TrainHook(), callbacks=[SaveSnapshot()], 480 | # Evaluation 481 | loader_valid=None, eval_metric=None, monitor_metrics=[], 482 | # Snapshot 483 | export_dir=None, resume=False, 484 | # Training option 485 | fp16=False, parallel=None, grad_accumulations=1, 486 | deterministic=None, random_state=0, 487 | # Logging 488 | logger=None, progress_bar=False, 489 | **kw_args 490 | ): 491 | # Register params 492 | self.criterion = criterion 493 | self.optimizer = optimizer 494 | self.scheduler = scheduler 495 | self.batch_scheduler = batch_scheduler 496 | self.scheduler_target = scheduler_target 497 | self.grad_accumulations = grad_accumulations 498 | self.deterministic = deterministic 499 | self.random_state = random_state 500 | self.eval_metric = eval_metric 501 | self.monitor_metrics = monitor_metrics 502 | self.logger = logger 503 | self.fp16 = fp16 504 | self.parallel = parallel 505 | self.progress_bar = progress_bar 506 | self.register(hook=hook, callbacks=callbacks) 507 | 508 | # Important flags 509 | self.global_epoch = 1 510 | self.stop_train = False 511 | self.checkpoint = False 512 | self.outoffold = None 513 | self.prediction = None 514 | 515 | ''' Configure directory ''' 516 | if export_dir is None: 517 | export_dir = Path().cwd() 518 | elif isinstance(export_dir, str): 519 | export_dir = Path(export_dir).expanduser() 520 | assert len(export_dir.suffix) == 0 # export_dir must be directory 521 | export_dir.mkdir(parents=True, exist_ok=True) 522 | self.base_dir = export_dir 523 | self.snapshot_path = self.base_dir / f'{self.serial}.pt' 524 | 525 | ''' Configure loggers ''' 526 | if self.logger is None: 527 | self.logger = TorchLogger(self.base_dir / f'{self.serial}.log') 528 | elif isinstance(self.logger, (str, Path)): 529 | self.logger = TorchLogger(self.logger, file=True) 530 | elif isinstance(self.logger, TorchLogger): 531 | pass 532 | else: 533 | raise ValueError('Invalid type of logger.') 534 | if len(kw_args) > 0: 535 | self.logger.warning(f'{kw_args} will be ignored.') 536 | 537 | ''' Configure loss function and metrics ''' 538 | if criterion is None: 539 | self.logger.warning('criterion is not set. Make sure loss is calculated in the training hook.') 540 | if eval_metric is None: 541 | self.logger.warning('eval_metric is not set. criterion will be used.') 542 | if not isinstance(self.monitor_metrics, (list, tuple)): 543 | self.monitor_metrics = [self.monitor_metrics] 544 | 545 | ''' Resume training ''' 546 | if resume: 547 | self.load_snapshot(self.snapshot_path, device='cpu') 548 | self.global_epoch += 1 549 | self.logger(f'Continuing from epoch {self.global_epoch}.') 550 | 551 | ''' Train ''' 552 | self.max_epochs = self.global_epoch + num_epochs - 1 553 | self.dataframe = [] 554 | self.state = { 555 | 'train_loss': None, 556 | 'train_metric': None, 557 | 'train_monitor': None, 558 | 'valid_loss': None, 559 | 'valid_metric': None, 560 | 'valid_monitor': None, 561 | 'best_epoch': self.global_epoch, 562 | 'best_score': None, 563 | 'patience': 0, 564 | 'epoch': 0, 565 | 'learning_rate': [group['lr'] for group in self.optimizer.param_groups][0] 566 | } 567 | self._states = [] 568 | 569 | if self.parallel == 'ddp': 570 | dist_url = f'tcp://127.0.0.1:{comm.find_free_port()}' 571 | session_id = str(uuid.uuid4()) 572 | origin = Path.cwd() / __main__.__file__ 573 | self.logger(f'DDP URL :\t{dist_url}') 574 | self.logger(f'session id :\t{session_id}') 575 | self.logger(f'__main__ :\t{origin}') 576 | 577 | ddp_tmp = { 578 | 'trainer': self, 579 | 'dist_url': dist_url, 580 | 'loader': loader, 581 | 'loader_valid': loader_valid, 582 | 'num_epochs': num_epochs 583 | } 584 | ddp_tmp_path = Path(f'.ku_ddp_tmp_{session_id}') 585 | self.ddp_tmp_path = ddp_tmp_path 586 | with open(ddp_tmp_path, 'wb') as f: 587 | pickle.dump(ddp_tmp, f) 588 | ddp_worker_path = Path(inspect.getfile( 589 | self.__class__)).parent/'ddp_worker.py' 590 | env_copy = os.environ.copy() 591 | env_copy['OMP_NUM_THREADS'] = '1' 592 | 593 | command = [ 594 | 'torchrun', 595 | '--standalone', 596 | '--nnodes', '1', 597 | '--nproc_per_node', str(self.world_size), 598 | '--rdzv_endpoint', dist_url, 599 | ddp_worker_path, 600 | '--path', ddp_tmp_path, 601 | '--origin', str(origin) 602 | ] 603 | proc = subprocess.Popen( 604 | command, env=env_copy, cwd=origin.parent) 605 | proc.wait() 606 | if ddp_tmp_path.exists(): 607 | ddp_tmp_path.unlink() 608 | else: 609 | self._configure_model() 610 | self._train(loader, loader_valid, num_epochs) 611 | 612 | fit = train # for compatibility 613 | load_checkpoint = load_snapshot 614 | save_checkpoint = save_snapshot 615 | 616 | def export_dataframe(self): 617 | return pd.DataFrame(self._states) 618 | 619 | def __repr__(self): 620 | print_dict = { 621 | 'model': self.model.__class__.__name__, 622 | 'device': self.device, 623 | 'serial': self.serial 624 | } 625 | return f'TorchTrainer(\n{pformat(print_dict, compact=True, indent=2)})' 626 | --------------------------------------------------------------------------------