├── .gitignore ├── LICENSE ├── README.md ├── docker ├── Dockerfile ├── generate_container.sh └── generate_image.sh ├── requirements.txt ├── scripts └── run.sh └── src ├── __init__.py ├── base ├── __init__.py └── base_trainer.py ├── conf ├── __init__.py ├── cf_train.py └── defaults │ ├── cf_distributed.py │ └── cf_wandb.py ├── loader.py ├── main.py ├── sweep.yaml ├── trainer.py └── utils.py /.gitignore: -------------------------------------------------------------------------------- 1 | outputs/ 2 | wandb/ 3 | checkpoints/ 4 | data/ 5 | results.csv 6 | src/results.csv 7 | src/data/ 8 | src/data/cached 9 | src/checkpoints/ 10 | __pycache__ 11 | 12 | /data 13 | *.csv 14 | 15 | # Created by https://www.toptal.com/developers/gitignore/api/windows,vscode,jupyternotebooks,linux,macos 16 | # Edit at https://www.toptal.com/developers/gitignore?templates=windows,vscode,jupyternotebooks,linux,macos 17 | 18 | ### JupyterNotebooks ### 19 | # gitignore template for Jupyter Notebooks 20 | # website: http://jupyter.org/ 21 | 22 | .ipynb_checkpoints 23 | */.ipynb_checkpoints/* 24 | 25 | # IPython 26 | profile_default/ 27 | ipython_config.py 28 | 29 | # Remove previous ipynb_checkpoints 30 | # git rm -r .ipynb_checkpoints/ 31 | 32 | ### Linux ### 33 | *~ 34 | 35 | # temporary files which can be created if a process still has a handle open of a deleted file 36 | .fuse_hidden* 37 | 38 | # KDE directory preferences 39 | .directory 40 | 41 | # Linux trash folder which might appear on any partition or disk 42 | .Trash-* 43 | 44 | # .nfs files are created when an open file is removed but is still being accessed 45 | .nfs* 46 | 47 | ### macOS ### 48 | # General 49 | .DS_Store 50 | .AppleDouble 51 | .LSOverride 52 | 53 | # Icon must end with two \r 54 | Icon 55 | 56 | 57 | # Thumbnails 58 | ._* 59 | 60 | # Files that might appear in the root of a volume 61 | .DocumentRevisions-V100 62 | .fseventsd 63 | .Spotlight-V100 64 | .TemporaryItems 65 | .Trashes 66 | .VolumeIcon.icns 67 | .com.apple.timemachine.donotpresent 68 | 69 | # Directories potentially created on remote AFP share 70 | .AppleDB 71 | .AppleDesktop 72 | Network Trash Folder 73 | Temporary Items 74 | .apdisk 75 | 76 | ### vscode ### 77 | .vscode/* 78 | !.vscode/settings.json 79 | !.vscode/tasks.json 80 | !.vscode/launch.json 81 | !.vscode/extensions.json 82 | *.code-workspace 83 | 84 | ### pycharm ### 85 | .idea 86 | 87 | ### Windows ### 88 | # Windows thumbnail cache files 89 | Thumbs.db 90 | Thumbs.db:encryptable 91 | ehthumbs.db 92 | ehthumbs_vista.db 93 | 94 | # Dump file 95 | *.stackdump 96 | 97 | # Folder config file 98 | [Dd]esktop.ini 99 | 100 | # Recycle Bin used on file shares 101 | $RECYCLE.BIN/ 102 | 103 | # Windows Installer files 104 | *.cab 105 | *.msi 106 | *.msix 107 | *.msm 108 | *.msp 109 | 110 | # Windows shortcuts 111 | *.lnk 112 | 113 | # End of https://www.toptal.com/developers/gitignore/api/windows,vscode,jupyternotebooks,linux,macos 114 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2021 Takyoung Kim 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # PyTorch NLP Template with Hydra & WandB 🔥 2 | 3 | > Seems too complex to simply implement your research idea? Then how about [this template with pytorch-lightning](https://github.com/youngerous/pytorch-lightning-nlp-template)? 4 | 5 | PyTorch template for easy use! (actually, for my own😉) 6 | - This template especially focuses on **BERT-based NLP tasks**, but it can also be customized to any tasks. 7 | - This template supports **distributed-data-parallel(ddp)** and **automatic-mixed-precision(amp)** training. 8 | - This template includes a simple BERT classification code. 9 | - This template includes simple [WandB](https://wandb.ai/site) and [Hydra](https://hydra.cc/) application. 10 | - This template follows [black](https://github.com/psf/black) code formatting. 11 | 12 | ## 1. Structure 13 | ```sh 14 | root/ 15 | ├─ docker/ 16 | │ ├─ Dockerfile 17 | │ ├─ generate_container.sh 18 | │ └─ generate_image.sh 19 | ├─ scripts/ 20 | │ └─ run.sh 21 | ├─ src/ 22 | │ ├─ base/ 23 | │ │ └─ base_trainer.py 24 | │ ├─ checkpoints/ # gitignored 25 | │ │ └─ WANDB_RUN_ID/ # automatically generated 26 | │ │ ├─ best/ 27 | │ │ └─ latest/ 28 | │ ├─ conf/ 29 | │ │ ├─ defaults/ 30 | │ │ ├─ cf_distributed.py 31 | │ │ └─ cf_wandb.py 32 | │ │ └─ cf_train.py 33 | │ ├─ loader.py 34 | │ ├─ main.py 35 | │ ├─ sweep.yaml 36 | │ ├─ trainer.py 37 | │ └─ utils.py 38 | ├─ .gitignore 39 | ├─ LICENSE 40 | ├─ README.md 41 | └─ requirements.txt 42 | ``` 43 | 44 | ## 2. Requirements 45 | - torch==1.7.1 46 | - transformers==4.9.1 47 | - datasets==1.11.0 48 | - wandb==0.12.6 49 | - hydra-core==1.1.1 50 | 51 | More dependencies are written in [requirements.txt](https://github.com/youngerous/pytorch-nlp-wandb-hydra-template/blob/main/requirements.txt). 52 | 53 | ## 3. Usage 54 | 55 | ### 3.1. Set docker environments 56 | ```bash 57 | # Example: generate image 58 | $ bash docker/generate_image.sh --image_name $IMAGE_NAME 59 | 60 | # Example: generate container 61 | $ bash docker/generate_container.sh --image_name $IMAGE_NAME --container_name $CONTAINER_NAME --port_jupyter 8888 --port_tensorboard 6666 62 | 63 | # Example: start container 64 | $ docker exec -it $CONTAINER_NAME bash 65 | ``` 66 | ### 3.2. Train model 67 | ```sh 68 | $ sh scripts/run.sh 69 | ``` 70 | 71 | ## 4. Sample Experiment Results 72 | 73 | | Task | Dataset | Model | Test Accuracy | 74 | | :----------------------: | :-----: | :---: | :-----------: | 75 | | Sentiment Classification | IMDB | BERT | 93% | 76 | 77 | ## 5. LICENSE 78 | [MIT License](https://github.com/youngerous/pytorch-nlp-wandb-hydra-template/blob/main/LICENSE) 79 | -------------------------------------------------------------------------------- /docker/Dockerfile: -------------------------------------------------------------------------------- 1 | FROM nvidia/cuda:10.2-cudnn7-devel-ubuntu18.04 2 | ENV LANG=C.UTF-8 LC_ALL=C.UTF-8 3 | ENV LC_ALL=C.UTF-8 4 | 5 | 6 | ENV PATH /usr/local/nvidia/bin:/usr/local/cuda/bin:${PATH} 7 | RUN apt-get update -y 8 | 9 | RUN apt-get update \ 10 | && apt-get install -y python3-pip python3-dev \ 11 | && cd /usr/local/bin \ 12 | && ln -s /usr/bin/python3 python \ 13 | && pip3 install --upgrade pip 14 | 15 | ARG DEBIAN_FRONTEND=noninteractive 16 | ENV TZ=Asia/Seoul 17 | RUN apt-get install -y tzdata 18 | 19 | COPY requirements.txt /tmp 20 | WORKDIR /tmp 21 | RUN pip install -r requirements.txt 22 | 23 | ARG UNAME 24 | ARG UID 25 | ARG GID 26 | RUN groupadd -g $GID -o $UNAME 27 | RUN useradd -m -u $UID -g $GID -o -s /bin/bash $UNAME 28 | USER $UNAME -------------------------------------------------------------------------------- /docker/generate_container.sh: -------------------------------------------------------------------------------- 1 | for ((argpos=1; argpos<$#; argpos++)); do 2 | if [ "${!argpos}" == "--container_name" ]; then 3 | argpos_plus1=$((argpos+1)) 4 | container_name=${!argpos_plus1} 5 | fi 6 | if [ "${!argpos}" == "--image_name" ]; then 7 | argpos_plus1=$((argpos+1)) 8 | image_name=${!argpos_plus1} 9 | fi 10 | if [ "${!argpos}" == "--port_jupyter" ]; then 11 | argpos_plus1=$((argpos+1)) 12 | port_jupyter=${!argpos_plus1} 13 | fi 14 | if [ "${!argpos}" == "--port_tensorboard" ]; then 15 | argpos_plus1=$((argpos+1)) 16 | port_tensorboard=${!argpos_plus1} 17 | fi 18 | done 19 | 20 | echo "Container Name: " $container_name 21 | echo "Image Name: " $image_name 22 | echo "Jupyter Port #: " $port_jupyter 23 | echo "Tensorboard Port #: " $port_tensorboard 24 | 25 | # --gpus '"device=0,1"' 26 | docker run --gpus '"device=all"' -td --ipc=host --name $container_name\ 27 | -v ~/repo:/repo\ 28 | -v /etc/passwd:/etc/passwd\ 29 | -v /etc/localtime:/etc/localtime:ro\ 30 | -e TZ=Asia/Seoul\ 31 | -p $port_jupyter:$port_jupyter -p $port_tensorboard:$port_tensorboard $image_name 32 | -------------------------------------------------------------------------------- /docker/generate_image.sh: -------------------------------------------------------------------------------- 1 | for ((argpos=1; argpos<$#; argpos++)); do 2 | if [ "${!argpos}" == "--image_name" ]; then 3 | argpos_plus1=$((argpos+1)) 4 | image_name=${!argpos_plus1} 5 | fi 6 | done 7 | 8 | echo "Image_name: " $image_name 9 | 10 | docker build -t $image_name -f docker/Dockerfile --build-arg UNAME=$(whoami) --build-arg UID=$(id -u) --build-arg GID=$(id -g) . 11 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | torch==1.7.1 2 | transformers==4.9.1 3 | datasets==1.11.0 4 | tensorboardX==2.1 5 | wandb==0.12.6 6 | hydra-core==1.1.1 7 | scikit-learn -------------------------------------------------------------------------------- /scripts/run.sh: -------------------------------------------------------------------------------- 1 | #!/bin/sh 2 | 3 | WANDB_PROJECT='template' 4 | WANDB_ENTITY='youngerous' 5 | 6 | CKPT_ROOT='/repo/pytorch-nlp-wandb-hydra-template/src/checkpoints/' 7 | EPOCH=1 8 | BATCH_SIZE=16 9 | GRADIENT_ACCUMULATION_STEP=1 10 | EARLY_STOP_TOLERANCE=10 11 | LR='5e-5' 12 | GPU='ddp' 13 | AMP='True' 14 | 15 | python src/main.py\ 16 | ckpt_root=$CKPT_ROOT\ 17 | epoch=$EPOCH\ 18 | batch_size=$BATCH_SIZE\ 19 | gradient_accumulation_step=$GRADIENT_ACCUMULATION_STEP\ 20 | early_stop_tolerance=$EARLY_STOP_TOLERANCE\ 21 | lr=$LR\ 22 | amp=$AMP\ 23 | +wandb.project=$WANDB_PROJECT\ 24 | +wandb.entity=$WANDB_ENTITY\ 25 | +gpu=$GPU\ 26 | 27 | -------------------------------------------------------------------------------- /src/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/youngerous/pytorch-nlp-wandb-hydra-template/052201ac40237b754dc00bb575297d37a3b1471a/src/__init__.py -------------------------------------------------------------------------------- /src/base/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/youngerous/pytorch-nlp-wandb-hydra-template/052201ac40237b754dc00bb575297d37a3b1471a/src/base/__init__.py -------------------------------------------------------------------------------- /src/base/base_trainer.py: -------------------------------------------------------------------------------- 1 | import glob 2 | import logging 3 | import math 4 | import os 5 | from typing import * 6 | 7 | import torch 8 | import torch.nn as nn 9 | import wandb 10 | from torch.nn.parallel import DistributedDataParallel as DDP 11 | from tqdm import tqdm 12 | from transformers import AdamW, get_linear_schedule_with_warmup 13 | from utils import EvalManager 14 | 15 | if torch.distributed.is_available(): 16 | from torch.distributed import ReduceOp 17 | 18 | logger = logging.getLogger() 19 | 20 | 21 | class BaseTrainer(object): 22 | def __init__(self, hparams, loaders, model): 23 | self.hparams = hparams 24 | self.distributed = self.hparams.gpu.distributed 25 | self.rank: int = self.hparams.gpu.rank 26 | self.main_process: bool = self.rank in [-1, 0] 27 | self.nprocs: int = torch.cuda.device_count() 28 | self.scaler = torch.cuda.amp.GradScaler() if self.hparams.amp else None 29 | if self.distributed: 30 | assert torch.cuda.is_available() 31 | self.device = f"cuda:{self.rank}" 32 | else: 33 | self.device = "cuda" if torch.cuda.is_available() else "cpu" 34 | 35 | self.model = model 36 | self.model = model.to(self.device, non_blocking=True) 37 | if self.distributed: 38 | self.model = DDP(self.model, device_ids=[self.rank]) 39 | elif self.nprocs > 1: 40 | self.model = nn.DataParallel(self.model) 41 | 42 | self.trn_loader, self.dev_loader = loaders 43 | self.max_grad_norm = self.hparams.max_grad_norm 44 | self.gradient_accumulation_step = self.hparams.gradient_accumulation_step 45 | self.step_total = ( 46 | len(self.trn_loader) // self.gradient_accumulation_step * self.hparams.epoch 47 | ) 48 | 49 | # model saving options 50 | self.global_step = 0 51 | self.eval_step = ( 52 | int(self.step_total * self.hparams.eval_ratio) 53 | if self.hparams.eval_ratio > 0 54 | else self.step_total // self.hparams.epoch 55 | ) 56 | 57 | # early stopping options 58 | activate_earlystop = True if hparams.early_stop_tolerance > 0 else False 59 | self.eval_mgr = EvalManager( 60 | patience=hparams.early_stop_tolerance, activate_early_stop=activate_earlystop 61 | ) 62 | 63 | if self.main_process: 64 | self.hparams.ckpt_root = os.path.join(self.hparams.ckpt_root, wandb.run.id) 65 | self.log_step = hparams.log_step 66 | wandb.config.update(self.hparams) 67 | wandb.watch(self.model) 68 | wandb.run.summary["step_total"] = self.step_total 69 | 70 | def to_device(self, *tensors): 71 | bundle = [] 72 | for tensor in tensors: 73 | bundle.append(tensor.to(self.device, non_blocking=True)) 74 | return bundle if len(bundle) > 1 else bundle[0] 75 | 76 | def reduce_boolean_decision( 77 | self, 78 | decision: bool, 79 | reduce_op: Optional[Union[ReduceOp, str]] = ReduceOp.SUM, 80 | stop_option: str = "all", 81 | ) -> bool: 82 | """This function is partially modified from pytorch-lightning 83 | 84 | Args: 85 | decision (bool): Boolean value whether to early stop the process 86 | reduce_op (Optional[Union[ReduceOp, str]]): DDP reduce operator 87 | stop_option (str): Early stopping option according to each process decision 88 | 89 | Return: Reduced boolean value 90 | 91 | Ref: 92 | https://github.com/PyTorchLightning/pytorch-lightning/blob/939d56c6d69202318baf2fbf65ceda00c63363fd/pytorch_lightning/strategies/parallel.py#L113 93 | """ 94 | assert stop_option in ["all", "half", "strict"] 95 | divide_by_world_size = False 96 | group = torch.distributed.group.WORLD 97 | decision = torch.tensor(int(decision), device=self.device) 98 | 99 | if isinstance(reduce_op, str): 100 | if reduce_op.lower() in ("avg", "mean"): 101 | op = ReduceOp.SUM 102 | divide_by_world_size = True 103 | else: 104 | op = getattr(ReduceOp, reduce_op.upper()) 105 | else: 106 | op = reduce_op 107 | 108 | torch.distributed.barrier(group=group) 109 | torch.distributed.all_reduce(decision, op=op, group=group, async_op=False) 110 | if divide_by_world_size: 111 | decision = decision / torch.distributed.get_world_size(group) 112 | 113 | if stop_option == "all": # stop if every process calls stopping 114 | decision = bool(decision == self.hparams.gpu.world_size) 115 | elif stop_option == "half": # stop if more than half processes call stopping 116 | decision = bool(decision > int(self.hparams.gpu.world_size // 2)) 117 | elif stop_option == "strict": # stop if just one process calls stopping 118 | decision = bool(decision > 0) 119 | 120 | return decision 121 | 122 | def configure_optimizers(self): 123 | # optimizer 124 | decay_parameters = self.get_parameter_names(self.model, [torch.nn.LayerNorm]) 125 | decay_parameters = [name for name in decay_parameters if "bias" not in name] 126 | optimizer_grouped_parameters = [ 127 | { 128 | "params": [p for n, p in self.model.named_parameters() if n in decay_parameters], 129 | "weight_decay": self.hparams.weight_decay, 130 | }, 131 | { 132 | "params": [ 133 | p for n, p in self.model.named_parameters() if n not in decay_parameters 134 | ], 135 | "weight_decay": 0.0, 136 | }, 137 | ] 138 | optimizer = AdamW(optimizer_grouped_parameters, lr=self.hparams.lr) 139 | 140 | # lr scheduler with warmup 141 | self.warmup_steps = math.ceil(self.step_total * self.hparams.warmup_ratio) 142 | scheduler = get_linear_schedule_with_warmup( 143 | optimizer, 144 | num_warmup_steps=self.warmup_steps, 145 | num_training_steps=self.step_total, 146 | ) 147 | 148 | return optimizer, scheduler 149 | 150 | def get_parameter_names(self, model, forbidden_layer_types): 151 | """ 152 | Returns the names of the model parameters that are not inside a forbidden layer. 153 | """ 154 | result = [] 155 | for name, child in model.named_children(): 156 | result += [ 157 | f"{name}.{n}" 158 | for n in self.get_parameter_names(child, forbidden_layer_types) 159 | if not isinstance(child, tuple(forbidden_layer_types)) 160 | ] 161 | # Add model specific parameters (defined with nn.Parameter) since they are not in any child. 162 | result += list(model._parameters.keys()) 163 | return result 164 | 165 | def save_checkpoint( 166 | self, 167 | epoch: int, 168 | global_dev_loss: float, 169 | dev_loss: float, 170 | dev_acc: float, 171 | model: nn.Module, 172 | best=True, 173 | ) -> None: 174 | latest_pth = os.path.join(self.hparams.ckpt_root, "latest") 175 | os.makedirs(latest_pth, exist_ok=True) 176 | 177 | if best: 178 | logger.info(f"Saving best model ...") 179 | best_pth = os.path.join(self.hparams.ckpt_root, "best") 180 | os.makedirs(best_pth, exist_ok=True) 181 | 182 | # save best model 183 | for filename in glob.glob(os.path.join(self.hparams.ckpt_root, "best", "ckpt_*.pt")): 184 | os.remove(filename) # remove old checkpoint 185 | torch.save( 186 | model.state_dict(), 187 | os.path.join(best_pth, f"ckpt_step_{self.global_step}_loss_{dev_loss:.5f}.pt"), 188 | ) 189 | 190 | wandb.run.summary["best_step"] = self.global_step 191 | wandb.run.summary["best_epoch"] = epoch 192 | wandb.run.summary["best_dev_loss"] = dev_loss 193 | wandb.run.summary["best_dev_acc"] = dev_acc 194 | 195 | # save latest model 196 | logger.info(f"Saving latest model ...") 197 | for filename in glob.glob(os.path.join(self.hparams.ckpt_root, "latest", "ckpt_*.pt")): 198 | os.remove(filename) # remove old checkpoint 199 | torch.save( 200 | model.state_dict(), 201 | os.path.join(latest_pth, f"ckpt_step_{self.global_step}_loss_{dev_loss:.5f}.pt"), 202 | ) 203 | 204 | def fit(self): 205 | raise NotImplementedError 206 | 207 | def _train_epoch(self, *args, **kwargs): 208 | raise NotImplementedError 209 | 210 | @torch.no_grad() 211 | def validate(self, *args, **kwargs): 212 | raise NotImplementedError 213 | 214 | @torch.no_grad() 215 | def test(self, *args, **kwargs): 216 | raise NotImplementedError 217 | -------------------------------------------------------------------------------- /src/conf/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/youngerous/pytorch-nlp-wandb-hydra-template/052201ac40237b754dc00bb575297d37a3b1471a/src/conf/__init__.py -------------------------------------------------------------------------------- /src/conf/cf_train.py: -------------------------------------------------------------------------------- 1 | from typing import Any 2 | import conf.defaults.cf_distributed as cf_distributed 3 | 4 | from dataclasses import dataclass 5 | from hydra.core.config_store import ConfigStore 6 | 7 | 8 | @dataclass 9 | class TrainConf: 10 | gpu: cf_distributed.Distributed 11 | wandb: Any 12 | 13 | test: bool = True 14 | amp: bool = True # torch >= 1.6.x 15 | ckpt_root: str = "/repo/pytorch-nlp-wandb-hydra-template/src/checkpoints/" 16 | 17 | seed: int = 42 18 | workers: int = 1 19 | log_step: int = 200 20 | eval_ratio: float = 0.0 # evaluation will be done at the end of epoch if set to 0.0 21 | early_stop_tolerance: int = -1 22 | 23 | epoch: int = 10 24 | batch_size: int = 16 # it will be divided by num_gpu in DDP 25 | lr: float = 5e-5 26 | weight_decay: float = 0.1 27 | warmup_ratio: float = 0.1 28 | max_grad_norm: float = 1.0 29 | gradient_accumulation_step: int = 1 30 | 31 | 32 | cf_distributed.register_configs() 33 | 34 | cs = ConfigStore.instance() 35 | cs.store(name="train", node=TrainConf) 36 | -------------------------------------------------------------------------------- /src/conf/defaults/cf_distributed.py: -------------------------------------------------------------------------------- 1 | from dataclasses import MISSING, dataclass 2 | 3 | from hydra.core.config_store import ConfigStore 4 | 5 | 6 | @dataclass 7 | class Distributed: 8 | distributed: bool = MISSING 9 | rank: int = MISSING 10 | 11 | 12 | @dataclass 13 | class DDP(Distributed): 14 | distributed: bool = True 15 | dist_backend: str = "nccl" 16 | dist_url: str = "tcp://127.0.0.1:3456" 17 | world_size: int = 1 18 | rank: int = 0 19 | 20 | 21 | @dataclass 22 | class DP(Distributed): 23 | distributed: bool = False 24 | rank: int = -1 25 | 26 | 27 | def register_configs() -> None: 28 | cs = ConfigStore.instance() 29 | cs.store(group="gpu", name="ddp", node=DDP) 30 | cs.store(group="gpu", name="dp", node=DP) 31 | -------------------------------------------------------------------------------- /src/conf/defaults/cf_wandb.py: -------------------------------------------------------------------------------- 1 | from dataclasses import MISSING, dataclass 2 | 3 | from hydra.core.config_store import ConfigStore 4 | 5 | 6 | @dataclass 7 | class DefaultWandB: 8 | project: str = MISSING 9 | entity: str = MISSING 10 | 11 | 12 | def register_configs() -> None: 13 | cs = ConfigStore.instance() 14 | cs.store(group="wandb", name="default", node=DefaultWandB) 15 | -------------------------------------------------------------------------------- /src/loader.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from sklearn.model_selection import train_test_split 3 | from torch.utils.data import DataLoader, Dataset 4 | from torch.utils.data.distributed import DistributedSampler 5 | 6 | from utils import SequentialDistributedSampler 7 | 8 | 9 | class IMDB(Dataset): 10 | def __init__(self, tok, text, label): 11 | self.tok = tok 12 | self.text = text 13 | self.label = label 14 | 15 | assert len(self.text) == len(self.label) 16 | print(f"Load {len(self.label)} data.") 17 | 18 | def __getitem__(self, idx): 19 | src = self.tok(self.text[idx], truncation=True, padding="max_length", return_tensors="pt") 20 | return { 21 | "input_ids": src["input_ids"], 22 | "token_type_ids": src["token_type_ids"], 23 | "attention_mask": src["attention_mask"], 24 | "labels": torch.tensor(self.label[idx]), 25 | } 26 | 27 | def __len__(self): 28 | return len(self.label) 29 | 30 | 31 | def get_trn_dev_loader(dset, tok, batch_size, workers, distributed=False) -> DataLoader: 32 | """ 33 | Return: 34 | Tuple[DataLoader] 35 | """ 36 | # trn 20000, dev 5000 37 | trn_text, dev_text, trn_label, dev_label = train_test_split( 38 | dset["text"], dset["label"], test_size=0.2 39 | ) 40 | trn_dset = IMDB(tok, trn_text, trn_label) 41 | dev_dset = IMDB(tok, dev_text, dev_label) 42 | 43 | shuffle_flag = True 44 | trn_sampler, dev_sampler = None, None 45 | if distributed: 46 | trn_sampler = DistributedSampler(trn_dset) 47 | dev_sampler = SequentialDistributedSampler(dev_dset) 48 | shuffle_flag = False 49 | 50 | trn_loader = DataLoader( 51 | dataset=trn_dset, 52 | batch_size=batch_size, 53 | sampler=trn_sampler, 54 | shuffle=shuffle_flag, 55 | num_workers=workers, 56 | pin_memory=True, 57 | drop_last=True, 58 | ) 59 | dev_loader = DataLoader( 60 | dataset=dev_dset, 61 | batch_size=batch_size, 62 | sampler=dev_sampler, 63 | shuffle=False, 64 | num_workers=workers, 65 | pin_memory=True, 66 | drop_last=False, 67 | ) 68 | return (trn_loader, dev_loader) 69 | 70 | 71 | def get_tst_loader(dset, tok, batch_size, workers, distributed=False) -> DataLoader: 72 | """ 73 | Return: 74 | DataLoader 75 | """ 76 | tst_dset = IMDB(tok, dset["text"], dset["label"]) 77 | 78 | tst_sampler = None 79 | if distributed: 80 | tst_sampler = SequentialDistributedSampler(tst_dset) 81 | 82 | return DataLoader( 83 | dataset=tst_dset, 84 | batch_size=batch_size, 85 | sampler=tst_sampler, 86 | shuffle=False, 87 | num_workers=workers, 88 | pin_memory=True, 89 | drop_last=False, 90 | ) 91 | -------------------------------------------------------------------------------- /src/main.py: -------------------------------------------------------------------------------- 1 | import glob 2 | import logging 3 | import os 4 | from pprint import pformat 5 | 6 | import hydra 7 | import torch 8 | import torch.distributed as dist 9 | import torch.multiprocessing as mp 10 | import wandb 11 | from datasets import load_dataset 12 | from transformers import AutoTokenizer, BertForSequenceClassification 13 | 14 | from conf.cf_train import * 15 | from loader import get_trn_dev_loader, get_tst_loader 16 | from trainer import Trainer 17 | from utils import fix_seed, setup_logger 18 | 19 | logger = logging.getLogger() 20 | 21 | 22 | def worker(rank, hparams, ngpus_per_node: int): 23 | fix_seed(hparams.seed) 24 | batch_size = hparams.batch_size 25 | if hparams.gpu.distributed: 26 | hparams.gpu.rank = hparams.gpu.rank * ngpus_per_node + rank 27 | print(f"Use GPU {hparams.gpu.rank} for training") 28 | dist.init_process_group( 29 | backend=hparams.gpu.dist_backend, 30 | init_method=hparams.gpu.dist_url, 31 | world_size=hparams.gpu.world_size, 32 | rank=hparams.gpu.rank, 33 | ) 34 | batch_size = hparams.batch_size // ngpus_per_node 35 | 36 | # get tokenizer 37 | if hparams.gpu.distributed: 38 | if rank != 0: 39 | dist.barrier() 40 | tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased", do_lower_case=True) 41 | if rank == 0: 42 | dist.barrier() 43 | else: 44 | tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased", do_lower_case=True) 45 | 46 | # get dataloaders 47 | loaders = get_trn_dev_loader( 48 | dset=load_dataset("imdb", split="train"), 49 | tok=tokenizer, 50 | batch_size=batch_size, 51 | workers=hparams.workers, 52 | distributed=hparams.gpu.distributed, 53 | ) 54 | 55 | # get model 56 | if hparams.gpu.distributed: 57 | torch.cuda.set_device(rank) 58 | torch.cuda.empty_cache() 59 | if rank != 0: 60 | dist.barrier() 61 | model = BertForSequenceClassification.from_pretrained("bert-base-uncased") 62 | if rank == 0: 63 | dist.barrier() 64 | else: 65 | torch.cuda.empty_cache() 66 | model = BertForSequenceClassification.from_pretrained("bert-base-uncased") 67 | 68 | num_params = sum(p.numel() for p in model.parameters() if p.requires_grad) 69 | if rank in [-1, 0]: 70 | wandb.init( 71 | project=hparams.wandb.project, 72 | entity=hparams.wandb.entity, 73 | config={"ngpus": ngpus_per_node, "num_params": num_params}, 74 | ) 75 | wandb.run.name = f"ep_{hparams.epoch}_bsz_{int(hparams.batch_size)*int(hparams.gradient_accumulation_step)}_lr_{hparams.lr}_wrmup_{hparams.warmup_ratio}_accum_{hparams.gradient_accumulation_step}_amp_{hparams.amp}_ddp_{hparams.gpu.distributed}" 76 | print(f"# Model Parameters: {num_params}") 77 | print(f"# WandB Run Name: {wandb.run.name}") 78 | print(f"# WandB Save Directory: {wandb.run.dir}") 79 | print(f"# Checkpoint Save Directory: {hparams.ckpt_root}") 80 | 81 | # initialize global logger 82 | run_root = os.path.join(hparams.ckpt_root, wandb.run.id) 83 | os.makedirs(run_root, exist_ok=True) 84 | setup_logger(logger, os.path.join(run_root, "experiment.log")) 85 | logger.info("\n%s", pformat(dict(hparams))) 86 | 87 | # training phase 88 | trainer = Trainer(hparams, tokenizer, loaders, model) 89 | trainer.fit() 90 | 91 | # testing phase 92 | if rank in [-1, 0]: 93 | if hparams.test: 94 | state_dict = torch.load( 95 | glob.glob(os.path.join(hparams.ckpt_root, "best", f"ckpt_*.pt"))[0] 96 | ) 97 | test_loader = get_tst_loader( 98 | dset=load_dataset("imdb", split="test"), 99 | tok=tokenizer, 100 | batch_size=batch_size, 101 | workers=hparams.workers, 102 | distributed=False, 103 | ) 104 | trainer.test(test_loader, state_dict) 105 | 106 | report = { 107 | "best_epoch": wandb.run.summary["best_epoch"], 108 | "best_step": wandb.run.summary["best_step"], 109 | "best_dev_acc": wandb.run.summary["best_dev_acc"], 110 | "best_dev_loss": wandb.run.summary["best_dev_loss"], 111 | "tst_acc": wandb.run.summary["tst_acc"], 112 | "tst_acc": wandb.run.summary["tst_loss"], 113 | "early_stopped": wandb.run.summary["early_stopped"], 114 | } 115 | logger.info("\n%s", pformat(report)) 116 | wandb.finish() 117 | 118 | 119 | @hydra.main(config_path="conf", config_name="train") 120 | def main(cfg): 121 | ngpus_per_node = torch.cuda.device_count() 122 | 123 | if cfg.gpu.distributed: 124 | cfg.gpu.world_size = ngpus_per_node * cfg.gpu.world_size 125 | mp.spawn(worker, nprocs=ngpus_per_node, args=(cfg, ngpus_per_node)) 126 | else: 127 | worker(cfg.gpu.rank, cfg, ngpus_per_node) 128 | return 129 | 130 | 131 | if __name__ == "__main__": 132 | main() 133 | -------------------------------------------------------------------------------- /src/sweep.yaml: -------------------------------------------------------------------------------- 1 | # Hyperparameter tuning example 2 | # Config docs: https://docs.wandb.ai/guides/sweeps/configuration 3 | # Hydra + Sweep: https://github.com/wandb/client/issues/1427 4 | command: 5 | - ${env} 6 | - ${interpreter} 7 | - ${program} 8 | - ${args_no_hyphens} 9 | name: project-hparam-search 10 | description: detailed explanation 11 | program: src/main.py 12 | method: grid 13 | metric: 14 | name: dev.loss 15 | goal: minimize 16 | parameters: 17 | amp: 18 | values: [True, False] 19 | lr: 20 | values: [5e-5, 1e-5, 5e-4] 21 | batch_size: 22 | values: [8, 16] 23 | test: 24 | value: True 25 | epoch: 26 | value: 5 27 | +wandb.project: 28 | value: template 29 | +wandb.entity: 30 | value: youngerous 31 | +gpu: 32 | values: [dp, ddp] -------------------------------------------------------------------------------- /src/trainer.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import warnings 3 | from typing import * 4 | 5 | warnings.filterwarnings("ignore") 6 | 7 | from typing import Tuple 8 | 9 | import torch 10 | import torch.nn.utils as torch_utils 11 | import wandb 12 | from datasets import load_metric 13 | from torch import Tensor as T 14 | from tqdm import tqdm 15 | 16 | from base.base_trainer import BaseTrainer 17 | from utils import AverageMeter 18 | 19 | logger = logging.getLogger() 20 | 21 | 22 | class Trainer(BaseTrainer): 23 | """ 24 | This trainer inherits BaseTrainer. See base_trainer.py 25 | """ 26 | 27 | def __init__(self, hparams, tokenizer, loaders, model): 28 | super(Trainer, self).__init__(hparams, loaders, model) 29 | self.tokenizer = tokenizer 30 | self.accuracy = load_metric("accuracy") 31 | 32 | # dataloader and distributed sampler 33 | if self.distributed: 34 | self.train_sampler = self.trn_loader.sampler 35 | 36 | # optimizer, scheduler 37 | self.optimizer, self.scheduler = self.configure_optimizers() 38 | if self.main_process: 39 | wandb.run.summary["step_warmup"] = self.warmup_steps 40 | 41 | def fit(self) -> dict: 42 | # this zero gradient update is needed to avoid a warning message in warmup setting 43 | self.optimizer.zero_grad() 44 | self.optimizer.step() 45 | for epoch in tqdm(range(self.hparams.epoch), desc="epoch", disable=not self.main_process): 46 | if self.distributed: 47 | self.train_sampler.set_epoch(epoch) 48 | self._train_epoch(epoch) 49 | if self.eval_mgr.early_stop: 50 | break 51 | 52 | if self.main_process: 53 | wandb.run.summary["early_stopped"] = True if self.eval_mgr.early_stop else False 54 | 55 | def _train_epoch(self, epoch: int) -> None: 56 | if self.main_process: 57 | train_loss = AverageMeter() 58 | train_acc = AverageMeter() 59 | 60 | for step, batch in tqdm( 61 | enumerate(self.trn_loader), 62 | desc="trn_steps", 63 | total=len(self.trn_loader), 64 | disable=not self.main_process, 65 | ): 66 | self.model.train() 67 | 68 | batch_input = self._set_batch_input(batch) 69 | 70 | loss, logit = self._compute_loss(batch_input) 71 | pred = torch.argmax(logit, dim=1) 72 | loss = self._aggregate_loss(loss) 73 | self._update_loss(loss, step) 74 | acc = self.accuracy.compute( 75 | references=batch_input["labels"].data, predictions=pred.data 76 | ) 77 | 78 | if (step + 1) % self.gradient_accumulation_step != 0: 79 | continue 80 | 81 | # train logging 82 | if self.main_process: 83 | self._logging_train(epoch, train_loss, loss, train_acc, acc) 84 | 85 | # validate and logging 86 | if self.global_step != 0 and self.global_step % self.eval_step == 0: 87 | dev_loss, dev_acc = self.validate(epoch) 88 | is_best = self.eval_mgr(dev_loss, self.global_step, self.main_process) 89 | global_dev_loss = self.eval_mgr.global_dev_loss 90 | if self.main_process: 91 | wandb.log({"dev": {"loss": dev_loss}}, step=self.global_step) 92 | self.save_checkpoint( 93 | epoch, global_dev_loss, dev_loss, dev_acc, self.model, best=is_best 94 | ) 95 | 96 | if self.eval_mgr.activate_early_stop: 97 | if self.distributed: # sync early stop with all processes in ddp 98 | self.eval_mgr.early_stop = self.reduce_boolean_decision( 99 | self.eval_mgr.early_stop, stop_option="all" 100 | ) 101 | if self.eval_mgr.early_stop: 102 | if self.main_process: 103 | logger.info("### Every process called early stopping ###") 104 | break 105 | 106 | @torch.no_grad() 107 | def validate(self, epoch: int) -> float: 108 | dev_loss = AverageMeter() 109 | dev_acc = AverageMeter() 110 | 111 | self.model.eval() 112 | for step, batch in tqdm( 113 | enumerate(self.dev_loader), 114 | desc="dev_steps", 115 | total=len(self.dev_loader), 116 | disable=not self.main_process, 117 | ): 118 | # load to machine 119 | input_ids = batch["input_ids"].squeeze(1) 120 | token_type_ids = batch["token_type_ids"].squeeze(1) 121 | attention_mask = batch["attention_mask"].squeeze(1) 122 | labels = batch["labels"] 123 | 124 | input_ids = input_ids.to(self.device, non_blocking=True) 125 | token_type_ids = token_type_ids.to(self.device, non_blocking=True) 126 | attention_mask = attention_mask.to(self.device, non_blocking=True) 127 | labels = labels.to(self.device, non_blocking=True) 128 | 129 | # compute loss 130 | output = self.model( 131 | input_ids=input_ids, 132 | token_type_ids=token_type_ids, 133 | attention_mask=attention_mask, 134 | labels=labels, 135 | ) 136 | loss = output.loss.mean() 137 | dev_loss.update(loss.item()) 138 | 139 | pred = torch.argmax(output.logits, dim=1) 140 | acc = self.accuracy.compute(references=labels.data, predictions=pred.data) 141 | dev_acc.update(acc["accuracy"]) 142 | 143 | return dev_loss.avg, dev_acc.avg 144 | 145 | @torch.no_grad() 146 | def test(self, test_loader, state_dict) -> dict: 147 | test_loss = AverageMeter() 148 | test_acc = AverageMeter() 149 | 150 | self.model.load_state_dict(state_dict) 151 | self.model.eval() 152 | for step, batch in tqdm(enumerate(test_loader), desc="tst_steps", total=len(test_loader)): 153 | # load to machine 154 | input_ids = batch["input_ids"].squeeze(1) 155 | token_type_ids = batch["token_type_ids"].squeeze(1) 156 | attention_mask = batch["attention_mask"].squeeze(1) 157 | labels = batch["labels"] 158 | 159 | input_ids = input_ids.to(self.device, non_blocking=True) 160 | token_type_ids = token_type_ids.to(self.device, non_blocking=True) 161 | attention_mask = attention_mask.to(self.device, non_blocking=True) 162 | labels = labels.to(self.device, non_blocking=True) 163 | 164 | # compute loss 165 | output = self.model( 166 | input_ids=input_ids, 167 | token_type_ids=token_type_ids, 168 | attention_mask=attention_mask, 169 | labels=labels, 170 | ) 171 | loss = output.loss.mean() 172 | test_loss.update(loss.item()) 173 | 174 | pred = torch.argmax(output.logits, dim=1) 175 | acc = self.accuracy.compute(references=labels.data, predictions=pred.data) 176 | test_acc.update(acc["accuracy"]) 177 | 178 | wandb.log({"tst": {"loss": test_loss.avg, "acc": test_acc.avg}}) 179 | wandb.run.summary["tst_loss"] = test_loss.avg 180 | wandb.run.summary["tst_acc"] = test_acc.avg 181 | logger.info(f"[TST] tst loss: {test_loss.avg:.5f} | tst acc: {test_acc.avg:.5f}") 182 | 183 | def _set_batch_input(self, batch: T) -> dict: 184 | input_ids = batch["input_ids"].squeeze(1) 185 | token_type_ids = batch["token_type_ids"].squeeze(1) 186 | attention_mask = batch["attention_mask"].squeeze(1) 187 | labels = batch["labels"] 188 | input_ids, token_type_ids, attention_mask, labels = self.to_device( 189 | input_ids, token_type_ids, attention_mask, labels 190 | ) 191 | 192 | return { 193 | "input_ids": input_ids, 194 | "token_type_ids": token_type_ids, 195 | "attention_mask": attention_mask, 196 | "labels": labels, 197 | } 198 | 199 | def _compute_loss(self, batch_input: dict) -> Tuple[T, T]: 200 | if self.hparams.amp: 201 | with torch.cuda.amp.autocast(): 202 | output = self.model( 203 | input_ids=batch_input["input_ids"], 204 | token_type_ids=batch_input["token_type_ids"], 205 | attention_mask=batch_input["attention_mask"], 206 | labels=batch_input["labels"], 207 | ) 208 | else: 209 | output = self.model( 210 | input_ids=batch_input["input_ids"], 211 | token_type_ids=batch_input["token_type_ids"], 212 | attention_mask=batch_input["attention_mask"], 213 | labels=batch_input["labels"], 214 | ) 215 | 216 | return output.loss, output.logits 217 | 218 | def _aggregate_loss(self, loss: T) -> T: 219 | loss = loss / self.gradient_accumulation_step 220 | if not self.distributed: 221 | loss = loss.mean() 222 | return loss 223 | 224 | def _update_loss(self, loss: T, step: int) -> None: 225 | if self.hparams.amp: 226 | self.scaler.scale(loss).backward() 227 | if (step + 1) % self.gradient_accumulation_step == 0: 228 | self.scaler.unscale_(self.optimizer) 229 | torch_utils.clip_grad_norm_(self.model.parameters(), self.max_grad_norm) 230 | self.scaler.step(self.optimizer) 231 | self.scheduler.step() 232 | self.scaler.update() 233 | self.optimizer.zero_grad() # when accumulating, only after step() 234 | self.global_step += 1 235 | else: 236 | loss.backward() 237 | if (step + 1) % self.gradient_accumulation_step == 0: 238 | torch_utils.clip_grad_norm_(self.model.parameters(), self.max_grad_norm) 239 | self.optimizer.step() 240 | self.scheduler.step() 241 | self.optimizer.zero_grad() 242 | self.global_step += 1 243 | 244 | def _logging_train( 245 | self, 246 | epoch: int, 247 | train_loss: AverageMeter, 248 | step_loss: T, 249 | train_acc: AverageMeter, 250 | step_acc: float, 251 | ) -> None: 252 | train_loss.update(step_loss.item()) 253 | train_acc.update(step_acc["accuracy"]) 254 | if self.global_step != 0 and self.global_step % self.log_step == 0: 255 | wandb.log( 256 | { 257 | "train": {"loss": train_loss.avg, "acc": train_acc.avg}, 258 | "lr": self.optimizer.param_groups[0]["lr"], 259 | "epoch": epoch, 260 | }, 261 | step=self.global_step, 262 | ) 263 | logger.info( 264 | f"[TRN] Epoch: {epoch} | Global step: {self.global_step} | Train loss: {step_loss.item():.5f} | LR: {self.optimizer.param_groups[0]['lr']:.5f}" 265 | ) 266 | -------------------------------------------------------------------------------- /src/utils.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import math 3 | import random 4 | 5 | import numpy as np 6 | import torch 7 | import torch.distributed as dist 8 | from torch.utils.data import Sampler 9 | 10 | logger = logging.getLogger() 11 | 12 | 13 | def setup_logger(logger, save_dir): 14 | logger.setLevel(logging.INFO) 15 | if logger.hasHandlers(): 16 | logger.handlers.clear() 17 | log_formatter = logging.Formatter( 18 | "[%(thread)s] %(asctime)s [%(levelname)s] %(name)s: %(message)s" 19 | ) 20 | console = logging.FileHandler(save_dir) 21 | console.setFormatter(log_formatter) 22 | logger.addHandler(console) 23 | 24 | 25 | def fix_seed(seed: int) -> None: 26 | torch.manual_seed(seed) 27 | torch.cuda.manual_seed(seed) 28 | torch.cuda.manual_seed_all(seed) 29 | torch.backends.cudnn.deterministic = True 30 | torch.backends.cudnn.benchmark = False 31 | np.random.seed(seed) 32 | random.seed(seed) 33 | 34 | 35 | class AverageMeter: 36 | def __init__(self): 37 | self.reset() 38 | 39 | def reset(self): 40 | self.val = 0 41 | self.avg = 0 42 | self.sum = 0 43 | self.count = 0 44 | 45 | def update(self, val: float, n: int = 1): 46 | self.val = val 47 | self.sum += val * n 48 | self.count += n 49 | self.avg = self.sum / self.count 50 | 51 | 52 | class SequentialDistributedSampler(Sampler): 53 | """ 54 | Distributed Sampler that subsamples indices sequentially, making it easier to collate all results at the end. 55 | Even though we only use this sampler for eval and predict (no training), which means that the model params won't 56 | have to be synced (i.e. will not hang for synchronization even if varied number of forward passes), we still add 57 | extra samples to the sampler to make it evenly divisible (like in `DistributedSampler`) to make it easy to `gather` 58 | or `reduce` resulting tensors at the end of the loop. 59 | Ref: https://github.com/huggingface/transformers/blob/6bef764506b2b53732ee5315b80f89e6c007b584/src/transformers/trainer_pt_utils.py#L185 60 | """ 61 | 62 | def __init__(self, dataset, num_replicas=None, rank=None): 63 | if num_replicas is None: 64 | if not dist.is_available(): 65 | raise RuntimeError("Requires distributed package to be available") 66 | num_replicas = dist.get_world_size() 67 | if rank is None: 68 | if not dist.is_available(): 69 | raise RuntimeError("Requires distributed package to be available") 70 | rank = dist.get_rank() 71 | self.dataset = dataset 72 | self.num_replicas = num_replicas 73 | self.rank = rank 74 | self.num_samples = int(math.ceil(len(self.dataset) * 1.0 / self.num_replicas)) 75 | self.total_size = self.num_samples * self.num_replicas 76 | 77 | def __iter__(self): 78 | indices = list(range(len(self.dataset))) 79 | 80 | # add extra samples to make it evenly divisible 81 | indices += indices[: (self.total_size - len(indices))] 82 | assert ( 83 | len(indices) == self.total_size 84 | ), f"Indices length {len(indices)} and total size {self.total_size} mismatched" 85 | 86 | # subsample 87 | indices = indices[self.rank * self.num_samples : (self.rank + 1) * self.num_samples] 88 | assert ( 89 | len(indices) == self.num_samples 90 | ), f"Indices length {len(indices)} and sample number {self.num_samples} mismatched" 91 | 92 | return iter(indices) 93 | 94 | def __len__(self): 95 | return self.num_samples 96 | 97 | 98 | class EvalManager: 99 | """Evaluation manager including dev loss comparison and early stopping""" 100 | 101 | def __init__(self, patience=7, delta=0, trace_func=logger.info, activate_early_stop=True): 102 | """ 103 | Args: 104 | patience (int): How long to wait after last time validation loss improved 105 | delta (float): Minimum change in the monitored quantity to qualify as an improvement 106 | trace_func (function): trace print function 107 | activate_early_stop (bool): Whether to activate early stopping while training 108 | """ 109 | 110 | self.global_dev_loss = float("inf") 111 | self.delta = delta 112 | self.trace_func = trace_func 113 | self.activate_early_stop = activate_early_stop 114 | 115 | # early stopping options 116 | self.early_stop = False 117 | if self.activate_early_stop: 118 | self.patience = patience 119 | self.counter = 0 120 | 121 | def __call__(self, dev_loss: float, global_step: int, main_proccess: bool) -> bool: 122 | is_best = False 123 | if main_proccess: 124 | self.trace_func(f"[DEV] global step: {global_step} | dev loss: {dev_loss:.5f}") 125 | 126 | if dev_loss < self.global_dev_loss - self.delta: 127 | self.trace_func( 128 | f"Global dev loss decreased ({self.global_dev_loss:.5f} → {dev_loss:.5f})" 129 | ) 130 | is_best = True 131 | self.global_dev_loss = dev_loss 132 | if self.activate_early_stop: 133 | self.counter = 0 134 | 135 | else: 136 | if self.activate_early_stop: 137 | if main_proccess: 138 | self.trace_func( 139 | f"EarlyStopping counter: {self.counter} (Patience of every process: {self.patience})" 140 | ) 141 | if self.counter >= self.patience: 142 | self.early_stop = True 143 | self.counter += 1 144 | 145 | return is_best 146 | --------------------------------------------------------------------------------