├── examples ├── imagenet │ ├── configs │ │ ├── __init__.py │ │ ├── resnet50_8x_cfg.py │ │ ├── mobilenet_v3_large_8x_cfg.py │ │ ├── resnet50_8x_mlu_cfg.py │ │ ├── resnet50_clip_grad_8x_cfg.py │ │ └── resnet50_16x_cfg.py │ ├── .gitignore │ ├── requirements.txt │ ├── validate.py │ ├── README.md │ └── imagenet_runner.py ├── mnist │ ├── configs │ │ ├── __init__.py │ │ ├── mnist_cpu_cfg.py │ │ ├── mnist_1x_cfg.py │ │ └── mnist_lr_cpu_cfg.py │ ├── .gitignore │ ├── requirements.txt │ ├── README.md │ ├── conv_net.py │ ├── validate.py │ └── mnist_runner.py └── linear_regression │ ├── .gitignore │ ├── README.md │ ├── dataset.py │ ├── linear_regression_cpu_cfg.py │ └── linear_regression_runner.py ├── tests ├── random_test │ ├── .gitignore │ └── random_test.py └── registry_test │ ├── losses │ ├── registry.py │ ├── loss_a.py │ └── __init__.py │ ├── models │ ├── registry.py │ ├── model_a.py │ └── __init__.py │ └── main.py ├── requirements.txt ├── easytorch ├── version.py ├── easyoptim │ ├── __init__.py │ ├── easy_lr_scheduler.py │ └── lamb.py ├── entry_points │ ├── __init__.py │ └── easytrain.py ├── core │ ├── __init__.py │ ├── data_loader.py │ ├── optimizer_builder.py │ ├── checkpoint.py │ ├── meter_pool.py │ └── runner.py ├── launcher │ ├── __init__.py │ ├── launcher.py │ └── dist_wrap.py ├── __init__.py ├── utils │ ├── __init__.py │ ├── named_hook.py │ ├── logging.py │ ├── misc.py │ ├── dist.py │ ├── timer.py │ ├── registry.py │ ├── env.py │ └── data_prefetcher.py ├── config │ ├── __init__.py │ ├── config.py │ └── utils.py └── device.py ├── .gitignore ├── docker └── Dockerfile ├── .vscode └── settings.json ├── .github └── workflows │ ├── pylint.yml │ ├── publish-pip.yml │ └── build-docker.yml ├── setup.py ├── README_CN.md ├── README.md ├── LICENSE └── .pylintrc /examples/imagenet/configs/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /examples/mnist/configs/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /tests/random_test/.gitignore: -------------------------------------------------------------------------------- 1 | checkpoints 2 | -------------------------------------------------------------------------------- /examples/linear_regression/.gitignore: -------------------------------------------------------------------------------- 1 | checkpoints 2 | -------------------------------------------------------------------------------- /examples/mnist/.gitignore: -------------------------------------------------------------------------------- 1 | checkpoints 2 | mnist_data 3 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | torch>=1.4 2 | tensorboard 3 | tqdm 4 | -------------------------------------------------------------------------------- /examples/imagenet/.gitignore: -------------------------------------------------------------------------------- 1 | checkpoints 2 | datasets/* 3 | -------------------------------------------------------------------------------- /easytorch/version.py: -------------------------------------------------------------------------------- 1 | __version__ = '1.3.3' 2 | __all__ = ['__version__'] 3 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | __pycache__ 2 | .idea 3 | .vscode 4 | 5 | /build 6 | /dist 7 | *.egg-info 8 | -------------------------------------------------------------------------------- /easytorch/easyoptim/__init__.py: -------------------------------------------------------------------------------- 1 | from .lamb import Lamb 2 | from . import easy_lr_scheduler 3 | -------------------------------------------------------------------------------- /easytorch/entry_points/__init__.py: -------------------------------------------------------------------------------- 1 | from .easytrain import easytrain 2 | 3 | __all__ = ['easytrain'] 4 | -------------------------------------------------------------------------------- /examples/imagenet/requirements.txt: -------------------------------------------------------------------------------- 1 | easy-torch>=1.3 2 | torch>=1.4 3 | torchvision>=0.5.0 4 | tensorboard 5 | tqdm 6 | -------------------------------------------------------------------------------- /examples/mnist/requirements.txt: -------------------------------------------------------------------------------- 1 | easy-torch>=1.3 2 | torch>=1.4 3 | torchvision>=0.5.0 4 | tensorboard 5 | tqdm 6 | -------------------------------------------------------------------------------- /tests/registry_test/losses/registry.py: -------------------------------------------------------------------------------- 1 | from easytorch.utils.registry import Registry 2 | 3 | LOSS_REGISTRY = Registry('Loss') 4 | -------------------------------------------------------------------------------- /tests/registry_test/models/registry.py: -------------------------------------------------------------------------------- 1 | from easytorch.utils.registry import Registry 2 | 3 | MODEL_REGISTRY = Registry('Model') 4 | -------------------------------------------------------------------------------- /easytorch/core/__init__.py: -------------------------------------------------------------------------------- 1 | from .runner import Runner 2 | from .meter_pool import AvgMeter, MeterPool 3 | 4 | 5 | __all__ = [ 6 | 'Runner', 'AvgMeter', 'MeterPool' 7 | ] 8 | -------------------------------------------------------------------------------- /easytorch/launcher/__init__.py: -------------------------------------------------------------------------------- 1 | from .launcher import launch_runner, launch_training 2 | from .dist_wrap import dist_wrap 3 | 4 | __all__ = ['launch_runner', 'launch_training', 'dist_wrap'] 5 | -------------------------------------------------------------------------------- /examples/linear_regression/README.md: -------------------------------------------------------------------------------- 1 | # EasyTorch Example - MNIST Classification 2 | 3 | ## Train 4 | 5 | * CPU 6 | 7 | ```shell 8 | easytrain -c linear_regression_cpu_cfg.py 9 | ``` 10 | -------------------------------------------------------------------------------- /tests/registry_test/models/model_a.py: -------------------------------------------------------------------------------- 1 | from .registry import MODEL_REGISTRY 2 | 3 | 4 | @MODEL_REGISTRY.register() 5 | class ModelA: 6 | def __init__(self, param_1, param_2) -> None: 7 | print('Init ModelA, param_1: {}, param_2: {}'.format(param_1, param_2)) 8 | -------------------------------------------------------------------------------- /tests/registry_test/losses/loss_a.py: -------------------------------------------------------------------------------- 1 | from .registry import LOSS_REGISTRY 2 | 3 | 4 | @LOSS_REGISTRY.register(name='A_LOSS') 5 | class ALoss: 6 | def __init__(self, param_1, param_2) -> None: 7 | print('Init ALoss, param_1: {}, param_2: {}'.format(param_1, param_2)) 8 | -------------------------------------------------------------------------------- /tests/registry_test/models/__init__.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | from easytorch.utils.registry import scan_modules 4 | 5 | from .registry import MODEL_REGISTRY 6 | 7 | __all__ = ['MODEL_REGISTRY'] 8 | 9 | scan_modules(os.getcwd(), __file__, ['__init__.py', 'builder.py']) 10 | -------------------------------------------------------------------------------- /tests/registry_test/main.py: -------------------------------------------------------------------------------- 1 | from losses import LOSS_REGISTRY 2 | from models import MODEL_REGISTRY 3 | 4 | model_a = MODEL_REGISTRY.build('ModelA', {'param_1': 'mp1', 'param_2': 'mp2'}) 5 | a_loss = LOSS_REGISTRY.build('A_LOSS', {'param_1': 'lp1', 'param_2': 'lp2'}) 6 | l1_loss = LOSS_REGISTRY.build('L1_LOSS') 7 | print(l1_loss) 8 | -------------------------------------------------------------------------------- /easytorch/__init__.py: -------------------------------------------------------------------------------- 1 | from .config import Config, import_config 2 | from .core import Runner, AvgMeter, MeterPool 3 | from .launcher import launch_runner, launch_training 4 | from .version import __version__ 5 | 6 | __all__ = [ 7 | 'Config', 'import_config', 'Runner', 'AvgMeter', 'MeterPool', 'launch_runner', 'launch_training', '__version__' 8 | ] 9 | -------------------------------------------------------------------------------- /tests/registry_test/losses/__init__.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | from torch import nn 4 | from easytorch.utils.registry import scan_modules 5 | 6 | from .registry import LOSS_REGISTRY 7 | 8 | __all__ = ['LOSS_REGISTRY'] 9 | 10 | scan_modules(os.getcwd(), __file__, ['__init__.py', 'builder.py']) 11 | 12 | LOSS_REGISTRY.register(nn.L1Loss, 'L1_LOSS') 13 | LOSS_REGISTRY.register(nn.MSELoss, 'L2_LOSS') 14 | -------------------------------------------------------------------------------- /docker/Dockerfile: -------------------------------------------------------------------------------- 1 | ARG IMAGE_TAG 2 | 3 | FROM cnstark/pytorch:${IMAGE_TAG} 4 | 5 | ARG EASYTORCH_VERSION 6 | ENV EASYTORCH_VERSION ${EASYTORCH_VERSION} 7 | 8 | ADD . /tmp/easytorch 9 | 10 | RUN cd /tmp/easytorch && \ 11 | pip install pip --upgrade && \ 12 | pip install -r requirements.txt && \ 13 | rm -rf .eggs && \ 14 | python setup.py install && \ 15 | rm -rf /tmp/easytorch 16 | -------------------------------------------------------------------------------- /examples/mnist/README.md: -------------------------------------------------------------------------------- 1 | # EasyTorch Example - MNIST Classification 2 | 3 | ## Train 4 | 5 | * CPU 6 | 7 | ```shell 8 | easytrain -c configs/mnist_cpu_cfg.py 9 | ``` 10 | 11 | * GPU (1x) 12 | 13 | ```shell 14 | easytrain -c configs/mnist_1x_cfg.py --devices 0 15 | ``` 16 | 17 | ## Validate 18 | 19 | * CPU 20 | 21 | ```shell 22 | python validate.py -c configs/mnist_cpu_cfg.py 23 | ``` 24 | 25 | * GPU (1x) 26 | 27 | ```shell 28 | python validate.py -c configs/mnist_1x_cfg.py --devices 0 29 | ``` 30 | -------------------------------------------------------------------------------- /.vscode/settings.json: -------------------------------------------------------------------------------- 1 | { 2 | "files.trimTrailingWhitespace": true, 3 | "editor.rulers": [ 4 | 80, 5 | 120 6 | ], 7 | "editor.renderWhitespace": "all", 8 | "editor.renderControlCharacters": true, 9 | "python.formatting.provider": "yapf", 10 | "python.formatting.yapfArgs": [ 11 | "--style", 12 | "{based_on_style: google, indent_width: 4, column_limit: 120}" 13 | ], 14 | "python.linting.enabled": true, 15 | "python.linting.pylintEnabled": true, 16 | } 17 | -------------------------------------------------------------------------------- /examples/linear_regression/dataset.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.utils.data import Dataset 3 | 4 | 5 | class LinearDataset(Dataset): 6 | """LinearDataset 7 | """ 8 | 9 | def __init__(self, k: float, b: float, num: int): 10 | self.num = num 11 | self.x = torch.unsqueeze(torch.linspace(-1, 1, self.num), dim=1) 12 | self.y = k * self.x + b + torch.rand(self.x.size()) - 0.5 13 | 14 | def __getitem__(self, index): 15 | return self.x[index], self.y[index] 16 | 17 | def __len__(self): 18 | return self.num 19 | -------------------------------------------------------------------------------- /easytorch/utils/__init__.py: -------------------------------------------------------------------------------- 1 | from .env import set_visible_devices, set_tf32_mode, setup_determinacy, set_env 2 | from .dist import get_rank, get_local_rank, get_world_size, is_rank, is_master, master_only 3 | from .logging import get_logger 4 | from .named_hook import NamedForwardHook, NamedBackwardHook 5 | from .timer import Timer, TimePredictor 6 | 7 | __all__ = [ 8 | 'set_visible_devices', 'set_tf32_mode', 'setup_determinacy', 'set_env', 'get_rank', 'get_local_rank', 'get_world_size', 'is_rank', 9 | 'is_master', 'master_only', 'get_logger', 'NamedForwardHook', 'NamedBackwardHook', 'Timer', 'TimePredictor' 10 | ] 11 | -------------------------------------------------------------------------------- /easytorch/entry_points/easytrain.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | from argparse import ArgumentParser 4 | 5 | from ..launcher import launch_training 6 | 7 | 8 | def parse_args(): 9 | parser = ArgumentParser(description='Welcome to EasyTorch!') 10 | parser.add_argument('-c', '--cfg', help='training config', required=True) 11 | parser.add_argument('--node-rank', default=0, type=int, help='node rank for distributed training') 12 | parser.add_argument('--devices', help='visible devices', type=str) 13 | return parser.parse_args() 14 | 15 | 16 | def easytrain(): 17 | # work dir 18 | path = os.getcwd() 19 | sys.path.append(path) 20 | 21 | # parse arguments 22 | args = parse_args() 23 | 24 | # train 25 | launch_training(args.cfg, args.devices, args.node_rank) 26 | -------------------------------------------------------------------------------- /easytorch/utils/named_hook.py: -------------------------------------------------------------------------------- 1 | from abc import ABCMeta, abstractmethod 2 | from typing import Tuple 3 | 4 | import torch 5 | from torch import nn 6 | 7 | 8 | class NamedHook(metaclass=ABCMeta): 9 | def __init__(self, name: str): 10 | self.name = name 11 | 12 | @abstractmethod 13 | def __call__(self, module: nn.Module, *args, **kwargs): 14 | pass 15 | 16 | 17 | class NamedForwardHook(NamedHook, metaclass=ABCMeta): 18 | @abstractmethod 19 | def __call__(self, module: nn.Module, inputs: Tuple[torch.Tensor], outputs: Tuple[torch.Tensor]): 20 | pass 21 | 22 | 23 | class NamedBackwardHook(NamedHook, metaclass=ABCMeta): 24 | @abstractmethod 25 | def __call__(self, module: nn.Module, input_grads: Tuple[torch.Tensor], output_grads: Tuple[torch.Tensor]): 26 | pass 27 | -------------------------------------------------------------------------------- /.github/workflows/pylint.yml: -------------------------------------------------------------------------------- 1 | name: PyLint 2 | 3 | env: 4 | FAIL_UNDER: "9.9" 5 | 6 | on: [push, pull_request] 7 | 8 | jobs: 9 | build: 10 | 11 | runs-on: ubuntu-latest 12 | strategy: 13 | matrix: 14 | python-version: [3.9] 15 | 16 | steps: 17 | - uses: actions/checkout@v2 18 | - name: Set up Python ${{ matrix.python-version }} 19 | uses: actions/setup-python@v2 20 | with: 21 | python-version: ${{ matrix.python-version }} 22 | 23 | - name: Install dependencies 24 | run: | 25 | python -m pip install --upgrade pip 26 | pip install pylint 27 | pip install torch==1.9.1+cpu torchvision==0.10.1+cpu torchaudio==0.9.1 -f https://download.pytorch.org/whl/torch_stable.html 28 | pip install -r requirements.txt 29 | - name: Lint 30 | run: | 31 | pylint --fail-under=${FAIL_UNDER} easytorch examples tests setup.py 32 | -------------------------------------------------------------------------------- /examples/linear_regression/linear_regression_cpu_cfg.py: -------------------------------------------------------------------------------- 1 | import os 2 | from easytorch import Config 3 | 4 | from linear_regression_runner import LinearRegressionRunner 5 | 6 | CFG = Config() 7 | 8 | CFG.DESC = 'linear_regression' 9 | CFG.RUNNER = LinearRegressionRunner 10 | CFG.DEVICE = 'cpu' 11 | 12 | CFG.MODEL = Config() 13 | CFG.MODEL.NAME = 'linear' 14 | 15 | CFG.TRAIN = Config() 16 | 17 | CFG.TRAIN.NUM_EPOCHS = 10000 18 | CFG.TRAIN.CKPT_SAVE_DIR = os.path.join( 19 | 'checkpoints', 20 | '_'.join([CFG.MODEL.NAME, str(CFG.TRAIN.NUM_EPOCHS)]) 21 | ) 22 | CFG.TRAIN.CKPT_SAVE_STRATEGY = None 23 | 24 | CFG.TRAIN.OPTIM = Config() 25 | CFG.TRAIN.OPTIM.TYPE = 'SGD' 26 | CFG.TRAIN.OPTIM.PARAM = { 27 | 'lr': 0.001, 28 | 'momentum': 0.1, 29 | } 30 | 31 | CFG.TRAIN.DATA = Config() 32 | CFG.TRAIN.DATA.BATCH_SIZE = 10 33 | CFG.TRAIN.DATA.K = 10 34 | CFG.TRAIN.DATA.B = 6 35 | CFG.TRAIN.DATA.NUM = 100 36 | CFG.TRAIN.DATA.SHUFFLE = True 37 | -------------------------------------------------------------------------------- /examples/mnist/conv_net.py: -------------------------------------------------------------------------------- 1 | from torch import nn 2 | 3 | 4 | class ConvNet(nn.Module): 5 | """Simple ConvNet for MNIST classification. 6 | """ 7 | 8 | def __init__(self): 9 | super().__init__() 10 | self.conv_block = nn.Sequential( 11 | nn.Conv2d(1, 10, kernel_size=5), 12 | nn.MaxPool2d(2), 13 | nn.ReLU(inplace=True), 14 | nn.Conv2d(10, 20, kernel_size=5), 15 | nn.Dropout2d(), 16 | nn.MaxPool2d(2), 17 | nn.ReLU(inplace=True), 18 | ) 19 | 20 | self.fc_block = nn.Sequential( 21 | nn.Linear(320, 50), 22 | nn.ReLU(inplace=True), 23 | nn.Dropout2d(), 24 | nn.Linear(50, 10), 25 | nn.LogSoftmax(dim=1) 26 | ) 27 | 28 | def forward(self, x): 29 | y = self.conv_block(x) 30 | y = y.view(-1, 320) 31 | y = self.fc_block(y) 32 | 33 | return y 34 | -------------------------------------------------------------------------------- /examples/mnist/configs/mnist_cpu_cfg.py: -------------------------------------------------------------------------------- 1 | import os 2 | from easytorch import Config 3 | 4 | from mnist_runner import MNISTRunner 5 | 6 | CFG = Config() 7 | 8 | CFG.DESC = 'mnist' 9 | CFG.RUNNER = MNISTRunner 10 | CFG.DEVICE = 'cpu' 11 | 12 | CFG.MODEL = Config() 13 | CFG.MODEL.NAME = 'conv_net' 14 | 15 | CFG.TRAIN = Config() 16 | 17 | CFG.TRAIN.NUM_EPOCHS = 30 18 | CFG.TRAIN.CKPT_SAVE_DIR = os.path.join( 19 | 'checkpoints', 20 | '_'.join([CFG.MODEL.NAME, str(CFG.TRAIN.NUM_EPOCHS)]) 21 | ) 22 | CFG.TRAIN.CKPT_SAVE_STRATEGY = None 23 | 24 | CFG.TRAIN.OPTIM = Config() 25 | CFG.TRAIN.OPTIM.TYPE = 'SGD' 26 | CFG.TRAIN.OPTIM.PARAM = { 27 | 'lr': 0.002, 28 | 'momentum': 0.1, 29 | } 30 | 31 | CFG.TRAIN.DATA = Config() 32 | CFG.TRAIN.DATA.BATCH_SIZE = 4 33 | CFG.TRAIN.DATA.DIR = 'mnist_data' 34 | CFG.TRAIN.DATA.SHUFFLE = True 35 | 36 | CFG.VAL = Config() 37 | 38 | CFG.VAL.INTERVAL = 1 39 | 40 | CFG.VAL.DATA = Config() 41 | CFG.VAL.DATA.DIR = 'mnist_data' 42 | -------------------------------------------------------------------------------- /examples/mnist/configs/mnist_1x_cfg.py: -------------------------------------------------------------------------------- 1 | import os 2 | from easytorch import Config 3 | 4 | from mnist_runner import MNISTRunner 5 | 6 | CFG = Config() 7 | 8 | CFG.DESC = 'mnist' 9 | CFG.RUNNER = MNISTRunner 10 | CFG.DEVICE = 'gpu' 11 | CFG.DEVICE_NUM = 1 12 | 13 | CFG.MODEL = Config() 14 | CFG.MODEL.NAME = 'conv_net' 15 | 16 | CFG.TRAIN = Config() 17 | 18 | CFG.TRAIN.NUM_EPOCHS = 30 19 | CFG.TRAIN.CKPT_SAVE_DIR = os.path.join( 20 | 'checkpoints', 21 | '_'.join([CFG.MODEL.NAME, str(CFG.TRAIN.NUM_EPOCHS)]) 22 | ) 23 | CFG.TRAIN.CKPT_SAVE_STRATEGY = None 24 | 25 | CFG.TRAIN.OPTIM = Config() 26 | CFG.TRAIN.OPTIM.TYPE = 'SGD' 27 | CFG.TRAIN.OPTIM.PARAM = { 28 | 'lr': 0.002, 29 | 'momentum': 0.1, 30 | } 31 | 32 | CFG.TRAIN.DATA = Config() 33 | CFG.TRAIN.DATA.BATCH_SIZE = 4 34 | CFG.TRAIN.DATA.DIR = 'mnist_data' 35 | CFG.TRAIN.DATA.SHUFFLE = True 36 | 37 | CFG.VAL = Config() 38 | 39 | CFG.VAL.INTERVAL = 1 40 | 41 | CFG.VAL.DATA = Config() 42 | CFG.VAL.DATA.DIR = 'mnist_data' 43 | -------------------------------------------------------------------------------- /.github/workflows/publish-pip.yml: -------------------------------------------------------------------------------- 1 | name: PyPI Publish 2 | 3 | on: push 4 | 5 | jobs: 6 | build-n-publish: 7 | runs-on: ubuntu-latest 8 | if: startsWith(github.event.ref, 'refs/tags') 9 | 10 | steps: 11 | - uses: actions/checkout@v2 12 | - name: Set up Python 3.8 13 | uses: actions/setup-python@v1 14 | with: 15 | python-version: 3.8 16 | - name: Upgrade pip 17 | run: pip install pip --upgrade 18 | - name: Install PyTorch (cpu) 19 | run: pip install torch==1.9.1+cpu torchvision==0.10.1+cpu torchaudio==0.9.1 -f https://download.pytorch.org/whl/torch_stable.html 20 | - name: Install dependencies 21 | run: pip install -r requirements.txt 22 | - name: Build and install 23 | run: rm -rf .eggs 24 | - name: Build for distribution 25 | run: python setup.py bdist_wheel 26 | - name: Publish distribution to PyPI 27 | uses: pypa/gh-action-pypi-publish@master 28 | with: 29 | password: ${{ secrets.PYPI_EASY_TORCH }} 30 | -------------------------------------------------------------------------------- /examples/mnist/validate.py: -------------------------------------------------------------------------------- 1 | from argparse import ArgumentParser 2 | 3 | from easytorch import launch_runner, Runner 4 | 5 | 6 | def parse_args(): 7 | parser = ArgumentParser(description='Welcome to EasyTorch!') 8 | parser.add_argument('-c', '--cfg', help='training config', required=True) 9 | parser.add_argument('--ckpt', help='ckpt path. if it is None, load default ckpt in ckpt save dir', type=str) 10 | parser.add_argument('--device-type', help='device type', type=str, default='gpu') 11 | parser.add_argument('--devices', help='visible devices', type=str) 12 | return parser.parse_args() 13 | 14 | 15 | def main(cfg: dict, runner: Runner, ckpt: str = None): 16 | # init logger 17 | runner.init_logger(logger_name='easytorch-inference', log_file_name='validate_result') 18 | 19 | runner.load_model(ckpt_path=ckpt) 20 | 21 | runner.validate(cfg) 22 | 23 | 24 | if __name__ == '__main__': 25 | args = parse_args() 26 | launch_runner(args.cfg, main, (args.ckpt, ), device_type=args.device_type, devices=args.devices) 27 | -------------------------------------------------------------------------------- /examples/imagenet/validate.py: -------------------------------------------------------------------------------- 1 | from argparse import ArgumentParser 2 | 3 | from easytorch import launch_runner, Runner 4 | 5 | 6 | def parse_args(): 7 | parser = ArgumentParser(description='Welcome to EasyTorch!') 8 | parser.add_argument('-c', '--cfg', help='training config', required=True) 9 | parser.add_argument('--ckpt', help='ckpt path. if it is None, load default ckpt in ckpt save dir', type=str) 10 | parser.add_argument('--device-type', help='device type', type=str, default='gpu') 11 | parser.add_argument('--devices', help='visible devices', type=str) 12 | return parser.parse_args() 13 | 14 | 15 | def main(cfg: dict, runner: Runner, ckpt: str = None): 16 | # init logger 17 | runner.init_logger(logger_name='easytorch-inference', log_file_name='validate_result') 18 | 19 | runner.load_model(ckpt_path=ckpt) 20 | 21 | runner.validate(cfg) 22 | 23 | 24 | if __name__ == '__main__': 25 | args = parse_args() 26 | launch_runner(args.cfg, main, (args.ckpt, ), device_type=args.device_type, devices=args.devices) 27 | -------------------------------------------------------------------------------- /examples/mnist/configs/mnist_lr_cpu_cfg.py: -------------------------------------------------------------------------------- 1 | import os 2 | from easytorch import Config 3 | 4 | from mnist_runner import MNISTRunner 5 | 6 | CFG = Config() 7 | 8 | CFG.DESC = 'mnist, lr scheduler' 9 | CFG.RUNNER = MNISTRunner 10 | CFG.DEVICE = 'cpu' 11 | 12 | CFG.MODEL = Config() 13 | CFG.MODEL.NAME = 'conv_net' 14 | 15 | CFG.TRAIN = Config() 16 | 17 | CFG.TRAIN.NUM_EPOCHS = 30 18 | CFG.TRAIN.CKPT_SAVE_DIR = os.path.join( 19 | 'checkpoints', 20 | '_'.join([CFG.MODEL.NAME, str(CFG.TRAIN.NUM_EPOCHS)]) 21 | ) 22 | CFG.TRAIN.CKPT_SAVE_STRATEGY = None 23 | 24 | CFG.TRAIN.OPTIM = Config() 25 | CFG.TRAIN.OPTIM.TYPE = 'SGD' 26 | CFG.TRAIN.OPTIM.PARAM = { 27 | 'lr': 0.002, 28 | 'momentum': 0.1, 29 | } 30 | 31 | CFG.TRAIN.LR_SCHEDULER = Config() 32 | CFG.TRAIN.LR_SCHEDULER.TYPE = 'CosineAnnealingLR' 33 | CFG.TRAIN.LR_SCHEDULER.PARAM = { 34 | 'T_max': CFG.TRAIN.NUM_EPOCHS, 35 | 'eta_min': 1e-6 36 | } 37 | 38 | CFG.TRAIN.DATA = Config() 39 | CFG.TRAIN.DATA.BATCH_SIZE = 4 40 | CFG.TRAIN.DATA.DIR = 'mnist_data' 41 | CFG.TRAIN.DATA.SHUFFLE = True 42 | 43 | CFG.VAL = Config() 44 | 45 | CFG.VAL.INTERVAL = 1 46 | 47 | CFG.VAL.DATA = Config() 48 | CFG.VAL.DATA.DIR = 'mnist_data' 49 | -------------------------------------------------------------------------------- /examples/imagenet/README.md: -------------------------------------------------------------------------------- 1 | # EasyTorch Example - ImageNet Classification 2 | 3 | Reference from [https://github.com/pytorch/examples/tree/main/imagenet](https://github.com/pytorch/examples/tree/main/imagenet) 4 | 5 | ## Train 6 | 7 | ### Resnet50 (One node) 8 | 9 | ```shell 10 | easytrain -c configs/resnet50_8x_cfg.py 11 | ``` 12 | 13 | ### MobileNet V3 Large (One node) 14 | 15 | ```shell 16 | easytrain -c configs/mobilenet_v3_large_8x_cfg.py 17 | ``` 18 | 19 | ### Resnet50 (Muti node) 20 | 21 | Modify `CFG.DIST_INIT_METHOD='tcp://{ip_of_node_0}:{free_port}'` in `configs/resnet50_16x_cfg.py`. 22 | 23 | e.g. 24 | 25 | ```python 26 | CFG.DIST_INIT_METHOD='tcp://192.168.1.2:55555' 27 | ``` 28 | 29 | * Node 0: 30 | 31 | ```shell 32 | easytrain -c configs/resnet50_16x_cfg.py 33 | ``` 34 | 35 | * Node 1: 36 | 37 | ```shell 38 | easytrain -c configs/resnet50_16x_cfg.py --node-rank 1 39 | ``` 40 | 41 | ### Other models 42 | 43 | To train other models or modify hyperparameters, customize config yourself. 44 | 45 | ## Validate 46 | 47 | ### Resnet50 48 | 49 | ```shell 50 | # last 51 | python validate.py -c configs/resnet50_8x_cfg.py --devices 0 52 | 53 | # best 54 | python validate.py -c configs/resnet50_8x_cfg.py --devices 0 --ckpt /path/to/ckpt_dir/resnet50_best_val_acc@1.pt 55 | ``` 56 | 57 | ### MobileNet V3 Large 58 | 59 | ```shell 60 | python validate.py -c configs/mobilenet_v3_large_8x_cfg.py --devices 0 61 | ``` 62 | -------------------------------------------------------------------------------- /easytorch/utils/logging.py: -------------------------------------------------------------------------------- 1 | import logging 2 | 3 | from .dist import is_master 4 | 5 | logger_initialized = set([]) 6 | 7 | 8 | def get_logger(name: str, log_file: str = None, log_level: int = logging.INFO, file_mode: str = 'w') -> logging.Logger: 9 | """Return a logger with the specified name, creating it if necessary. 10 | 11 | Notes: 12 | If current process is master process, return `Logger(log_level)` with FileHandler. 13 | If current process is not master process, return `Logger(logging.ERROR)` 14 | 15 | Args: 16 | name (str): specified name of logger 17 | log_file (str): logger file name 18 | log_level (int): logger level 19 | file_mode (str): logger file mode 20 | 21 | Returns: 22 | logger (logging.Logger) 23 | """ 24 | 25 | logger = logging.getLogger(name) 26 | logger.propagate = False 27 | 28 | if name in logger_initialized: 29 | return logger 30 | 31 | logger_handlers = [logging.StreamHandler()] 32 | 33 | if is_master() and log_file is not None: 34 | logger_handlers.append(logging.FileHandler(log_file, file_mode)) 35 | 36 | formatter = logging.Formatter( 37 | '%(asctime)s - %(name)s - %(levelname)s - %(message)s' 38 | ) 39 | 40 | for handler in logger_handlers: 41 | handler.setFormatter(formatter) 42 | handler.setLevel(log_level) 43 | logger.addHandler(handler) 44 | 45 | if is_master(): 46 | logger.setLevel(log_level) 47 | else: 48 | logger.setLevel(logging.ERROR) 49 | 50 | logger_initialized.add(name) 51 | 52 | return logger 53 | -------------------------------------------------------------------------------- /.github/workflows/build-docker.yml: -------------------------------------------------------------------------------- 1 | name: Build Docker Image 2 | 3 | on: push 4 | 5 | jobs: 6 | build: 7 | runs-on: ubuntu-latest 8 | if: startsWith(github.event.ref, 'refs/tags') 9 | 10 | strategy: 11 | matrix: 12 | base-image-tag: 13 | - "1.4.0-py3.8.13-cuda10.1-ubuntu18.04" 14 | - "1.5.1-py3.8.13-cuda10.2-ubuntu18.04" 15 | - "1.6.0-py3.8.13-cuda10.2-ubuntu18.04" 16 | - "1.7.1-py3.9.12-cuda11.0-ubuntu18.04" 17 | - "1.8.1-py3.9.12-cuda11.1-ubuntu20.04" 18 | - "1.9.1-py3.9.12-cuda11.1-ubuntu20.04" 19 | - "1.10.2-py3.9.12-cuda11.3.1-ubuntu20.04" 20 | - "1.11.0-py3.9.12-cuda11.3.1-ubuntu20.04" 21 | - "1.12.1-py3.9.12-cuda11.6.2-ubuntu20.04" 22 | 23 | steps: 24 | - uses: actions/checkout@v2 25 | 26 | - name: Get version 27 | id: get_version 28 | run: | 29 | echo "EASYTORCH_VERSION=`cd easytorch && python -c '''from version import __version__; print(__version__)'''`" >> $GITHUB_ENV 30 | echo "IMAGE_TAG=${{ matrix.base-image-tag }}" >> $GITHUB_ENV 31 | 32 | - name: Login DockerHub 33 | run: docker login --username=${{ secrets.DOCKER_USERNAME }} --password=${{ secrets.DOCKER_PASSWORD }} 34 | 35 | - name: Build docker image 36 | run: | 37 | docker build \ 38 | --build-arg IMAGE_TAG=${IMAGE_TAG} \ 39 | -t cnstark/easytorch:${EASYTORCH_VERSION}-${IMAGE_TAG} \ 40 | -f docker/Dockerfile \ 41 | . 42 | 43 | - name: Push docker image 44 | run: docker push cnstark/easytorch:${EASYTORCH_VERSION}-${IMAGE_TAG} 45 | -------------------------------------------------------------------------------- /examples/imagenet/configs/resnet50_8x_cfg.py: -------------------------------------------------------------------------------- 1 | import os 2 | from easytorch import Config 3 | 4 | from imagenet_runner import ImagenetRunner 5 | 6 | CFG = Config() 7 | 8 | CFG.DESC = 'imagenet resnet50' 9 | CFG.RUNNER = ImagenetRunner 10 | CFG.DEVICE = 'gpu' 11 | CFG.DEVICE_NUM = 8 12 | 13 | CFG.MODEL = Config() 14 | CFG.MODEL.NAME = 'resnet50' 15 | 16 | CFG.TRAIN = Config() 17 | 18 | CFG.TRAIN.NUM_EPOCHS = 90 19 | CFG.TRAIN.CKPT_SAVE_DIR = os.path.join( 20 | 'checkpoints', 21 | '_'.join([CFG.MODEL.NAME, str(CFG.TRAIN.NUM_EPOCHS)]) 22 | ) 23 | CFG.TRAIN.CKPT_SAVE_STRATEGY = None 24 | 25 | CFG.TRAIN.OPTIM = Config() 26 | CFG.TRAIN.OPTIM.TYPE = 'SGD' 27 | CFG.TRAIN.OPTIM.PARAM = { 28 | 'lr': 0.1, 29 | 'momentum': 0.9, 30 | 'weight_decay': 1e-4 31 | } 32 | 33 | CFG.TRAIN.LR_SCHEDULER = Config() 34 | CFG.TRAIN.LR_SCHEDULER.TYPE = 'StepLR' 35 | CFG.TRAIN.LR_SCHEDULER.PARAM = { 36 | 'step_size': 30, 37 | 'gamma': 0.1 38 | } 39 | 40 | IMAGENET_PATH = 'datasets/imagenet/jpegs' 41 | 42 | CFG.TRAIN.DATA = Config() 43 | CFG.TRAIN.DATA.BATCH_SIZE = 32 44 | CFG.TRAIN.DATA.NUM_WORKERS = 4 45 | CFG.TRAIN.DATA.SHUFFLE = True 46 | 47 | CFG.TRAIN.DATA.DIR = os.path.join(IMAGENET_PATH, 'train') 48 | CFG.TRAIN.DATA.CROP_SIZE = 224 49 | CFG.TRAIN.DATA.NORMALIZE = { 50 | 'mean': [0.485, 0.456, 0.406], 51 | 'std': [0.229, 0.224, 0.225] 52 | } 53 | 54 | CFG.VAL = Config() 55 | 56 | CFG.VAL.INTERVAL = 1 57 | 58 | CFG.VAL.DATA = Config() 59 | CFG.VAL.DATA.BATCH_SIZE = 32 60 | CFG.VAL.DATA.DIR = os.path.join(IMAGENET_PATH, 'val') 61 | CFG.VAL.DATA.CROP_SIZE = 224 62 | CFG.VAL.DATA.RESIZE = 256 63 | CFG.VAL.DATA.NORMALIZE = { 64 | 'mean': [0.485, 0.456, 0.406], 65 | 'std': [0.229, 0.224, 0.225] 66 | } 67 | -------------------------------------------------------------------------------- /examples/imagenet/configs/mobilenet_v3_large_8x_cfg.py: -------------------------------------------------------------------------------- 1 | import os 2 | from easytorch import Config 3 | 4 | from imagenet_runner import ImagenetRunner 5 | 6 | CFG = Config() 7 | 8 | CFG.DESC = 'imagenet mobilenet_v3_large' 9 | CFG.RUNNER = ImagenetRunner 10 | CFG.DEVICE = 'gpu' 11 | CFG.DEVICE_NUM = 8 12 | 13 | CFG.MODEL = Config() 14 | CFG.MODEL.NAME = 'mobilenet_v3_large' 15 | 16 | CFG.TRAIN = Config() 17 | 18 | CFG.TRAIN.NUM_EPOCHS = 90 19 | CFG.TRAIN.CKPT_SAVE_DIR = os.path.join( 20 | 'checkpoints', 21 | '_'.join([CFG.MODEL.NAME, str(CFG.TRAIN.NUM_EPOCHS)]) 22 | ) 23 | CFG.TRAIN.CKPT_SAVE_STRATEGY = None 24 | 25 | CFG.TRAIN.OPTIM = Config() 26 | CFG.TRAIN.OPTIM.TYPE = 'SGD' 27 | CFG.TRAIN.OPTIM.PARAM = { 28 | 'lr': 0.1, 29 | 'momentum': 0.9, 30 | 'weight_decay': 1e-4 31 | } 32 | 33 | CFG.TRAIN.LR_SCHEDULER = Config() 34 | CFG.TRAIN.LR_SCHEDULER.TYPE = 'StepLR' 35 | CFG.TRAIN.LR_SCHEDULER.PARAM = { 36 | 'step_size': 30, 37 | 'gamma': 0.1 38 | } 39 | 40 | IMAGENET_PATH = 'datasets/imagenet/jpegs' 41 | 42 | CFG.TRAIN.DATA = Config() 43 | CFG.TRAIN.DATA.BATCH_SIZE = 32 44 | CFG.TRAIN.DATA.NUM_WORKERS = 4 45 | CFG.TRAIN.DATA.SHUFFLE = True 46 | 47 | CFG.TRAIN.DATA.DIR = os.path.join(IMAGENET_PATH, 'train') 48 | CFG.TRAIN.DATA.CROP_SIZE = 224 49 | CFG.TRAIN.DATA.NORMALIZE = { 50 | 'mean': [0.485, 0.456, 0.406], 51 | 'std': [0.229, 0.224, 0.225] 52 | } 53 | 54 | CFG.VAL = Config() 55 | 56 | CFG.VAL.INTERVAL = 1 57 | 58 | CFG.VAL.DATA = Config() 59 | CFG.VAL.DATA.BATCH_SIZE = 32 60 | CFG.VAL.DATA.DIR = os.path.join(IMAGENET_PATH, 'val') 61 | CFG.VAL.DATA.CROP_SIZE = 224 62 | CFG.VAL.DATA.RESIZE = 256 63 | CFG.VAL.DATA.NORMALIZE = { 64 | 'mean': [0.485, 0.456, 0.406], 65 | 'std': [0.229, 0.224, 0.225] 66 | } 67 | -------------------------------------------------------------------------------- /examples/imagenet/configs/resnet50_8x_mlu_cfg.py: -------------------------------------------------------------------------------- 1 | import os 2 | from easytorch import Config 3 | 4 | from imagenet_runner import ImagenetRunner 5 | 6 | CFG = Config() 7 | 8 | CFG.DESC = 'imagenet resnet50' 9 | CFG.RUNNER = ImagenetRunner 10 | CFG.DEVICE = 'mlu' 11 | CFG.DEVICE_NUM = 8 12 | CFG.DIST_BACKEND = 'cncl' 13 | 14 | CFG.MODEL = Config() 15 | CFG.MODEL.NAME = 'resnet50' 16 | 17 | CFG.TRAIN = Config() 18 | 19 | CFG.TRAIN.NUM_EPOCHS = 90 20 | CFG.TRAIN.CKPT_SAVE_DIR = os.path.join( 21 | 'checkpoints', 22 | '_'.join([CFG.MODEL.NAME, str(CFG.TRAIN.NUM_EPOCHS)]) 23 | ) 24 | CFG.TRAIN.CKPT_SAVE_STRATEGY = None 25 | 26 | CFG.TRAIN.OPTIM = Config() 27 | CFG.TRAIN.OPTIM.TYPE = 'SGD' 28 | CFG.TRAIN.OPTIM.PARAM = { 29 | 'lr': 0.1, 30 | 'momentum': 0.9, 31 | 'weight_decay': 1e-4 32 | } 33 | 34 | CFG.TRAIN.LR_SCHEDULER = Config() 35 | CFG.TRAIN.LR_SCHEDULER.TYPE = 'StepLR' 36 | CFG.TRAIN.LR_SCHEDULER.PARAM = { 37 | 'step_size': 30, 38 | 'gamma': 0.1 39 | } 40 | 41 | IMAGENET_PATH = 'datasets/imagenet/jpegs' 42 | 43 | CFG.TRAIN.DATA = Config() 44 | CFG.TRAIN.DATA.BATCH_SIZE = 32 45 | CFG.TRAIN.DATA.NUM_WORKERS = 4 46 | CFG.TRAIN.DATA.SHUFFLE = True 47 | 48 | CFG.TRAIN.DATA.DIR = os.path.join(IMAGENET_PATH, 'train') 49 | CFG.TRAIN.DATA.CROP_SIZE = 224 50 | CFG.TRAIN.DATA.NORMALIZE = { 51 | 'mean': [0.485, 0.456, 0.406], 52 | 'std': [0.229, 0.224, 0.225] 53 | } 54 | 55 | CFG.VAL = Config() 56 | 57 | CFG.VAL.INTERVAL = 1 58 | 59 | CFG.VAL.DATA = Config() 60 | CFG.VAL.DATA.BATCH_SIZE = 32 61 | CFG.VAL.DATA.DIR = os.path.join(IMAGENET_PATH, 'val') 62 | CFG.VAL.DATA.CROP_SIZE = 224 63 | CFG.VAL.DATA.RESIZE = 256 64 | CFG.VAL.DATA.NORMALIZE = { 65 | 'mean': [0.485, 0.456, 0.406], 66 | 'std': [0.229, 0.224, 0.225] 67 | } 68 | -------------------------------------------------------------------------------- /easytorch/utils/misc.py: -------------------------------------------------------------------------------- 1 | import os 2 | from typing import Tuple, Union 3 | 4 | 5 | def scan_dir(dir_path: str, suffix: Union[str, Tuple[str]] = None, recursive: bool = False, full_path: bool = False): 6 | """Scan a directory to find the interested files. 7 | Args: 8 | dir_path (str): Path of the directory. 9 | suffix (str | tuple(str), optional): File suffix that we are 10 | interested in. Default: None. 11 | recursive (bool, optional): If set to True, recursively scan the 12 | directory. Default: False. 13 | full_path (bool, optional): If set to True, include the dir_path. 14 | Default: False. 15 | Returns: 16 | A generator for all the interested files with relative paths. 17 | """ 18 | 19 | if (suffix is not None) and not isinstance(suffix, (str, tuple)): 20 | raise TypeError('"suffix" must be a string or tuple of strings') 21 | 22 | root = dir_path 23 | 24 | def _scan_dir(dir_path, suffix, recursive): 25 | for entry in os.scandir(dir_path): 26 | if not entry.name.startswith('.') and entry.is_file(): 27 | if full_path: 28 | return_path = entry.path 29 | else: 30 | return_path = os.path.relpath(entry.path, root) 31 | 32 | if suffix is None: 33 | yield return_path 34 | elif return_path.endswith(suffix): 35 | yield return_path 36 | else: 37 | if recursive: 38 | yield from _scan_dir(entry.path, suffix=suffix, recursive=recursive) 39 | else: 40 | continue 41 | 42 | return _scan_dir(dir_path, suffix=suffix, recursive=recursive) 43 | -------------------------------------------------------------------------------- /examples/imagenet/configs/resnet50_clip_grad_8x_cfg.py: -------------------------------------------------------------------------------- 1 | import os 2 | from easytorch import Config 3 | 4 | from imagenet_runner import ImagenetRunner 5 | 6 | CFG = Config() 7 | 8 | CFG.DESC = 'imagenet resnet50' 9 | CFG.RUNNER = ImagenetRunner 10 | CFG.DEVICE = 'gpu' 11 | CFG.DEVICE_NUM = 8 12 | 13 | CFG.MODEL = Config() 14 | CFG.MODEL.NAME = 'resnet50' 15 | 16 | CFG.TRAIN = Config() 17 | 18 | CFG.TRAIN.NUM_EPOCHS = 90 19 | CFG.TRAIN.CKPT_SAVE_DIR = os.path.join( 20 | 'checkpoints', 21 | '_'.join([CFG.MODEL.NAME, str(CFG.TRAIN.NUM_EPOCHS)]) 22 | ) 23 | CFG.TRAIN.CKPT_SAVE_STRATEGY = None 24 | 25 | CFG.TRAIN.OPTIM = Config() 26 | CFG.TRAIN.OPTIM.TYPE = 'SGD' 27 | CFG.TRAIN.OPTIM.PARAM = { 28 | 'lr': 0.1, 29 | 'momentum': 0.9, 30 | 'weight_decay': 1e-4 31 | } 32 | 33 | CFG.TRAIN.LR_SCHEDULER = Config() 34 | CFG.TRAIN.LR_SCHEDULER.TYPE = 'StepLR' 35 | CFG.TRAIN.LR_SCHEDULER.PARAM = { 36 | 'step_size': 30, 37 | 'gamma': 0.1 38 | } 39 | 40 | CFG.TRAIN.CLIP_GRAD_PARAM = { 41 | 'max_norm': 1.0 42 | } 43 | 44 | IMAGENET_PATH = 'datasets/imagenet/jpegs' 45 | 46 | CFG.TRAIN.DATA = Config() 47 | CFG.TRAIN.DATA.BATCH_SIZE = 32 48 | CFG.TRAIN.DATA.NUM_WORKERS = 4 49 | CFG.TRAIN.DATA.SHUFFLE = True 50 | 51 | CFG.TRAIN.DATA.DIR = os.path.join(IMAGENET_PATH, 'train') 52 | CFG.TRAIN.DATA.CROP_SIZE = 224 53 | CFG.TRAIN.DATA.NORMALIZE = { 54 | 'mean': [0.485, 0.456, 0.406], 55 | 'std': [0.229, 0.224, 0.225] 56 | } 57 | 58 | CFG.VAL = Config() 59 | 60 | CFG.VAL.INTERVAL = 1 61 | 62 | CFG.VAL.DATA = Config() 63 | CFG.VAL.DATA.BATCH_SIZE = 32 64 | CFG.VAL.DATA.DIR = os.path.join(IMAGENET_PATH, 'val') 65 | CFG.VAL.DATA.CROP_SIZE = 224 66 | CFG.VAL.DATA.RESIZE = 256 67 | CFG.VAL.DATA.NORMALIZE = { 68 | 'mean': [0.485, 0.456, 0.406], 69 | 'std': [0.229, 0.224, 0.225] 70 | } 71 | -------------------------------------------------------------------------------- /examples/imagenet/configs/resnet50_16x_cfg.py: -------------------------------------------------------------------------------- 1 | import os 2 | from easytorch import Config 3 | 4 | from imagenet_runner import ImagenetRunner 5 | 6 | CFG = Config() 7 | 8 | CFG.DESC = 'imagenet resnet50' 9 | CFG.RUNNER = ImagenetRunner 10 | CFG.DEVICE = 'gpu' 11 | CFG.DEVICE_NUM = 8 12 | CFG.DIST_NODE_NUM = 2 13 | CFG.DIST_BACKEND = 'nccl' 14 | CFG.DIST_INIT_METHOD='tcp://{ip_of_node_0}:{free_port}' 15 | 16 | CFG.MODEL = Config() 17 | CFG.MODEL.NAME = 'resnet50' 18 | 19 | CFG.TRAIN = Config() 20 | 21 | CFG.TRAIN.NUM_EPOCHS = 90 22 | CFG.TRAIN.CKPT_SAVE_DIR = os.path.join( 23 | 'checkpoints', 24 | '_'.join([CFG.MODEL.NAME, str(CFG.TRAIN.NUM_EPOCHS)]) 25 | ) 26 | CFG.TRAIN.CKPT_SAVE_STRATEGY = None 27 | 28 | CFG.TRAIN.OPTIM = Config() 29 | CFG.TRAIN.OPTIM.TYPE = 'SGD' 30 | CFG.TRAIN.OPTIM.PARAM = { 31 | 'lr': 0.1, 32 | 'momentum': 0.9, 33 | 'weight_decay': 1e-4 34 | } 35 | 36 | CFG.TRAIN.LR_SCHEDULER = Config() 37 | CFG.TRAIN.LR_SCHEDULER.TYPE = 'StepLR' 38 | CFG.TRAIN.LR_SCHEDULER.PARAM = { 39 | 'step_size': 30, 40 | 'gamma': 0.1 41 | } 42 | 43 | IMAGENET_PATH = 'datasets/imagenet/jpegs' 44 | 45 | CFG.TRAIN.DATA = Config() 46 | CFG.TRAIN.DATA.BATCH_SIZE = 16 47 | CFG.TRAIN.DATA.NUM_WORKERS = 4 48 | CFG.TRAIN.DATA.SHUFFLE = True 49 | 50 | CFG.TRAIN.DATA.DIR = os.path.join(IMAGENET_PATH, 'train') 51 | CFG.TRAIN.DATA.CROP_SIZE = 224 52 | CFG.TRAIN.DATA.NORMALIZE = { 53 | 'mean': [0.485, 0.456, 0.406], 54 | 'std': [0.229, 0.224, 0.225] 55 | } 56 | 57 | CFG.VAL = Config() 58 | 59 | CFG.VAL.INTERVAL = 1 60 | 61 | CFG.VAL.DATA = Config() 62 | CFG.VAL.DATA.BATCH_SIZE = 16 63 | CFG.VAL.DATA.DIR = os.path.join(IMAGENET_PATH, 'val') 64 | CFG.VAL.DATA.CROP_SIZE = 224 65 | CFG.VAL.DATA.RESIZE = 256 66 | CFG.VAL.DATA.NORMALIZE = { 67 | 'mean': [0.485, 0.456, 0.406], 68 | 'std': [0.229, 0.224, 0.225] 69 | } 70 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | import os 2 | from setuptools import find_packages, setup 3 | 4 | 5 | def readme(): 6 | with open('README.md', encoding='utf-8') as f: 7 | content = f.read() 8 | return content 9 | 10 | 11 | def get_version(): 12 | version_file = 'easytorch/version.py' 13 | with open(version_file, 'r', encoding='utf-8') as f: 14 | exec(compile(f.read(), version_file, 'exec')) 15 | return locals()['__version__'] 16 | 17 | 18 | def get_requirements(filename='requirements.txt'): 19 | here = os.path.dirname(os.path.realpath(__file__)) 20 | with open(os.path.join(here, filename), 'r') as f: 21 | requires = [line.replace('\n', '') for line in f.readlines()] 22 | return requires 23 | 24 | 25 | if __name__ == '__main__': 26 | setup( 27 | name='easy-torch', 28 | version=get_version(), 29 | description='Simple and powerful pytorch framework.', 30 | long_description=readme(), 31 | long_description_content_type='text/markdown', 32 | author='Yuhao Wang', 33 | author_email='yuhaow97@gmail.com', 34 | keywords='pytorch, deep learning', 35 | url='https://github.com/cnstark/easytorch', 36 | include_package_data=True, 37 | packages=find_packages(exclude=('tests',)), 38 | classifiers=[ 39 | 'License :: OSI Approved :: Apache Software License', 40 | 'Operating System :: OS Independent', 41 | 'Programming Language :: Python :: 3', 42 | 'Programming Language :: Python :: 3.7', 43 | 'Programming Language :: Python :: 3.8', 44 | 'Programming Language :: Python :: 3.9', 45 | 'Topic :: Utilities' 46 | ], 47 | entry_points={ 48 | 'console_scripts': ['easytrain=easytorch.entry_points:easytrain'], 49 | }, 50 | license='Apache License 2.0', 51 | install_requires=get_requirements(), 52 | zip_safe=False 53 | ) 54 | -------------------------------------------------------------------------------- /examples/linear_regression/linear_regression_runner.py: -------------------------------------------------------------------------------- 1 | from torch import nn 2 | 3 | from easytorch import Runner 4 | from easytorch.device import to_device 5 | 6 | from dataset import LinearDataset 7 | 8 | 9 | class LinearRegressionRunner(Runner): 10 | """LinearRegressionRunner 11 | """ 12 | 13 | def init_training(self, cfg): 14 | """Initialize training. 15 | 16 | Including loss, training meters, etc. 17 | 18 | Args: 19 | cfg (dict): config 20 | """ 21 | 22 | super().init_training(cfg) 23 | 24 | self.loss = nn.MSELoss() 25 | self.loss = to_device(self.loss) 26 | 27 | self.register_epoch_meter('train_loss', 'train', '{:.2f}') 28 | 29 | @staticmethod 30 | def define_model(cfg: dict) -> nn.Module: 31 | """Define model. 32 | 33 | Args: 34 | cfg (dict): config 35 | 36 | Returns: 37 | model (nn.Module) 38 | """ 39 | 40 | return nn.Linear(1, 1) 41 | 42 | @staticmethod 43 | def build_train_dataset(cfg: dict): 44 | """Build MNIST train dataset 45 | 46 | Args: 47 | cfg (dict): config 48 | 49 | Returns: 50 | train dataset (Dataset) 51 | """ 52 | 53 | return LinearDataset( 54 | cfg['TRAIN']['DATA']['K'], 55 | cfg['TRAIN']['DATA']['B'], 56 | cfg['TRAIN']['DATA']['NUM'], 57 | ) 58 | 59 | def train_iters(self, epoch, iter_index, data): 60 | """Training details. 61 | 62 | Args: 63 | epoch (int): current epoch. 64 | iter_index (int): current iter. 65 | data (torch.Tensor or tuple): Data provided by DataLoader 66 | 67 | Returns: 68 | loss (torch.Tensor) 69 | """ 70 | 71 | x, y = data 72 | x = to_device(x) 73 | y = to_device(y) 74 | 75 | output = self.model(x) 76 | loss = self.loss(output, y) 77 | self.update_epoch_meter('train_loss', loss.item()) 78 | return loss 79 | 80 | def on_training_end(self): 81 | """Print result on training end. 82 | """ 83 | 84 | super().on_training_end() 85 | self.logger.info('Result: k: {}, b: {}'.format(self.model.weight.item(), self.model.bias.item())) 86 | -------------------------------------------------------------------------------- /easytorch/config/__init__.py: -------------------------------------------------------------------------------- 1 | """Everything is based on config. 2 | 3 | `Config` is the set of all configurations. `Config` is is implemented by `dict`, We recommend using `Config`. 4 | 5 | Look at the following example: 6 | 7 | cfg.py 8 | 9 | ```python 10 | import os 11 | from easytorch import Config 12 | 13 | from my_runner import MyRunner 14 | 15 | CFG = {} 16 | 17 | CFG.DESC = 'my net' # customized description 18 | CFG.RUNNER = MyRunner 19 | CFG.GPU_NUM = 1 20 | 21 | CFG.MODEL = {} 22 | CFG.MODEL.NAME = 'my_net' 23 | 24 | CFG.TRAIN = {} 25 | 26 | CFG.TRAIN.NUM_EPOCHS = 100 27 | CFG.TRAIN.CKPT_SAVE_DIR = os.path.join( 28 | 'checkpoints', 29 | '_'.join([CFG.MODEL.NAME, str(CFG.TRAIN.NUM_EPOCHS)]) 30 | ) 31 | CFG.TRAIN.CKPT_SAVE_STRATEGY = None 32 | 33 | CFG.TRAIN.OPTIM = {} 34 | CFG.TRAIN.OPTIM.TYPE = 'SGD' 35 | CFG.TRAIN.OPTIM.PARAM = { 36 | 'lr': 0.002, 37 | 'momentum': 0.1, 38 | } 39 | 40 | CFG.TRAIN.DATA = {} 41 | CFG.TRAIN.DATA.BATCH_SIZE = 4 42 | CFG.TRAIN.DATA.DIR = './my_data' 43 | CFG.TRAIN.DATA.SHUFFLE = True 44 | CFG.TRAIN.DATA.PIN_MEMORY = True 45 | CFG.TRAIN.DATA.PREFETCH = True 46 | 47 | CFG.VAL = {} 48 | 49 | CFG.VAL.INTERVAL = 1 50 | 51 | CFG.VAL.DATA = {} 52 | CFG.VAL.DATA.DIR = 'mnist_data' 53 | 54 | CFG._TRAINING_INDEPENDENT` = [ 55 | 'OTHER_CONFIG' 56 | ] 57 | 58 | ``` 59 | 60 | All configurations consists of two parts: 61 | 1. Training dependent configuration: changing this will affect the training results. 62 | 2. Training independent configuration: changing this will not affect the training results. 63 | 64 | Notes: 65 | All training dependent configurations will be calculated MD5, 66 | this MD5 value will be the sub directory name of checkpoint save directory. 67 | If the MD5 value is `098f6bcd4621d373cade4e832627b4f6`, 68 | real checkpoint save directory is `{CFG.TRAIN.CKPT_SAVE_DIR}/098f6bcd4621d373cade4e832627b4f6` 69 | 70 | Notes: 71 | Each configuration default is training dependent, 72 | except the key is in `TRAINING_INDEPENDENT_KEYS` or `CFG._TRAINING_INDEPENDENT` 73 | """ 74 | from .config import Config 75 | from .utils import config_str, config_md5, save_config_str, copy_config_file, import_config, convert_config, \ 76 | get_ckpt_save_dir, init_cfg 77 | 78 | 79 | __all__ = [ 80 | 'Config', 'config_str', 'config_md5', 'save_config_str', 'copy_config_file', 81 | 'import_config', 'convert_config', 'get_ckpt_save_dir', 'init_cfg' 82 | ] 83 | -------------------------------------------------------------------------------- /tests/random_test/random_test.py: -------------------------------------------------------------------------------- 1 | import random 2 | from typing import Dict, Tuple, Union 3 | 4 | import numpy as np 5 | import torch 6 | from torch import nn 7 | from torch.utils.data import Dataset 8 | 9 | from easytorch import Config, Runner, get_rank, launch_training 10 | 11 | 12 | class FakeDataset(Dataset): 13 | """FakeDataset 14 | """ 15 | 16 | def __init__(self, num: int, min_: int, max_: int): 17 | self.num = num 18 | self.min = min_ 19 | self.max = max_ 20 | 21 | def __getitem__(self, index): 22 | return index, \ 23 | random.randint(self.min, self.max), \ 24 | np.random.randint(self.min, self.max + 1), \ 25 | torch.randint(self.min, self.max + 1, (1,)).item() 26 | 27 | def __len__(self): 28 | return self.num 29 | 30 | 31 | class DDPTestRunner(Runner): 32 | """DDPTestRunner 33 | """ 34 | 35 | @staticmethod 36 | def define_model(cfg: Dict) -> nn.Module: 37 | return nn.Conv2d(3, 3, 3) 38 | 39 | @staticmethod 40 | def build_train_dataset(cfg: Dict): 41 | return FakeDataset(cfg['TRAIN']['DATA']['NUM'], cfg['TRAIN']['DATA']['MIN'], cfg['TRAIN']['DATA']['MAX']) 42 | 43 | def train_iters(self, epoch: int, iter_index: int, data: Union[torch.Tensor, Tuple]) -> torch.Tensor: 44 | print('rank: {:d}, epoch: {:d}, iter: {:d}, data: {}'.format(get_rank(), epoch, iter_index, data)) 45 | if torch.distributed.is_initialized(): 46 | torch.distributed.barrier() 47 | 48 | 49 | def build_cfg(): 50 | CFG = Config() 51 | 52 | CFG.DESC = 'ddp test' 53 | CFG.RUNNER = DDPTestRunner 54 | CFG.GPU_NUM = 8 55 | 56 | CFG.ENV = Config() 57 | CFG.ENV.TF32 = False 58 | CFG.ENV.SEED = 6 59 | 60 | CFG.MODEL = Config() 61 | CFG.MODEL.NAME = 'conv' 62 | 63 | CFG.TRAIN = Config() 64 | 65 | CFG.TRAIN.NUM_EPOCHS = 5 66 | CFG.TRAIN.CKPT_SAVE_DIR = 'checkpoints' 67 | 68 | CFG.TRAIN.CKPT_SAVE_STRATEGY = None 69 | 70 | CFG.TRAIN.OPTIM = Config() 71 | CFG.TRAIN.OPTIM.TYPE = 'SGD' 72 | CFG.TRAIN.OPTIM.PARAM = { 73 | 'lr': 0.002, 74 | 'momentum': 0.1, 75 | } 76 | 77 | CFG.TRAIN.DATA = Config() 78 | CFG.TRAIN.DATA.NUM = 100 79 | CFG.TRAIN.DATA.MIN = 0 80 | CFG.TRAIN.DATA.MAX = 100 81 | CFG.TRAIN.DATA.BATCH_SIZE = 4 82 | CFG.TRAIN.DATA.NUM_WORKERS = 2 83 | CFG.TRAIN.DATA.SHUFFLE = True 84 | 85 | return CFG 86 | 87 | 88 | if __name__ == '__main__': 89 | cfg_ = build_cfg() 90 | 91 | launch_training(cfg_, devices='0,1,2,3,4,5,6,7,8') 92 | -------------------------------------------------------------------------------- /easytorch/utils/dist.py: -------------------------------------------------------------------------------- 1 | import functools 2 | 3 | import torch 4 | 5 | from ..device import get_device_count 6 | 7 | # default master rank 8 | MASTER_RANK = 0 9 | 10 | 11 | def get_rank() -> int: 12 | """Get the rank of current process group. 13 | 14 | If DDP is initialized, return `torch.distributed.get_rank()`. 15 | Else return 0 16 | 17 | Returns: 18 | rank (int) 19 | """ 20 | 21 | if torch.distributed.is_initialized(): 22 | return torch.distributed.get_rank() 23 | else: 24 | return 0 25 | 26 | 27 | def get_local_rank() -> int: 28 | """Get the local rank of current process group in multiple compute nodes. 29 | 30 | Returns: 31 | local_rank (int) 32 | """ 33 | 34 | return get_rank() % get_device_count() if get_device_count() != 0 else 0 35 | 36 | 37 | def get_world_size() -> int: 38 | """Get the number of processes in the current process group. 39 | 40 | If DDP is initialized, return ```torch.distributed.get_world_size()```. 41 | Else return 1 42 | 43 | Returns: 44 | world_size (int) 45 | """ 46 | 47 | if torch.distributed.is_initialized(): 48 | return torch.distributed.get_world_size() 49 | else: 50 | return 1 51 | 52 | 53 | def is_rank(rank: int) -> bool: 54 | """Checking if the rank of current process group is equal to ```rank```. 55 | 56 | Notes: 57 | ```rank``` must be less than ```world_size``` 58 | 59 | Args: 60 | rank (int): rank 61 | 62 | Returns: 63 | result (bool) 64 | """ 65 | 66 | if rank >= get_world_size(): 67 | raise ValueError('Rank is out of range') 68 | 69 | return get_rank() == rank 70 | 71 | 72 | def is_master() -> bool: 73 | """Checking if current process is master process. 74 | 75 | The rank of master process is ```MASTER_RANK``` 76 | 77 | Returns: 78 | result (bool) 79 | """ 80 | 81 | return is_rank(MASTER_RANK) 82 | 83 | 84 | def master_only(func): 85 | """An function decorator that the function is only executed in the master process. 86 | 87 | Examples: 88 | @master_only 89 | def func(x): 90 | return 2 ** x 91 | 92 | Args: 93 | func: function 94 | 95 | Returns: 96 | wrapper func 97 | """ 98 | 99 | @functools.wraps(func) 100 | def wrapper(*args, **kwargs): 101 | if is_master(): 102 | return func(*args, **kwargs) 103 | 104 | return wrapper 105 | -------------------------------------------------------------------------------- /README_CN.md: -------------------------------------------------------------------------------- 1 | # EasyTorch 2 | 3 | [![LICENSE](https://img.shields.io/github/license/cnstark/easytorch.svg)](https://github.com/cnstark/easytorch/blob/master/LICENSE) 4 | [![Language grade: Python](https://img.shields.io/lgtm/grade/python/g/cnstark/easytorch.svg?logo=lgtm&logoWidth=18)](https://lgtm.com/projects/g/cnstark/easytorch/context:python) 5 | [![gitee mirror](https://github.com/cnstark/easytorch/actions/workflows/git-mirror.yml/badge.svg)](https://gitee.com/cnstark/easytorch) 6 | 7 | [English](README.md) **|** [简体中文](README_CN.md) 8 | 9 | --- 10 | 11 | EasyTorch是一个基于PyTorch的开源神经网络框架,封装了PyTorch项目中常用的功能,帮助用户快速构建深度学习项目。 12 | 13 | ## :sparkles: 功能亮点 14 | 15 | * :computer: **最小代码量**。EasyTorch封装了通用神经网络训练流程,用户仅需实现`Dataset`、`Model`以及训练/推理等关键代码,就能完成深度学习项目的构建。 16 | * :wrench: **万物基于Config**。用户通过配置文件控制训练模式与超参。EasyTorch根据配置文件内容的MD5自动生成唯一的结果存放目录,调整超参不再凌乱。 17 | * :flashlight: **支持所有设备**。EasyTorch支持CPU、GPU与GPU分布式训练(单机多卡和多机多卡)。用户可以通过配置参数使用,不需要修改任何代码。 18 | * :page_with_curl: **持久化训练日志**。支持`logging`日志系统与`Tensorboard`,并封装为统一接口,用户通过调用简单的接口即可保存自定义的训练日志。 19 | 20 | ## :cd: 环境依赖 21 | 22 | ### 操作系统 23 | 24 | * [Linux](https://pytorch.org/get-started/locally/#linux-prerequisites) 25 | * [Windows](https://pytorch.org/get-started/locally/#windows-prerequisites) 26 | * [MacOS](https://pytorch.org/get-started/locally/#mac-prerequisites) 27 | 28 | 推荐使用Ubuntu16.04及更高版本的系统。 29 | 30 | ### Python 31 | 32 | python >= 3.6 (推荐 >= 3.9) 33 | 34 | 推荐使用[Miniconda](https://docs.conda.io/en/latest/miniconda.html)或者[Anaconda](https://www.anaconda.com/) 35 | 36 | ### PyTorch及CUDA 37 | 38 | [pytorch](https://pytorch.org/) >= 1.4(推荐 >= 1.9)。 39 | 如需使用CUDA,请安装对应 CUDA 版本编译的 PyTorch 包。 40 | 41 | 注意:如需使用安培(Ampere)架构GPU,PyTorch版本需 >= 1.7 且CUDA版本 >= 11.0。 42 | 43 | ## :dart: 开始使用 44 | 45 | ### 安装EasyTorch 46 | 47 | ```shell 48 | pip install easy-torch 49 | ``` 50 | 51 | ### 初始化项目 52 | 53 | TODO 54 | 55 | ## :pushpin: 示例 56 | 57 | * [线性回归](examples/linear_regression) 58 | * [MNIST手写数字识别](examples/mnist) 59 | * [ImageNet图像分类](examples/imagenet) 60 | 61 | *更多示例正在开发途中* 62 | 63 | 推荐参考成熟的开源项目[BasicTS](https://github.com/zezhishao/BasicTS)。 64 | 65 | ## :rocket: 引用 66 | 67 | ### BibTex 引用 68 | 69 | 如果EasyTorch对你的科研或工作有所帮助,可以考虑引用EasyTorch。 70 | BibTex引用条目如下(需要`url`包)。 71 | 72 | ``` latex 73 | @misc{wang2020easytorch, 74 | author = {Yuhao Wang}, 75 | title = {{EasyTorch}: Simple and powerful pytorch framework.}, 76 | howpublished = {\url{https://github.com/cnstark/easytorch}}, 77 | year = {2020} 78 | } 79 | ``` 80 | 81 | ### README 徽章 82 | 83 | 如果你的项目正在使用EasyTorch,可以将EasyTorch徽章 [![EasyTorch](https://img.shields.io/badge/Developing%20with-EasyTorch-2077ff.svg)](https://github.com/cnstark/easytorch) 添加到你的 README 中: 84 | 85 | ``` 86 | [![EasyTorch](https://img.shields.io/badge/Developing%20with-EasyTorch-2077ff.svg)](https://github.com/cnstark/easytorch) 87 | ``` 88 | 89 | ***(完整的文档即将推出)*** 90 | -------------------------------------------------------------------------------- /easytorch/core/data_loader.py: -------------------------------------------------------------------------------- 1 | from typing import Dict 2 | 3 | from torch.utils.data import Dataset, DataLoader 4 | from torch.utils.data.distributed import DistributedSampler 5 | 6 | from ..utils import get_rank, get_world_size 7 | from ..utils.data_prefetcher import DataLoaderX 8 | 9 | 10 | def build_data_loader(dataset: Dataset, data_cfg: Dict): 11 | """Build dataloader from `data_cfg` 12 | `data_cfg` is part of config which defines fields about data, such as `CFG.TRAIN.DATA` 13 | 14 | structure of `data_cfg` is 15 | { 16 | 'BATCH_SIZE': (int, optional) batch size of data loader (default: ``1``), 17 | 'SHUFFLE': (bool, optional) data reshuffled option (default: ``False``), 18 | 'NUM_WORKERS': (int, optional) num workers for data loader (default: ``0``), 19 | 'PIN_MEMORY': (bool, optional) pin_memory option (default: ``False``), 20 | 'PREFETCH': (bool, optional) set to ``True`` to use `DataLoaderX` (default: ``False``), 21 | } 22 | 23 | Args: 24 | dataset (Dataset): dataset defined by user 25 | data_cfg (Dict): data config 26 | 27 | Returns: 28 | data loader 29 | """ 30 | 31 | return (DataLoaderX if data_cfg.get('PREFETCH', False) else DataLoader)( 32 | dataset, 33 | collate_fn=data_cfg.get('COLLATE_FN', None), 34 | batch_size=data_cfg.get('BATCH_SIZE', 1), 35 | shuffle=data_cfg.get('SHUFFLE', False), 36 | num_workers=data_cfg.get('NUM_WORKERS', 0), 37 | pin_memory=data_cfg.get('PIN_MEMORY', False) 38 | ) 39 | 40 | 41 | def build_data_loader_ddp(dataset: Dataset, data_cfg: Dict): 42 | """Build ddp dataloader from `data_cfg` 43 | `data_cfg` is part of config which defines fields about data, such as `CFG.TRAIN.DATA` 44 | 45 | structure of `data_cfg` is 46 | { 47 | 'BATCH_SIZE': (int, optional) batch size of data loader (default: ``1``), 48 | 'SHUFFLE': (bool, optional) data reshuffled option (default: ``False``), 49 | 'NUM_WORKERS': (int, optional) num workers for data loader (default: ``0``), 50 | 'PIN_MEMORY': (bool, optional) pin_memory option (default: ``False``), 51 | 'PREFETCH': (bool, optional) set to ``True`` to use `BackgroundGenerator` (default: ``False``) 52 | need to install `prefetch_generator`, see https://pypi.org/project/prefetch_generator/ 53 | } 54 | 55 | Args: 56 | dataset (Dataset): dataset defined by user 57 | data_cfg (Dict): data config 58 | 59 | Returns: 60 | data loader 61 | """ 62 | 63 | ddp_sampler = DistributedSampler( 64 | dataset, 65 | get_world_size(), 66 | get_rank(), 67 | shuffle=data_cfg.get('SHUFFLE', False) 68 | ) 69 | return (DataLoaderX if data_cfg.get('PREFETCH', False) else DataLoader)( 70 | dataset, 71 | collate_fn=data_cfg.get('COLLATE_FN', None), 72 | batch_size=data_cfg.get('BATCH_SIZE', 1), 73 | shuffle=False, 74 | sampler=ddp_sampler, 75 | num_workers=data_cfg.get('NUM_WORKERS', 0), 76 | pin_memory=data_cfg.get('PIN_MEMORY', False) 77 | ) 78 | -------------------------------------------------------------------------------- /easytorch/device.py: -------------------------------------------------------------------------------- 1 | from typing import Union 2 | 3 | import torch 4 | from torch import nn 5 | 6 | try: 7 | __import__('torch_mlu') 8 | except ImportError: 9 | pass 10 | 11 | __all__ = [ 12 | 'get_device_type', 'set_device_type', 'get_device_count', 'set_device', 'to_device', 'set_device_manual_seed' 13 | ] 14 | 15 | _DEVICE_TYPE = 'gpu' 16 | 17 | 18 | def get_device_type() -> str: 19 | return _DEVICE_TYPE 20 | 21 | 22 | def set_device_type(device_type: str): 23 | global _DEVICE_TYPE 24 | if device_type not in ['gpu', 'mlu', 'cpu']: 25 | raise ValueError('Unknown device type!') 26 | _DEVICE_TYPE = device_type 27 | 28 | 29 | def get_device_count() -> int: 30 | if _DEVICE_TYPE == 'gpu': 31 | return torch.cuda.device_count() 32 | elif _DEVICE_TYPE == 'mlu': 33 | return torch.mlu.device_count() 34 | elif _DEVICE_TYPE == 'cpu': 35 | return 0 36 | else: 37 | raise ValueError('Unknown device type!') 38 | 39 | 40 | def set_device(device_id: int): 41 | if _DEVICE_TYPE == 'gpu': 42 | torch.cuda.set_device(device_id) 43 | elif _DEVICE_TYPE == 'mlu': 44 | torch.mlu.set_device(device_id) 45 | else: 46 | raise ValueError('Unknown device type!') 47 | 48 | 49 | def to_device(src: Union[torch.Tensor, nn.Module], device_id: int = None, 50 | non_blocking: bool = False) -> Union[torch.Tensor, nn.Module]: 51 | kwargs = {'non_blocking': non_blocking} if isinstance(src, torch.Tensor) else {} 52 | if _DEVICE_TYPE == 'gpu': 53 | if device_id is None: 54 | return src.cuda(**kwargs) 55 | else: 56 | return src.to('cuda:{:d}'.format(device_id), **kwargs) 57 | elif _DEVICE_TYPE == 'mlu': 58 | if device_id is None: 59 | return src.mlu(**kwargs) 60 | else: 61 | return src.to('mlu:{:d}'.format(device_id), **kwargs) 62 | elif _DEVICE_TYPE == 'cpu': 63 | return src.cpu() 64 | else: 65 | raise ValueError('Unknown device type!') 66 | 67 | 68 | def init_stream(): 69 | if _DEVICE_TYPE == 'gpu': 70 | return torch.cuda.Stream() 71 | elif _DEVICE_TYPE == 'mlu': 72 | return torch.mlu.Stream() 73 | else: 74 | raise ValueError('Unknown device type!') 75 | 76 | 77 | def stream(st): 78 | if _DEVICE_TYPE == 'gpu': 79 | return torch.cuda.stream(st) 80 | elif _DEVICE_TYPE == 'mlu': 81 | return torch.mlu.stream(st) 82 | else: 83 | raise ValueError('Unknown device type!') 84 | 85 | 86 | def current_stream(): 87 | if _DEVICE_TYPE == 'gpu': 88 | return torch.cuda.current_stream() 89 | elif _DEVICE_TYPE == 'mlu': 90 | return torch.mlu.current_stream() 91 | else: 92 | raise ValueError('Unknown device type!') 93 | 94 | 95 | def set_device_manual_seed(seed: int): 96 | torch.manual_seed(seed) 97 | if _DEVICE_TYPE == 'gpu': 98 | torch.cuda.manual_seed(seed) 99 | torch.cuda.manual_seed_all(seed) 100 | elif _DEVICE_TYPE == 'mlu': 101 | torch.mlu.manual_seed(seed) 102 | torch.mlu.manual_seed_all(seed) 103 | -------------------------------------------------------------------------------- /easytorch/core/optimizer_builder.py: -------------------------------------------------------------------------------- 1 | from typing import Dict 2 | 3 | from torch import nn, optim 4 | from torch.optim import lr_scheduler 5 | 6 | from .. import easyoptim 7 | from ..easyoptim import easy_lr_scheduler 8 | 9 | 10 | def build_optim(optim_cfg: Dict, model: nn.Module) -> optim.Optimizer: 11 | """Build optimizer from `optim_cfg` 12 | `optim_cfg` is part of config which defines fields about optimizer 13 | 14 | structure of `optim_cfg` is 15 | { 16 | 'TYPE': (str or type) optimizer name or type, such as ``Adam``, ``SGD``, 17 | or custom optimizer type. 18 | 'PARAM': (Dict) optimizer init params except first param `params` 19 | } 20 | 21 | Note: 22 | Optimizer is initialized by reflection, please ensure optim_cfg['TYPE'] is in `torch.optim` 23 | 24 | Examples: 25 | optim_cfg = { 26 | 'TYPE': 'Adam', 27 | 'PARAM': { 28 | 'lr': 1e-3, 29 | 'betas': (0.9, 0.99) 30 | 'eps': 1e-8, 31 | 'weight_decay': 0 32 | } 33 | } 34 | An `Adam` optimizer will be built. 35 | 36 | Args: 37 | optim_cfg (Dict): optimizer config 38 | model (nn.Module): model defined by user 39 | 40 | Returns: 41 | optimizer (optim.Optimizer) 42 | """ 43 | 44 | if isinstance(optim_cfg['TYPE'], type): 45 | optim_type = optim_cfg['TYPE'] 46 | else: 47 | if hasattr(optim, optim_cfg['TYPE']): 48 | optim_type = getattr(optim, optim_cfg['TYPE']) 49 | else: 50 | optim_type = getattr(easyoptim, optim_cfg['TYPE']) 51 | optim_param = optim_cfg['PARAM'].copy() 52 | optimizer = optim_type(model.parameters(), **optim_param) 53 | return optimizer 54 | 55 | 56 | def build_lr_scheduler(lr_scheduler_cfg: Dict, optimizer: optim.Optimizer) -> lr_scheduler._LRScheduler: 57 | """Build lr_scheduler from `lr_scheduler_cfg` 58 | `lr_scheduler_cfg` is part of config which defines fields about lr_scheduler 59 | 60 | structure of `lr_scheduler_cfg` is 61 | { 62 | 'TYPE': (str or type) lr_scheduler name or type, such as ``MultiStepLR``, ``CosineAnnealingLR``, 63 | or custom lr_scheduler type 64 | 'PARAM': (Dict) lr_scheduler init params except first param `optimizer` 65 | } 66 | 67 | Note: 68 | LRScheduler is initialized by reflection, please ensure 69 | lr_scheduler_cfg['TYPE'] is in `torch.optim.lr_scheduler` or `easytorch.easyoptim.easy_lr_scheduler`, 70 | if the `type` is not found in `torch.optim.lr_scheduler`, 71 | it will continue to be search in `easytorch.easyoptim.easy_lr_scheduler` 72 | 73 | Examples: 74 | lr_scheduler_cfg = { 75 | 'TYPE': 'MultiStepLR', 76 | 'PARAM': { 77 | 'milestones': [100, 200, 300], 78 | 'gamma': 0.1 79 | } 80 | } 81 | An `MultiStepLR` lr_scheduler will be built. 82 | 83 | Args: 84 | lr_scheduler_cfg (Dict): lr_scheduler config 85 | optimizer (nn.Module): optimizer 86 | 87 | Returns: 88 | LRScheduler 89 | """ 90 | 91 | lr_scheduler_cfg['TYPE'] = lr_scheduler_cfg['TYPE'] 92 | if isinstance(lr_scheduler_cfg['TYPE'], type): 93 | scheduler_type = lr_scheduler_cfg['TYPE'] 94 | else: 95 | if hasattr(lr_scheduler, lr_scheduler_cfg['TYPE']): 96 | scheduler_type = getattr(lr_scheduler, lr_scheduler_cfg['TYPE']) 97 | else: 98 | scheduler_type = getattr(easy_lr_scheduler, lr_scheduler_cfg['TYPE']) 99 | scheduler_param = lr_scheduler_cfg['PARAM'].copy() 100 | scheduler_param['optimizer'] = optimizer 101 | scheduler = scheduler_type(**scheduler_param) 102 | return scheduler 103 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # EasyTorch 2 | 3 | [![LICENSE](https://img.shields.io/github/license/cnstark/easytorch.svg)](https://github.com/cnstark/easytorch/blob/master/LICENSE) 4 | [![PyPI](https://img.shields.io/pypi/v/easy-torch)](https://pypi.org/project/easy-torch/) 5 | [![Language grade: Python](https://img.shields.io/lgtm/grade/python/g/cnstark/easytorch.svg?logo=lgtm&logoWidth=18)](https://lgtm.com/projects/g/cnstark/easytorch/context:python) 6 | [![python lint](https://github.com/cnstark/easytorch/actions/workflows/pylint.yml/badge.svg)](https://github.com/cnstark/easytorch/blob/master/.github/workflows/pylint.yml) 7 | 8 | [English](README.md) **|** [简体中文](README_CN.md) 9 | 10 | EasyTorch is an open source neural network framework based on PyTorch, which encapsulates common functions in PyTorch projects to help users quickly build deep learning projects. 11 | 12 | ## :sparkles: Highlight Characteristics 13 | 14 | * :computer: **Minimum Code**. EasyTorch encapsulates the general neural network training pipeline. Users only need to implement key codes such as `Dataset`, `Model`, and training/inference to build deep learning projects. 15 | * :wrench: **Everything Based on Config**. Users control the training mode and hyperparameters through the config file. EasyTorch automatically generates a unique result storage directory according to the MD5 of the config file content, which help users to adjust hyperparameters more conveniently. 16 | * :flashlight: **Support All Devices**. EasyTorch supports CPU, GPU and GPU distributed training (single node multiple GPUs and multiple nodes). Users can use it by setting parameters without modifying any code. 17 | * :page_with_curl: **Save Training Log**. Support `logging` log system and `Tensorboard`, and encapsulate it as a unified interface, users can save customized training logs by calling simple interfaces. 18 | 19 | ## :cd: Dependence 20 | 21 | ### OS 22 | 23 | * [Linux](https://pytorch.org/get-started/locally/#linux-prerequisites) 24 | * [Windows](https://pytorch.org/get-started/locally/#windows-prerequisites) 25 | * [MacOS](https://pytorch.org/get-started/locally/#mac-prerequisites) 26 | 27 | Ubuntu 16.04 and later systems are recommended. 28 | 29 | ### Python 30 | 31 | python >= 3.6 (recommended >= 3.9) 32 | 33 | [Miniconda](https://docs.conda.io/en/latest/miniconda.html) or [Anaconda](https://www.anaconda.com/) are recommended. 34 | 35 | ### PyTorch and CUDA 36 | 37 | [pytorch](https://pytorch.org/) >= 1.4 (recommended >= 1.9). 38 | To use CUDA, please install the PyTorch package compiled with the corresponding CUDA version. 39 | 40 | Note: To use Ampere GPU, PyTorch version >= 1.7 and CUDA version >= 11.0. 41 | 42 | ## :dart: Get Started 43 | 44 | ### Installation 45 | 46 | ```shell 47 | pip install easy-torch 48 | ``` 49 | 50 | ### Initialize Project 51 | 52 | TODO 53 | 54 | ## :pushpin: Examples 55 | 56 | * [Linear Regression](examples/linear_regression) 57 | * [MNIST Digit Recognition](examples/mnist) 58 | * [ImageNet Image Classification](examples/imagenet) 59 | 60 | *More examples are on the way* 61 | 62 | It is recommended to refer to the excellent open source project [BasicTS](https://github.com/zezhishao/BasicTS). 63 | 64 | ## :rocket: Citations 65 | 66 | ### BibTex Citations 67 | 68 | If EasyTorch helps your research or work, please consider citing EasyTorch. 69 | The BibTex reference item is as follows(requires the `url` LaTeX package). 70 | 71 | ``` latex 72 | @misc{wang2020easytorch, 73 | author = {Yuhao Wang}, 74 | title = {{EasyTorch}: Simple and powerful pytorch framework.}, 75 | howpublished = {\url{https://github.com/cnstark/easytorch}}, 76 | year = {2020} 77 | } 78 | ``` 79 | 80 | ### README Badge 81 | 82 | If your project is using EasyTorch, please consider put the EasyTorch badge [![EasyTorch](https://img.shields.io/badge/Developing%20with-EasyTorch-2077ff.svg)](https://github.com/cnstark/easytorch) add to your README. 83 | 84 | ``` 85 | [![EasyTorch](https://img.shields.io/badge/Developing%20with-EasyTorch-2077ff.svg)](https://github.com/cnstark/easytorch) 86 | ``` 87 | 88 | ***(Full documentation is coming soon)*** 89 | -------------------------------------------------------------------------------- /easytorch/utils/timer.py: -------------------------------------------------------------------------------- 1 | import time 2 | from typing import Union 3 | 4 | 5 | class Timer: 6 | """Timer with multiple record 7 | 8 | Examples: 9 | >>> timer = Timer() 10 | >>> time.sleep(1) 11 | >>> timer.record('one') 12 | >>> time.sleep(2) 13 | >>> timer.record('two') 14 | >>> timer.print() 15 | Start:: [diff: 0.000000, total: 0.000000] 16 | one:: [diff: 1.002618, total: 1.002618] 17 | two:: [diff: 2.003077, total: 3.005695] 18 | >>> print(timer.get(2)) 19 | (2.0030770301818848, 3.005695104598999) 20 | >>> print(timer.get(1)) 21 | (1.0026180744171143, 1.0026180744171143) 22 | >>> print(timer.get(2, 0)) 23 | (3.005695104598999, 3.005695104598999) 24 | """ 25 | 26 | def __init__(self): 27 | self._record_dict = {'Start': time.time()} 28 | self._record_names = ['Start'] 29 | 30 | def record(self, name: str = None): 31 | """Record a checkpoint 32 | 33 | Args: 34 | name (str): checkpoint name (default is Record_i, i is index) 35 | """ 36 | 37 | if name is None: 38 | name = 'Record_{:d}'.format(len(self._record_names)) 39 | elif self._record_dict.get(name) is not None: 40 | raise ValueError('Name \'{}\' already exists'.format(name)) 41 | 42 | self._record_dict[name] = time.time() 43 | self._record_names.append(name) 44 | 45 | def print(self): 46 | """Print all checkpoints of this timer 47 | """ 48 | start_time_record = last_time_record = self._record_dict['Start'] 49 | for name in self._record_names: 50 | time_record = self._record_dict[name] 51 | time_diff = time_record - last_time_record 52 | time_total = time_record - start_time_record 53 | last_time_record = time_record 54 | print('{}:: [diff: {:2f}, total: {:2f}]'.format(name, time_diff, time_total)) 55 | 56 | def get(self, end: Union[str, int], start: Union[str, int] = None): 57 | """Get the time from the ```start``` to the```end```(diff), 58 | and the time from timer initialization to the ```end```(total). 59 | 60 | Notes: 61 | If start is none, default is the previous one of the ```end```. 62 | 63 | Args: 64 | end (Union[str, int]): end checkpoint name or index 65 | start (Union[str, int]): start checkpoint name or index 66 | 67 | Returns: 68 | (diff, total) 69 | """ 70 | 71 | # end 72 | if isinstance(end, int): 73 | end_record_index = end 74 | end_record_name = self._record_names[end_record_index] 75 | else: 76 | end_record_name = end 77 | end_record_index = self._record_names.index(end_record_name) 78 | end_record_time = self._record_dict[end_record_name] 79 | 80 | # start 81 | if start is None: 82 | start_record_index = max(end_record_index - 1, 0) 83 | start_record_name = self._record_names[start_record_index] 84 | elif isinstance(start, int): 85 | start_record_name = self._record_names[start] 86 | else: 87 | start_record_name = start 88 | start_record_time = self._record_dict[start_record_name] 89 | 90 | return end_record_time - start_record_time, end_record_time - self._record_dict['Start'] 91 | 92 | 93 | class TimePredictor: 94 | """TimePredictor 95 | """ 96 | 97 | def __init__(self, start_step: int, end_step: int): 98 | self.start_step = start_step 99 | self.end_step = end_step 100 | self.start_time = time.time() 101 | 102 | def get_remaining_time(self, step: int) -> float: 103 | now_time = time.time() 104 | return (now_time - self.start_time) * (self.end_step - self.start_step) / (step - self.start_step) 105 | 106 | def get_expected_end_time(self, step: int) -> float: 107 | return self.start_time + self.get_remaining_time(step) 108 | -------------------------------------------------------------------------------- /examples/mnist/mnist_runner.py: -------------------------------------------------------------------------------- 1 | from typing import Dict, Union, Tuple 2 | 3 | import torch 4 | from torch import nn 5 | import torchvision 6 | 7 | from easytorch import Runner 8 | from easytorch.device import to_device 9 | 10 | from conv_net import ConvNet 11 | 12 | 13 | class MNISTRunner(Runner): 14 | """MNISTRunner 15 | """ 16 | 17 | def init_training(self, cfg: Dict): 18 | """Initialize training. 19 | 20 | Including loss, training meters, etc. 21 | 22 | Args: 23 | cfg (Dict): config 24 | """ 25 | 26 | super().init_training(cfg) 27 | 28 | self.loss = nn.NLLLoss() 29 | self.loss = to_device(self.loss) 30 | 31 | self.register_epoch_meter('train_loss', 'train', '{:.2f}') 32 | 33 | def init_validation(self, cfg: Dict): 34 | """Initialize validation. 35 | 36 | Including validation meters, etc. 37 | 38 | Args: 39 | cfg (Dict): config 40 | """ 41 | 42 | super().init_validation(cfg) 43 | 44 | self.register_epoch_meter('val_acc', 'val', '{:.2f}%') 45 | 46 | @staticmethod 47 | def define_model(cfg: Dict) -> nn.Module: 48 | """Define model. 49 | 50 | If you have multiple models, insert the name and class into the dict below, 51 | and select it through ```config```. 52 | 53 | Args: 54 | cfg (Dict): config 55 | 56 | Returns: 57 | model (nn.Module) 58 | """ 59 | 60 | return { 61 | 'conv_net': ConvNet 62 | }[cfg['MODEL']['NAME']](**cfg['MODEL'].get('PARAM', {})) 63 | 64 | @staticmethod 65 | def build_train_dataset(cfg: Dict): 66 | """Build MNIST train dataset 67 | 68 | Args: 69 | cfg (Dict): config 70 | 71 | Returns: 72 | train dataset (Dataset) 73 | """ 74 | 75 | return torchvision.datasets.MNIST( 76 | cfg['TRAIN']['DATA']['DIR'], train=True, download=True, 77 | transform=torchvision.transforms.Compose([ 78 | torchvision.transforms.ToTensor(), 79 | torchvision.transforms.Normalize( 80 | (0.1307,), (0.3081,)) 81 | ]) 82 | ) 83 | 84 | @staticmethod 85 | def build_val_dataset(cfg: Dict): 86 | """Build MNIST val dataset 87 | 88 | Args: 89 | cfg (Dict): config 90 | 91 | Returns: 92 | train dataset (Dataset) 93 | """ 94 | 95 | return torchvision.datasets.MNIST( 96 | cfg['VAL']['DATA']['DIR'], train=False, download=True, 97 | transform=torchvision.transforms.Compose([ 98 | torchvision.transforms.ToTensor(), 99 | torchvision.transforms.Normalize( 100 | (0.1307,), (0.3081,)) 101 | ]) 102 | ) 103 | 104 | def train_iters(self, epoch: int, iter_index: int, data: Union[torch.Tensor, Tuple]) -> torch.Tensor: 105 | """Training details. 106 | 107 | Args: 108 | epoch (int): current epoch. 109 | iter_index (int): current iter. 110 | data (torch.Tensor or tuple): Data provided by DataLoader 111 | 112 | Returns: 113 | loss (torch.Tensor) 114 | """ 115 | 116 | input_, target_ = data 117 | input_ = to_device(input_) 118 | target_ = to_device(target_) 119 | 120 | output = self.model(input_) 121 | loss = self.loss(output, target_) 122 | self.update_epoch_meter('train_loss', loss.item()) 123 | return loss 124 | 125 | def val_iters(self, iter_index: int, data: Union[torch.Tensor, Tuple]): 126 | """Validation details. 127 | 128 | Args: 129 | iter_index (int): current iter. 130 | data (torch.Tensor or tuple): Data provided by DataLoader 131 | """ 132 | 133 | input_, target_ = data 134 | input_ = to_device(input_) 135 | target_ = to_device(target_) 136 | 137 | output = self.model(input_) 138 | pred = output.data.max(1, keepdim=True)[1] 139 | self.update_epoch_meter('val_acc', 100 * pred.eq(target_.data.view_as(pred)).sum()) 140 | -------------------------------------------------------------------------------- /easytorch/utils/registry.py: -------------------------------------------------------------------------------- 1 | # Modified from: https://github.com/facebookresearch/fvcore/blob/master/fvcore/common/registry.py # noqa: E501 2 | # Modified from: https://github.com/xinntao/BasicSR/blob/master/basicsr/utils/registry.py 3 | # pyre-ignore-all-errors[2,3] 4 | import os 5 | import importlib 6 | from copy import deepcopy 7 | import platform 8 | from typing import Any, Dict, Iterable, Iterator, Tuple, List 9 | 10 | from .misc import scan_dir 11 | 12 | 13 | __all__ = ['Registry', 'scan_modules'] 14 | 15 | 16 | class Registry(Iterable[Tuple[str, Any]]): 17 | """ 18 | The registry that provides name -> object mapping, to support third-party 19 | users' custom modules. 20 | To create a registry (e.g. a backbone registry): 21 | .. code-block:: python 22 | BACKBONE_REGISTRY = Registry('BACKBONE') 23 | To register an object: 24 | .. code-block:: python 25 | @BACKBONE_REGISTRY.register() 26 | class MyBackbone(): 27 | ... 28 | Or: 29 | .. code-block:: python 30 | BACKBONE_REGISTRY.register(MyBackbone) 31 | """ 32 | 33 | def __init__(self, name: str) -> None: 34 | """ 35 | Args: 36 | name (str): the name of this registry 37 | """ 38 | self._name: str = name 39 | self._obj_map: Dict[str, Any] = {} 40 | 41 | def _do_register(self, name: str, obj: Any) -> None: 42 | if name in self._obj_map: 43 | raise ValueError('An object named \'{}\' was already registered in \'{}\' registry!'.format( 44 | name, self._name 45 | )) 46 | 47 | self._obj_map[name] = obj 48 | 49 | def register(self, obj: Any = None, name: str = None) -> Any: 50 | """ 51 | Register the given object under the the name `obj.__name__`. 52 | Can be used as either a decorator or not. See docstring of this class for usage. 53 | """ 54 | 55 | if obj is None: 56 | # used as a decorator 57 | def deco(func_or_class: Any) -> Any: 58 | self._do_register(func_or_class.__name__ if name is None else name, func_or_class) 59 | return func_or_class 60 | 61 | return deco 62 | 63 | # used as a function call 64 | self._do_register(obj.__name__ if name is None else name, obj) 65 | 66 | def get(self, name: str) -> Any: 67 | ret = self._obj_map.get(name) 68 | if ret is None: 69 | raise KeyError( 70 | 'No object named \'{}\' found in \'{}\' registry!'.format(name, self._name) 71 | ) 72 | return ret 73 | 74 | def build(self, name: str, params: Dict[str, Any] = None): 75 | if params is None: 76 | params = {} 77 | else: 78 | params = deepcopy(params) 79 | return self.get(name)(**params) 80 | 81 | def __contains__(self, name: str) -> bool: 82 | return name in self._obj_map 83 | 84 | def __repr__(self) -> str: 85 | return 'Registry of {}:\n{}'.format(self._name, self._obj_map) 86 | 87 | def __iter__(self) -> Iterator[Tuple[str, Any]]: 88 | return iter(self._obj_map.items()) 89 | 90 | # pyre-fixme[4]: Attribute must be annotated. 91 | __str__ = __repr__ 92 | 93 | 94 | def scan_modules(work_dir: str, file_dir: str, exclude_files: List[str] = None): 95 | """ 96 | automatically scan and import modules for registry 97 | """ 98 | if platform.system().lower() == 'windows': 99 | # On Windows systems, os.getcwd() (i.e., work_dir) will get an uppercase drive letter, such as C:\\Users\\... 100 | # However, the drive letter obtained by __file__ (i.e., file_dir) is lowercase, such as c:\\Users\\... 101 | file_dir = file_dir[0].upper() + file_dir[1:] 102 | module_dir = os.path.dirname(os.path.abspath(file_dir)) 103 | import_prefix = module_dir[module_dir.find(work_dir) + len(work_dir) + 1:].replace('/', '.').replace('\\', '.') 104 | 105 | if exclude_files is None: 106 | exclude_files = [] 107 | 108 | model_file_names = [ 109 | v[:v.find('.py')].replace('/', '.').replace('\\', '.') \ 110 | for v in scan_dir(module_dir, suffix='py', recursive=True) if v not in exclude_files 111 | ] 112 | 113 | # import all modules 114 | return [importlib.import_module(f'{import_prefix}.{file_name}') for file_name in model_file_names] 115 | -------------------------------------------------------------------------------- /easytorch/launcher/launcher.py: -------------------------------------------------------------------------------- 1 | import traceback 2 | from typing import Callable, Dict, Union, Tuple 3 | 4 | from ..config import init_cfg 5 | from ..utils import set_visible_devices, get_logger, is_master 6 | from ..device import set_device_type 7 | from .dist_wrap import dist_wrap 8 | 9 | def training_func(cfg: Dict): 10 | """Start training 11 | 12 | 1. Init runner defined by `cfg` 13 | 2. Init logger 14 | 3. Call `train()` method in the runner 15 | 16 | Args: 17 | cfg (Dict): Easytorch config. 18 | """ 19 | 20 | # init runner 21 | logger = get_logger('easytorch-launcher') 22 | if is_master(): 23 | logger.info('Initializing runner "{}"'.format(cfg['RUNNER'])) 24 | runner = cfg['RUNNER'](cfg) 25 | 26 | # init logger (after making ckpt save dir) 27 | runner.init_logger(logger_name='easytorch-training', log_file_name='training_log') 28 | 29 | # train 30 | try: 31 | runner.train(cfg) 32 | except BaseException as e: 33 | # log exception to file 34 | runner.logger.error(traceback.format_exc()) 35 | raise e 36 | 37 | 38 | def launch_training(cfg: Union[Dict, str], devices: str = None, node_rank: int = 0): 39 | """Launch training process defined by `cfg`. 40 | 41 | Support distributed data parallel training when the number of available GPUs is greater than one. 42 | Nccl backend is used by default. 43 | 44 | Notes: 45 | If `GPU_NUM` in `cfg` is greater than `0`, easytorch will run in GPU mode; 46 | If `GPU_NUM` in `cfg` is `0`, easytorch will run in CPU mode. 47 | In order to ensure the consistency of training results, the number of available GPUs 48 | must be equal to `GPU_NUM` in GPU mode. 49 | 50 | Args: 51 | cfg (Union[Dict, str]): Easytorch config. 52 | devices (str): set ``CUDA_VISIBLE_DEVICES`` environment variable. 53 | node_rank (int): Rank of the current node. 54 | """ 55 | 56 | logger = get_logger('easytorch-launcher') 57 | logger.info('Launching EasyTorch training.') 58 | 59 | cfg = init_cfg(cfg, node_rank == 0) 60 | 61 | if cfg.get('DEVICE') is not None: 62 | set_device_type(cfg['DEVICE']) 63 | device_num = cfg.get('DEVICE_NUM', 0) 64 | elif cfg.get('GPU_NUM', 0) != 0 or cfg.get('MLU_NUM', 0) != 0: 65 | if cfg.get('GPU_NUM', 0) != 0 and cfg.get('MLU_NUM', 0) == 0: 66 | set_device_type('gpu') 67 | device_num = cfg.get('GPU_NUM', 0) 68 | elif cfg.get('GPU_NUM', 0) == 0 and cfg.get('MLU_NUM', 0) != 0: 69 | set_device_type('mlu') 70 | device_num = cfg.get('MLU_NUM', 0) 71 | else: 72 | raise ValueError('At least one of `CFG.GPU_NUM` and `CFG.MLU_NUM` is 0.') 73 | set_visible_devices(devices) 74 | else: 75 | set_device_type('cpu') 76 | device_num = 0 77 | 78 | train_dist = dist_wrap( 79 | training_func, 80 | node_num=cfg.get('DIST_NODE_NUM', 1), 81 | device_num=device_num, 82 | node_rank=node_rank, 83 | dist_backend=cfg.get('DIST_BACKEND'), 84 | init_method=cfg.get('DIST_INIT_METHOD') 85 | ) 86 | train_dist(cfg) 87 | 88 | 89 | def launch_runner(cfg: Union[Dict, str], fn: Callable, args: Tuple = (), device_type: str = 'gpu', devices: str = None): 90 | """Launch runner defined by `cfg`, and call `fn`. 91 | 92 | Args: 93 | cfg (Union[Dict, str]): Easytorch config. 94 | fn (Callable): Function is called after init runner. 95 | The function is called as ``fn(cfg, runner, *args)``, where ``cfg`` is 96 | the Easytorch config and ``runner`` is the runner defined by ``cfg`` and 97 | ``args`` is the passed through tuple of arguments. 98 | args (tuple): Arguments passed to ``fn``. 99 | device_type (str): Device type. Valid values are ['cpu', 'gpu', 'mlu']. 100 | devices (str): set ``CUDA_VISIBLE_DEVICES`` environment variable. 101 | """ 102 | 103 | logger = get_logger('easytorch-launcher') 104 | logger.info('Launching EasyTorch runner.') 105 | 106 | cfg = init_cfg(cfg, True) 107 | 108 | set_device_type(device_type) 109 | 110 | if device_type != 'cpu': 111 | set_visible_devices(devices) 112 | 113 | # init runner 114 | runner = cfg['RUNNER'](cfg) 115 | 116 | # call fn 117 | return fn(cfg, runner, *args) 118 | -------------------------------------------------------------------------------- /examples/imagenet/imagenet_runner.py: -------------------------------------------------------------------------------- 1 | from typing import Dict, Union, Tuple, Optional 2 | 3 | import torch 4 | from torch import nn 5 | from torch.utils.data import Dataset 6 | from torchvision import models, datasets, transforms 7 | 8 | from easytorch import Runner 9 | from easytorch.device import to_device 10 | 11 | 12 | def accuracy(output, target, topk=(1,)): 13 | """Computes the accuracy over the k top predictions for the specified values of k""" 14 | with torch.no_grad(): 15 | maxk = max(topk) 16 | batch_size = target.size(0) 17 | 18 | _, pred = output.topk(maxk, 1, True, True) 19 | pred = pred.t() 20 | correct = pred.eq(target.view(1, -1).expand_as(pred)) 21 | 22 | res = [] 23 | for k in topk: 24 | correct_k = correct[:k].reshape(-1).float().sum(0, keepdim=True) 25 | res.append(correct_k.mul_(100.0 / batch_size)) 26 | return res 27 | 28 | 29 | class ImagenetRunner(Runner): 30 | """ImagenetRunner 31 | """ 32 | 33 | def __init__(self, cfg: Dict): 34 | super().__init__(cfg) 35 | 36 | self.criterion = nn.CrossEntropyLoss() 37 | self.criterion = to_device(self.criterion) 38 | 39 | def init_training(self, cfg: Dict): 40 | super().init_training(cfg) 41 | 42 | self.register_epoch_meter('train/loss', 'train', '{:.4e}') 43 | self.register_epoch_meter('train/acc@1', 'train', '{:6.2f}') 44 | self.register_epoch_meter('train/acc@5', 'train', '{:6.2f}') 45 | 46 | def init_validation(self, cfg: Dict): 47 | super().init_validation(cfg) 48 | 49 | self.register_epoch_meter('val/loss', 'val', '{:.4e}') 50 | self.register_epoch_meter('val/acc@1', 'val', '{:6.2f}') 51 | self.register_epoch_meter('val/acc@5', 'val', '{:6.2f}') 52 | 53 | @staticmethod 54 | def define_model(cfg: Dict) -> nn.Module: 55 | return models.__dict__[cfg['MODEL']['NAME']]() 56 | 57 | @staticmethod 58 | def build_train_dataset(cfg: Dict) -> Dataset: 59 | normalize = transforms.Normalize(**cfg['TRAIN']['DATA']['NORMALIZE']) 60 | return datasets.ImageFolder( 61 | cfg['TRAIN']['DATA']['DIR'], 62 | transforms.Compose([ 63 | transforms.RandomResizedCrop(cfg['TRAIN']['DATA']['CROP_SIZE']), 64 | transforms.RandomHorizontalFlip(), 65 | transforms.ToTensor(), 66 | normalize, 67 | ])) 68 | 69 | @staticmethod 70 | def build_val_dataset(cfg: Dict): 71 | normalize = transforms.Normalize(**cfg['VAL']['DATA']['NORMALIZE']) 72 | return datasets.ImageFolder( 73 | cfg['VAL']['DATA']['DIR'], 74 | transforms.Compose([ 75 | transforms.Resize(cfg['VAL']['DATA']['RESIZE']), 76 | transforms.CenterCrop(cfg['VAL']['DATA']['CROP_SIZE']), 77 | transforms.ToTensor(), 78 | normalize, 79 | ])) 80 | 81 | def train_iters(self, epoch: int, iter_index: int, data: Union[torch.Tensor, Tuple]) -> torch.Tensor: 82 | images, target = data 83 | 84 | images = to_device(images) 85 | target = to_device(target) 86 | 87 | output = self.model(images) 88 | 89 | loss = self.criterion(output, target) 90 | 91 | # measure accuracy and record loss 92 | acc1, acc5 = accuracy(output, target, topk=(1, 5)) 93 | 94 | self.update_epoch_meter('train/loss', loss.item(), images.size(0)) 95 | self.update_epoch_meter('train/acc@1', acc1[0], images.size(0)) 96 | self.update_epoch_meter('train/acc@5', acc5[0], images.size(0)) 97 | 98 | return loss 99 | 100 | def val_iters(self, iter_index: int, data: Union[torch.Tensor, Tuple]): 101 | images, target = data 102 | 103 | images = to_device(images) 104 | target = to_device(target) 105 | 106 | output = self.model(images) 107 | 108 | loss = self.criterion(output, target) 109 | 110 | # measure accuracy and record loss 111 | acc1, acc5 = accuracy(output, target, topk=(1, 5)) 112 | 113 | self.update_epoch_meter('val/loss', loss.item(), images.size(0)) 114 | self.update_epoch_meter('val/acc@1', acc1[0], images.size(0)) 115 | self.update_epoch_meter('val/acc@5', acc5[0], images.size(0)) 116 | 117 | def on_validating_end(self, train_epoch: Optional[int]): 118 | # `None` means validation mode 119 | if train_epoch is not None: 120 | self.save_best_model(train_epoch, 'val/acc@1', greater_best=True) 121 | -------------------------------------------------------------------------------- /easytorch/core/checkpoint.py: -------------------------------------------------------------------------------- 1 | import os 2 | import re 3 | import glob 4 | from logging import Logger 5 | from typing import Dict, List, Tuple, Union 6 | 7 | import torch 8 | 9 | from ..utils import get_logger 10 | from ..device import to_device 11 | 12 | 13 | DEFAULT_LOGGER = get_logger('easytorch-checkpoint') 14 | 15 | 16 | def get_last_ckpt_path(ckpt_save_dir: str, name_pattern: str = r'^.+_[\d]*.pt$') -> str: 17 | r"""Get last checkpoint path in `ckpt_save_dir` 18 | checkpoint files will be sorted by name 19 | 20 | Args: 21 | ckpt_save_dir (str): checkpoint save directory 22 | name_pattern (str): re pattern for checkpoint file name, default is r'^.+_[\d]*.pt$' 23 | 24 | Returns: 25 | checkpoint path (str): last checkpoint path in `ckpt_save_dir` 26 | """ 27 | 28 | ckpt_list = [f for f in os.listdir(ckpt_save_dir) if re.search(name_pattern, f) is not None] 29 | ckpt_list.sort() 30 | return os.path.join(ckpt_save_dir, ckpt_list[-1]) 31 | 32 | 33 | def load_ckpt(ckpt_save_dir: str, ckpt_path: str = None, logger: Logger = DEFAULT_LOGGER) -> Dict: 34 | """Load checkpoint 35 | if param `ckpt_path` is None, load the last checkpoint in `ckpt_save_dir`, 36 | else load checkpoint from `ckpt_path` 37 | 38 | Args: 39 | ckpt_save_dir (str): checkpoint save directory 40 | ckpt_path (str): checkpoint path, default is None 41 | logger (Logger): logger, default is Logger('easytorch') 42 | 43 | Returns: 44 | checkpoint dict loaded from file 45 | """ 46 | 47 | if ckpt_path is None: 48 | ckpt_path = get_last_ckpt_path(ckpt_save_dir) 49 | 50 | logger.info('Loading Checkpoint from \'{}\''.format(ckpt_path)) 51 | return torch.load(ckpt_path, map_location=lambda storage, loc: to_device(storage)) 52 | 53 | 54 | def save_ckpt(ckpt: Dict, ckpt_path: str, logger: Logger = DEFAULT_LOGGER): 55 | """Save checkpoint 56 | 57 | Args: 58 | ckpt (Dict): saved checkpoint dict 59 | ckpt_path (str): checkpoint save path 60 | logger (Logger): logger, default is Logger('easytorch') 61 | """ 62 | 63 | torch.save(ckpt, ckpt_path) 64 | logger.info('Checkpoint {} saved'.format(ckpt_path)) 65 | 66 | 67 | def need_to_remove_last_ckpt(last_epoch: int, ckpt_save_strategy: Union[int, List, Tuple]) -> bool: 68 | """Judging whether to remove last checkpoint by `ckpt_save_strategy` 69 | 70 | `ckpt_save_strategy` should be None, an int value, a list or a tuple 71 | if `ckpt_save_strategy` is None, remove last checkpoint file every epoch 72 | if `ckpt_save_strategy` is an int value `n`, save checkpoint every n epoch, 73 | remove last checkpoint file when last_epoch % ckpt_save_strategy != 0 74 | if `ckpt_save_strategy` is a list or a tuple `l`, save checkpoint when epoch in `l`, 75 | remove last checkpoint file when last_epoch not in ckpt_save_strategy 76 | 77 | Args: 78 | last_epoch (int): last epoch num 79 | ckpt_save_strategy (Union[int, List, Tuple]): checkpoint save strategy 80 | 81 | Returns: 82 | last checkpoint delete flag (bool): `True` means delete last checkpoint 83 | """ 84 | 85 | if ckpt_save_strategy is None: 86 | return True 87 | elif isinstance(ckpt_save_strategy, int) and last_epoch % ckpt_save_strategy != 0: 88 | return True 89 | elif isinstance(ckpt_save_strategy, (list, tuple)) and last_epoch not in ckpt_save_strategy: 90 | return True 91 | else: 92 | return False 93 | 94 | 95 | def backup_last_ckpt(last_ckpt_path: str, epoch: int, ckpt_save_strategy: Union[int, List, Tuple]): 96 | """Backup last checkpoint when last checkpoint needs to be removed (by call need_to_remove_last_ckpt()) 97 | if last checkpoint file name is `a.pt`, rename `a.pt` to `a.pt.bak` 98 | 99 | Args: 100 | last_ckpt_path (str): last checkpoint file path 101 | epoch (int): current epoch num 102 | ckpt_save_strategy (Union[int, List, Tuple]): checkpoint save strategy 103 | """ 104 | 105 | last_epoch = epoch - 1 106 | 107 | # rename last ckpt to .bak 108 | if need_to_remove_last_ckpt(last_epoch, ckpt_save_strategy) and last_epoch != 0: 109 | os.rename(last_ckpt_path, last_ckpt_path + '.bak') 110 | 111 | 112 | def clear_ckpt(ckpt_save_dir: str, name_pattern: str = '*.pt.bak'): 113 | """Clear all backed up checkpoint files 114 | 115 | Args: 116 | ckpt_save_dir (str): checkpoint save directory 117 | name_pattern (str): backed up checkpoint file name pattern 118 | """ 119 | 120 | ckpt_list = glob.glob(os.path.join(ckpt_save_dir, name_pattern)) 121 | for ckpt in ckpt_list: 122 | os.remove(ckpt) 123 | -------------------------------------------------------------------------------- /easytorch/utils/env.py: -------------------------------------------------------------------------------- 1 | import os 2 | import random 3 | from typing import Dict 4 | 5 | import torch 6 | import numpy as np 7 | 8 | from .logging import get_logger 9 | from .dist import get_rank 10 | from ..device import get_device_type, set_device_manual_seed 11 | 12 | 13 | def set_visible_devices(devices: str): 14 | """Set environment variable `CUDA_VISIBLE_DEVICES` to select GPU devices. 15 | 16 | Examples: 17 | set_devices('0,1,2,3') 18 | 19 | Args: 20 | devices (str): environment variable `CUDA_VISIBLE_DEVICES` value 21 | """ 22 | 23 | logger = get_logger('easytorch-env') 24 | if devices is not None: 25 | os.environ[{ 26 | 'gpu': 'CUDA_VISIBLE_DEVICES', 27 | 'mlu': 'MLU_VISIBLE_DEVICES' 28 | }[get_device_type()]] = devices 29 | logger.info('Use devices {}.'.format(devices)) 30 | else: 31 | logger.info('Use all devices.') 32 | 33 | 34 | def set_tf32_mode(tf32_mode: bool): 35 | """Set tf32 mode on Ampere gpu when torch version >= 1.7.0 and cuda version >= 11.0. 36 | See https://pytorch.org/docs/stable/notes/cuda.html#tf32-on-ampere 37 | 38 | Args: 39 | tf32_mode (bool): set to ``True`` to enable tf32 mode. 40 | """ 41 | 42 | logger = get_logger('easytorch-env') 43 | if get_device_type() == 'gpu': 44 | if torch.__version__ >= '1.7.0': 45 | if tf32_mode: 46 | logger.info('Enable TF32 mode') 47 | else: 48 | # disable tf32 mode on Ampere gpu 49 | torch.backends.cuda.matmul.allow_tf32 = False 50 | torch.backends.cudnn.allow_tf32 = False 51 | logger.info('Disable TF32 mode') 52 | else: 53 | if tf32_mode: 54 | raise RuntimeError('Torch version {} does not support tf32'.format(torch.__version__)) 55 | else: 56 | if tf32_mode: 57 | raise RuntimeError('Device {} does not support tf32.'.format(get_device_type())) 58 | 59 | 60 | def setup_determinacy(seed: int, deterministic: bool = False, cudnn_enabled: bool = True, 61 | cudnn_benchmark: bool = True, cudnn_deterministic: bool = False): 62 | """Setup determinacy. 63 | 64 | Including `python`, `random`, `numpy`, `torch` 65 | 66 | Args: 67 | seed (int): random seed. 68 | deterministic (bool): Use deterministic algorithms. 69 | See https://pytorch.org/docs/stable/generated/torch.use_deterministic_algorithms.html. 70 | cudnn_enabled (bool): Enable cudnn. 71 | See https://pytorch.org/docs/stable/backends.html 72 | cudnn_benchmark (bool): Enable cudnn benchmark. 73 | See https://pytorch.org/docs/stable/backends.html 74 | cudnn_deterministic (bool): Enable cudnn deterministic algorithms. 75 | See https://pytorch.org/docs/stable/backends.html 76 | """ 77 | 78 | logger = get_logger('easytorch-env') 79 | 80 | os.environ['PYTHONHASHSEED'] = str(seed) 81 | random.seed(seed) 82 | np.random.seed(seed) 83 | 84 | set_device_manual_seed(seed) 85 | 86 | if deterministic: 87 | if get_device_type() == 'gpu': 88 | os.environ['CUBLAS_WORKSPACE_CONFIG'] = ':4096:8' 89 | 90 | if torch.__version__ < '1.7.0': 91 | pass 92 | elif torch.__version__ < '1.8.0': 93 | torch.set_deterministic(True) 94 | else: 95 | torch.use_deterministic_algorithms(True) 96 | logger.info('Use deterministic algorithms.') 97 | 98 | if get_device_type() == 'gpu': 99 | if not cudnn_enabled: 100 | torch.backends.cudnn.enabled = False 101 | logger.info('Unset cudnn enabled.') 102 | if not cudnn_benchmark: 103 | torch.backends.cudnn.benchmark = False 104 | logger.info('Unset cudnn benchmark.') 105 | if cudnn_deterministic: 106 | torch.backends.cudnn.deterministic = True 107 | logger.info('Set cudnn deterministic.') 108 | 109 | 110 | def set_env(env_cfg: Dict): 111 | """Setup runtime env, include tf32, seed and determinacy. 112 | 113 | env config template: 114 | ``` 115 | CFG.ENV = Config() 116 | CFG.ENV.TF32 = False 117 | CFG.ENV.SEED = 42 118 | CFG.ENV.DETERMINISTIC = True 119 | CFG.ENV.CUDNN = Config() 120 | CFG.ENV.CUDNN.ENABLED = False 121 | CFG.ENV.CUDNN.BENCHMARK = False 122 | CFG.ENV.CUDNN.DETERMINISTIC = True 123 | ``` 124 | 125 | Args: 126 | env_cfg (Dict): env config. 127 | """ 128 | 129 | # tf32 130 | set_tf32_mode(env_cfg.get('TF32', False)) 131 | 132 | # determinacy 133 | seed = env_cfg.get('SEED') 134 | if seed is not None: 135 | # each rank has different seed in distributed mode 136 | setup_determinacy( 137 | seed + get_rank(), 138 | env_cfg.get('DETERMINISTIC', False), 139 | env_cfg.get('CUDNN.ENABLED', True), 140 | env_cfg.get('CUDNN.BENCHMARK', True), 141 | env_cfg.get('CUDNN.DETERMINISTIC', False) 142 | ) 143 | -------------------------------------------------------------------------------- /easytorch/config/config.py: -------------------------------------------------------------------------------- 1 | # Modified from: https://github.com/makinacorpus/easydict/blob/master/easydict/__init__.py 2 | from typing import overload 3 | 4 | 5 | class Config(dict): 6 | """ 7 | Get attributes 8 | 9 | >>> d = Config({'foo':3}) 10 | >>> d['foo'] 11 | 3 12 | >>> d.foo 13 | 3 14 | >>> d.bar 15 | Traceback (most recent call last): 16 | ... 17 | AttributeError: 'Config' object has no attribute 'bar' 18 | 19 | Works recursively 20 | 21 | >>> d = Config({'foo':3, 'bar':{'x':1, 'y':2}}) 22 | >>> isinstance(d.bar, dict) 23 | True 24 | >>> d.bar.x 25 | 1 26 | >>> d['bar.x'] 27 | 1 28 | >>> d.get('bar.x') 29 | 1 30 | >>> d.get('bar.z') 31 | None 32 | >>> d.get('bar.z', 3) 33 | 3 34 | >>> d.has('bar.x') 35 | True 36 | >>> d.has('bar.z') 37 | False 38 | 39 | Bullet-proof 40 | 41 | >>> Config({}) 42 | {} 43 | >>> Config(d={}) 44 | {} 45 | >>> Config(None) 46 | {} 47 | >>> d = {'a': 1} 48 | >>> Config(**d) 49 | {'a': 1} 50 | 51 | Set attributes 52 | 53 | >>> d = Config() 54 | >>> d.foo = 3 55 | >>> d.foo 56 | 3 57 | >>> d.bar = {'prop': 'value'} 58 | >>> d.bar.prop 59 | 'value' 60 | >>> d 61 | {'foo': 3, 'bar': {'prop': 'value'}} 62 | >>> d.bar.prop = 'newer' 63 | >>> d.bar.prop 64 | 'newer' 65 | 66 | 67 | Values extraction 68 | 69 | >>> d = Config({'foo':0, 'bar':[{'x':1, 'y':2}, {'x':3, 'y':4}]}) 70 | >>> isinstance(d.bar, list) 71 | True 72 | >>> from operator import attrgetter 73 | >>> map(attrgetter('x'), d.bar) 74 | [1, 3] 75 | >>> map(attrgetter('y'), d.bar) 76 | [2, 4] 77 | >>> d = Config() 78 | >>> d.keys() 79 | [] 80 | >>> d = Config(foo=3, bar=dict(x=1, y=2)) 81 | >>> d.foo 82 | 3 83 | >>> d.bar.x 84 | 1 85 | 86 | Still like a dict though 87 | 88 | >>> o = Config({'clean':True}) 89 | >>> o.items() 90 | [('clean', True)] 91 | 92 | And like a class 93 | 94 | >>> class Flower(Config): 95 | ... power = 1 96 | ... 97 | >>> f = Flower() 98 | >>> f.power 99 | 1 100 | >>> f = Flower({'height': 12}) 101 | >>> f.height 102 | 12 103 | >>> f['power'] 104 | 1 105 | >>> sorted(f.keys()) 106 | ['height', 'power'] 107 | 108 | update and pop items 109 | >>> d = Config(a=1, b='2') 110 | >>> e = Config(c=3.0, a=9.0) 111 | >>> d.update(e) 112 | >>> d.c 113 | 3.0 114 | >>> d['c'] 115 | 3.0 116 | >>> d.get('c') 117 | 3.0 118 | >>> d.update(a=4, b=4) 119 | >>> d.b 120 | 4 121 | >>> d.pop('a') 122 | 4 123 | >>> d.a 124 | Traceback (most recent call last): 125 | ... 126 | AttributeError: 'Config' object has no attribute 'a' 127 | """ 128 | 129 | # pylint: disable=super-init-not-called 130 | def __init__(self, d=None, **kwargs): 131 | if d is None: 132 | d = {} 133 | if kwargs: 134 | d.update(**kwargs) 135 | for k, v in d.items(): 136 | setattr(self, k, v) 137 | # Class attributes 138 | for k in self.__class__.__dict__: 139 | if not (k.startswith('__') and k.endswith('__')) and not k in ('has', 'get', 'update', 'pop'): 140 | setattr(self, k, getattr(self, k)) 141 | 142 | def __setattr__(self, name, value): 143 | if isinstance(value, (list, tuple)): 144 | v = [self.__class__(x) if isinstance(x, dict) else x for x in value] 145 | # Don't repalce tuple with list 146 | if isinstance(value, tuple): 147 | v = tuple(v) 148 | value = v 149 | elif isinstance(value, dict) and not isinstance(value, self.__class__): 150 | value = self.__class__(value) 151 | super().__setattr__(name, value) 152 | super().__setitem__(name, value) 153 | 154 | __setitem__ = __setattr__ 155 | 156 | def __getitem__(self, key): 157 | # Support `cfg['AA.BB.CC']` 158 | if isinstance(key, str): 159 | keys = key.split('.') 160 | else: 161 | keys = key 162 | value = super().__getitem__(keys[0]) 163 | if len(keys) > 1: 164 | return value.__getitem__(keys[1:]) 165 | else: 166 | return value 167 | 168 | def has(self, key): 169 | return self.get(key) is not None 170 | 171 | @overload 172 | def get(self, key): ... 173 | 174 | def get(self, key, default=None): 175 | # Support `cfg.get('AA.BB.CC')` and `cfg.get('AA.BB.CC', default_value)` 176 | try: 177 | return self[key] 178 | except KeyError: 179 | return default 180 | 181 | def update(self, e=None, **f): 182 | d = e or {} 183 | d.update(f) 184 | for k in d: 185 | setattr(self, k, d[k]) 186 | 187 | def pop(self, k, d=None): 188 | # Check for existence 189 | if hasattr(self, k): 190 | delattr(self, k) 191 | return super().pop(k, d) 192 | -------------------------------------------------------------------------------- /easytorch/utils/data_prefetcher.py: -------------------------------------------------------------------------------- 1 | import threading 2 | import queue as Queue 3 | 4 | import torch 5 | from torch.utils.data import DataLoader 6 | 7 | from .. import device 8 | 9 | 10 | class BackgroundGenerator(threading.Thread): 11 | """BackgroundGenerator 12 | """ 13 | 14 | def __init__(self, generator, max_prefetch=1): 15 | """ 16 | 17 | This function transforms generator into a background-thead generator. 18 | :param generator: generator or genexp or any 19 | It can be used with any minibatch generator. 20 | 21 | It is quite lightweight, but not entirely weightless. 22 | Using global variables inside generator is not recommended 23 | (may rise GIL and zero-out the benefit of having a background thread.) 24 | The ideal use case is when everything it requires is store inside it and everything it outputs is 25 | passed through queue. 26 | 27 | There's no restriction on doing weird stuff, reading/writing files, 28 | retrieving URLs [or whatever] wlilst iterating. 29 | 30 | :param max_prefetch: defines, how many iterations (at most) can background generator 31 | keep stored at any moment of time. 32 | Whenever there's already max_prefetch batches stored in queue, the background process will halt until 33 | one of these batches is dequeued. 34 | 35 | !Default max_prefetch=1 is okay unless you deal with some weird file IO in your generator! 36 | 37 | Setting max_prefetch to -1 lets it store as many batches as it can, which will work slightly (if any) faster, 38 | but will require storing all batches in memory. 39 | If you use infinite generator with max_prefetch=-1, it will exceed the RAM size unless dequeued quickly enough. 40 | """ 41 | threading.Thread.__init__(self) 42 | self.queue = Queue.Queue(max_prefetch) 43 | self.generator = generator 44 | self.daemon = True 45 | self.start() 46 | 47 | def run(self): 48 | for item in self.generator: 49 | self.queue.put(item) 50 | self.queue.put(None) 51 | 52 | def next(self): 53 | next_item = self.queue.get() 54 | if next_item is None: 55 | raise StopIteration 56 | return next_item 57 | 58 | def __next__(self): 59 | return self.next() 60 | 61 | def __iter__(self): 62 | return self 63 | 64 | def __len__(self): 65 | return len(self.generator) 66 | 67 | 68 | class DataLoaderX(DataLoader): 69 | """Dataloader with prefetch. See https://github.com/justheuristic/prefetch_generator 70 | """ 71 | def __iter__(self): 72 | return BackgroundGenerator(super().__iter__()) 73 | 74 | 75 | def data_to_device(data): 76 | if isinstance(data, dict): 77 | for k, v in data.items(): 78 | if isinstance(v, torch.Tensor): 79 | data[k] = data_to_device(v) 80 | elif isinstance(data, list): 81 | for i, v in enumerate(data): 82 | if isinstance(v, torch.Tensor): 83 | data[i] = data_to_device(v) 84 | elif isinstance(data, tuple): 85 | data = tuple(data_to_device(list(data))) 86 | elif isinstance(data, torch.Tensor): 87 | data = device.to_device(data, non_blocking=True) 88 | return data 89 | 90 | 91 | class DevicePrefetcher: 92 | """Device Prefetcher 93 | """ 94 | 95 | def __init__(self, data_loader: DataLoader) -> None: 96 | self.data_loader = data_loader 97 | self.stream = torch.cuda.Stream() 98 | self.batch_data = None 99 | 100 | @staticmethod 101 | def data_to_device(data): 102 | if isinstance(data, dict): 103 | for k, v in data.items(): 104 | if isinstance(v, torch.Tensor): 105 | data[k] = device.to_device(v, non_blocking=True) 106 | elif isinstance(data, (list, tuple)): 107 | for i, v in enumerate(data): 108 | if isinstance(v, torch.Tensor): 109 | data[i] = device.to_device(v, non_blocking=True) 110 | elif isinstance(data, torch.Tensor): 111 | data = device.to_device(data, non_blocking=True) 112 | return data 113 | 114 | def preload(self): 115 | try: 116 | self.batch_data = next(self.data_loader_iter) 117 | # put tensors to gpu 118 | with device.stream(self.stream): 119 | self.batch_data = data_to_device(self.batch_data) 120 | except StopIteration: 121 | self.batch_data = None 122 | 123 | def next(self): 124 | if self.batch_data is None: 125 | raise StopIteration() 126 | 127 | device.current_stream().wait_stream(self.stream) 128 | batch = self.batch_data 129 | self.preload() 130 | return batch 131 | 132 | def reset(self): 133 | self.data_loader_iter = iter(self.data_loader) 134 | self.preload() 135 | 136 | def __next__(self): 137 | return self.next() 138 | 139 | def __iter__(self): 140 | self.reset() 141 | return self 142 | 143 | def __len__(self): 144 | return len(self.data_loader) 145 | -------------------------------------------------------------------------------- /easytorch/launcher/dist_wrap.py: -------------------------------------------------------------------------------- 1 | import functools 2 | import random 3 | from typing import Callable, Dict, Union, Any, Optional 4 | 5 | import torch 6 | 7 | from ..utils import get_logger 8 | from ..device import get_device_type, set_device_type, get_device_count, set_device 9 | 10 | 11 | def dist_func(local_rank: int, dist_params: Dict[str, Any], func: Callable, *args): 12 | """Distributed function for `torch.multiprocessing.spawn` 13 | 14 | Args: 15 | local_rank (int): Local rank of current process group. 16 | dist_params (Dict[str, Any]): Other distributed parameters. 17 | func (Callable): A function. 18 | """ 19 | 20 | logger = get_logger('easytorch-launcher') 21 | 22 | rank = dist_params['device_num'] * dist_params['node_rank'] + local_rank 23 | logger.info( 24 | 'Launching in distributed mode. Distributed parameters:'\ 25 | 'word_size={:d}, node_rank={:d}, rank={:d}, local_rank={:d}, dist_backend={}, init_method={}'.format( 26 | dist_params['word_size'], dist_params['node_rank'], rank, local_rank, 27 | dist_params['dist_backend'], dist_params['init_method'] 28 | ) 29 | ) 30 | 31 | set_device_type(dist_params['device_type']) 32 | 33 | torch.distributed.init_process_group( 34 | backend=dist_params['dist_backend'], 35 | init_method=dist_params['init_method'], 36 | rank=rank, 37 | world_size=dist_params['word_size'] 38 | ) 39 | 40 | set_device(local_rank) 41 | 42 | try: 43 | args, kwargs = args 44 | func(*args, **kwargs) 45 | finally: 46 | # https://pytorch.org/docs/stable/distributed.html#shutdown 47 | torch.distributed.destroy_process_group() 48 | 49 | 50 | def dist_wrap(func: Callable, 51 | node_num: int = 1, 52 | device_num: int = 1, 53 | node_rank: int = 0, 54 | dist_backend: Optional[Union[str, torch.distributed.Backend]] = None, 55 | init_method: Optional[str] = None) -> Callable: 56 | """Convert a function to a distributed function. 57 | 58 | Usage: 59 | >>> def function(a, b): 60 | >>> ... 61 | >>> 62 | >>> function_dist = dist_wrap( 63 | >>> function, 64 | >>> node_num=node_num, 65 | >>> device_num=device_num, 66 | >>> node_rank=node_rank, 67 | >>> dist_backend=dist_backend, 68 | >>> init_method=init_method 69 | >>> ) 70 | >>> function_dist(a, b) 71 | 72 | Args: 73 | func (Callable): The function. 74 | node_num (int, optional): Number of node. Defaults to 1. 75 | device_num (int, optional): Number of devices per node. Defaults to 1. 76 | node_rank (int, optional): Rank of current node. Defaults to 0. 77 | dist_backend (Optional[Union[str, distributed.Backend]], optional): The backend of DDP. 78 | Defaults to None, means using `nccl` as the backend. 79 | init_method (Optional[str], optional): URL specifying how to initialize the process group. 80 | Defaults to None, means using `172.0.0.1:{random port}` as the init method. 81 | 82 | Returns: 83 | Callable: The converted function. 84 | """ 85 | 86 | if node_num < 1: 87 | raise ValueError('The node_num must be greater than 1!') 88 | 89 | if device_num < 0: 90 | raise ValueError('The device_num must be greater than 0!') 91 | 92 | word_size = node_num * device_num 93 | 94 | if word_size == 0: 95 | # CPU mode 96 | return func 97 | else: 98 | # DEVICE mode 99 | if node_rank >= node_num: 100 | raise ValueError('The node_rank must be less than dist_node_num!') 101 | 102 | if device_num != get_device_count(): 103 | raise RuntimeError('Device num not match, cfg.DEVICE_NUM = {:d}, ' \ 104 | 'but torch.cuda.device_count() = {:d}'.format(device_num, get_device_count())) 105 | 106 | if word_size == 1: 107 | return func 108 | else: 109 | # Distributed Data Parallel 110 | dist_backend = 'nccl' if dist_backend is None else dist_backend 111 | 112 | if init_method is None: 113 | if node_num == 1: 114 | init_method = 'tcp://127.0.0.1:{:d}'.format(random.randint(50000, 65000)) 115 | else: 116 | raise ValueError('The init_method cannot be None in multiple compute nodes') 117 | 118 | @functools.wraps(func) 119 | def wrapper(*args, **kwargs): 120 | dist_params = { 121 | 'device_type': get_device_type(), 122 | 'device_num': device_num, 123 | 'node_rank': node_rank, 124 | 'word_size': word_size, 125 | 'dist_backend': dist_backend, 126 | 'init_method': init_method 127 | } 128 | 129 | torch.multiprocessing.spawn( 130 | dist_func, 131 | args=(dist_params, func, args, kwargs), 132 | nprocs=device_num, 133 | join=True 134 | ) 135 | 136 | return wrapper 137 | -------------------------------------------------------------------------------- /easytorch/core/meter_pool.py: -------------------------------------------------------------------------------- 1 | import logging 2 | 3 | import torch 4 | from torch.utils.tensorboard import SummaryWriter 5 | 6 | 7 | class AvgMeter(object): 8 | """Average meter. 9 | """ 10 | 11 | def __init__(self): 12 | 13 | self._last = 0. 14 | self._sum = 0. 15 | self._count = 0 16 | 17 | def reset(self): 18 | """Reset counter. 19 | """ 20 | 21 | self._last = 0. 22 | self._sum = 0. 23 | self._count = 0 24 | 25 | def update(self, value: float, n: int = 1): 26 | """Update sum and count. 27 | 28 | Args: 29 | value (float): value. 30 | n (int): number. 31 | """ 32 | 33 | self._last = value 34 | self._sum += value * n 35 | self._count += n 36 | 37 | @property 38 | def avg(self) -> float: 39 | """Get average value. 40 | 41 | Returns: 42 | avg (float) 43 | """ 44 | 45 | return self._sum / self._count if self._count != 0 else 0 46 | 47 | @property 48 | def last(self) -> float: 49 | """Get last value. 50 | 51 | Returns: 52 | last (float) 53 | """ 54 | 55 | return self._last 56 | 57 | class MeterPool: 58 | """Meter container 59 | """ 60 | 61 | def __init__(self): 62 | self._pool = {} 63 | 64 | def register(self, name: str, meter_type: str, fmt: str = '{:f}', plt: bool = True): 65 | """Init an average meter and add it to meter pool. 66 | 67 | Args: 68 | name (str): meter name (must be unique). 69 | meter_type (str): meter type. 70 | fmt (str): meter output format. 71 | plt (bool): set ```True``` to plot it in tensorboard 72 | when calling ```plt_meters```. 73 | """ 74 | 75 | if name in self._pool: 76 | raise ValueError(f'Meter {name} already existed.') 77 | 78 | self._pool[name] = { 79 | 'meter': AvgMeter(), 80 | 'index': len(self._pool.keys()), 81 | 'format': fmt, 82 | 'type': meter_type, 83 | 'plt': plt 84 | } 85 | 86 | def update(self, name: str, value: float, n: int = 1): 87 | """Update average meter. 88 | 89 | Args: 90 | name (str): meter name. 91 | value (str): value. 92 | n: (int): num. 93 | """ 94 | 95 | self._pool[name]['meter'].update(value, n) 96 | 97 | def get_avg(self, name: str) -> float: 98 | """Get average value. 99 | 100 | Args: 101 | name (str): meter name. 102 | 103 | Returns: 104 | avg (float) 105 | """ 106 | 107 | return self._pool[name]['meter'].avg 108 | 109 | def print_meters(self, meter_type: str, logger: logging.Logger = None): 110 | """Print the specified type of meters. 111 | 112 | Args: 113 | meter_type (str): meter type 114 | logger (logging.Logger): logger 115 | """ 116 | 117 | print_list = [] 118 | for i in range(len(self._pool.keys())): 119 | for name, value in self._pool.items(): 120 | if value['index'] == i and value['type'] == meter_type: 121 | print_list.append( 122 | ('{}: ' + value['format']).format(name, value['meter'].avg) 123 | ) 124 | print_str = 'Result <{}>: [{}]'.format(meter_type, ', '.join(print_list)) 125 | if logger is None: 126 | print(print_str) 127 | else: 128 | logger.info(print_str) 129 | 130 | def plt_meters(self, meter_type: str, step: int, tensorboard_writer: SummaryWriter, value_type: str = 'avg'): 131 | """Plot the specified type of meters in tensorboard. 132 | 133 | Args: 134 | meter_type (str): meter type. 135 | step (int): Global step value to record 136 | tensorboard_writer (SummaryWriter): tensorboard SummaryWriter 137 | """ 138 | 139 | assert value_type in ['avg', 'last'], "value_type must be 'avg' or 'last'" 140 | 141 | get_value = lambda meter: meter.avg if value_type == 'avg' else meter.last 142 | 143 | for name, value in self._pool.items(): 144 | if value['plt'] and value['type'] == meter_type: 145 | tensorboard_writer.add_scalar(name, get_value(value['meter']), global_step=step) 146 | tensorboard_writer.flush() 147 | 148 | def reset(self): 149 | """Reset all meters. 150 | """ 151 | 152 | for _, value in self._pool.items(): 153 | value['meter'].reset() 154 | 155 | 156 | class MeterPoolDDP(MeterPool): 157 | """MeterPoolDDP 158 | """ 159 | # TODO(Yuhao Wang): not support 160 | 161 | def to_tensor(self): 162 | tensor = torch.empty((len(self._pool.keys()), 2)) 163 | for i in range(len(self._pool.keys())): 164 | for _, value in self._pool.items(): 165 | if value['index'] == i: 166 | tensor[i][0] = float(value['meter'].count) 167 | tensor[i][1] = value['meter'].avg 168 | return tensor 169 | 170 | def update_tensor(self, tensor): 171 | if tensor.shape[0] != len(self._pool.keys()): 172 | raise ValueError('Invalid tensor shape!') 173 | for i in range(len(self._pool.keys())): 174 | for _, value in self._pool.items(): 175 | if value['index'] == i: 176 | value['meter'].update(tensor[i][1], tensor[i][0]) 177 | -------------------------------------------------------------------------------- /easytorch/config/utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | import shutil 3 | import types 4 | import copy 5 | import hashlib 6 | from typing import Dict, Set, List, Union 7 | 8 | from .config import Config 9 | 10 | __all__ = [ 11 | 'config_str', 'config_md5', 'save_config_str', 'copy_config_file', 12 | 'import_config', 'convert_config', 'get_ckpt_save_dir' 13 | ] 14 | 15 | TRAINING_INDEPENDENT_FLAG = '_TRAINING_INDEPENDENT' 16 | 17 | TRAINING_INDEPENDENT_KEYS = { 18 | 'DIST_BACKEND', 19 | 'DIST_INIT_METHOD', 20 | 'TRAIN.CKPT_SAVE_STRATEGY', 21 | 'TRAIN.DATA.NUM_WORKERS', 22 | 'TRAIN.DATA.PIN_MEMORY', 23 | 'TRAIN.DATA.PREFETCH', 24 | 'VAL' 25 | } 26 | 27 | 28 | def get_training_dependent_config(cfg: Dict, except_keys: Union[Set, List] = None) -> Dict: 29 | """Get training dependent config. 30 | Recursively traversal each key, 31 | if the key is in `TRAINING_INDEPENDENT_KEYS` or `CFG._TRAINING_INDEPENDENT`, pop it. 32 | 33 | Args: 34 | cfg (Dict): Config 35 | except_keys (Union[Set, List]): the keys need to be excepted 36 | 37 | Returns: 38 | cfg (Dict): Training dependent configs 39 | """ 40 | cfg_copy = copy.deepcopy(cfg) 41 | 42 | if except_keys is None: 43 | except_keys = copy.deepcopy(TRAINING_INDEPENDENT_KEYS) 44 | if cfg_copy.get(TRAINING_INDEPENDENT_FLAG) is not None: 45 | except_keys.update(cfg_copy[TRAINING_INDEPENDENT_FLAG]) 46 | 47 | # convert to set 48 | if isinstance(except_keys, list): 49 | except_keys = set(except_keys) 50 | 51 | if cfg_copy.get(TRAINING_INDEPENDENT_FLAG) is not None: 52 | cfg_copy.pop(TRAINING_INDEPENDENT_FLAG) 53 | 54 | pop_list = [] 55 | dict_list = [] 56 | for k, v in cfg_copy.items(): 57 | if isinstance(v, dict): 58 | sub_except_keys = set([]) 59 | for except_key in except_keys: 60 | if k == except_key: 61 | pop_list.append(k) 62 | elif except_key.find(k) == 0 and except_key[len(k)] == '.': 63 | sub_except_keys.add(except_key[len(k) + 1:]) 64 | if len(sub_except_keys) != 0: 65 | new_v = get_training_dependent_config(v, sub_except_keys) 66 | dict_list.append((k, new_v)) 67 | else: 68 | for except_key in except_keys: 69 | if k == except_key: 70 | pop_list.append(k) 71 | 72 | for dict_key, dict_value in dict_list: 73 | cfg_copy[dict_key] = dict_value 74 | 75 | for pop_key in pop_list: 76 | cfg_copy.pop(pop_key) 77 | 78 | return cfg_copy 79 | 80 | 81 | def config_str(cfg: Dict, indent: str = '') -> str: 82 | """Get config string 83 | 84 | Args: 85 | cfg (Dict): Config 86 | indent (str): if ``cfg`` is a sub config, ``indent`` += ' ' 87 | 88 | Returns: 89 | Config string (str) 90 | """ 91 | 92 | s = '' 93 | for k, v in cfg.items(): 94 | if isinstance(v, dict): 95 | s += (indent + '{}:').format(k) + '\n' 96 | s += config_str(v, indent + ' ') 97 | elif isinstance(v, types.FunctionType): 98 | s += (indent + '{}: {}').format(k, v.__name__) + '\n' 99 | elif k == TRAINING_INDEPENDENT_FLAG: 100 | pass 101 | else: 102 | s += (indent + '{}: {}').format(k, v) + '\n' 103 | return s 104 | 105 | 106 | def config_md5(cfg: Dict) -> str: 107 | """Get MD5 value of config. 108 | 109 | Notes: 110 | Only training dependent configurations participate in the MD5 calculation. 111 | 112 | Args: 113 | cfg (Dict): Config 114 | 115 | Returns: 116 | MD5 (str) 117 | """ 118 | 119 | cfg_excepted = get_training_dependent_config(cfg) 120 | m = hashlib.md5() 121 | m.update(config_str(cfg_excepted).encode('utf-8')) 122 | return m.hexdigest() 123 | 124 | 125 | def save_config_str(cfg: Dict, file_path: str): 126 | """Save config 127 | 128 | Args: 129 | cfg (Dict): Config 130 | file_path (str): file path 131 | """ 132 | 133 | with open(file_path, 'w') as f: 134 | f.write(config_str(cfg)) 135 | 136 | 137 | def copy_config_file(cfg_file_path: str, save_dir: str): 138 | """Copy config file to `save_dir` 139 | 140 | Args: 141 | cfg_file_path (str): config file path 142 | save_dir (str): save directory 143 | """ 144 | 145 | if os.path.isfile(cfg_file_path) and os.path.isdir(save_dir): 146 | cfg_file_name = os.path.basename(cfg_file_path) 147 | shutil.copyfile(cfg_file_path, os.path.join(save_dir, cfg_file_name)) 148 | 149 | 150 | def import_config(path: str, verbose: bool = True) -> Dict: 151 | """Import config by path 152 | 153 | Examples: 154 | ``` 155 | cfg = import_config('config/my_config.py') 156 | ``` 157 | is equivalent to 158 | ``` 159 | from config.my_config import CFG as cfg 160 | ``` 161 | 162 | Args: 163 | path (str): Config path 164 | verbose (str): set to ``True`` to print config 165 | 166 | Returns: 167 | cfg (Dict): `CFG` in config file 168 | """ 169 | 170 | if path.find('.py') != -1: 171 | path = path[:path.find('.py')].replace('/', '.').replace('\\', '.') 172 | cfg_name = path.split('.')[-1] 173 | cfg = __import__(path, fromlist=[cfg_name]).CFG 174 | 175 | if verbose: 176 | print(config_str(cfg)) 177 | return cfg 178 | 179 | 180 | def convert_config(cfg: Dict) -> Config: 181 | """Convert cfg to `Config`; add MD5 to cfg. 182 | 183 | Args: 184 | cfg (Dict): config. 185 | """ 186 | 187 | if not isinstance(cfg, Config): 188 | cfg = Config(cfg) 189 | if cfg.get('MD5') is None: 190 | cfg['MD5'] = config_md5(cfg) 191 | return cfg 192 | 193 | 194 | def get_ckpt_save_dir(cfg: Dict) -> str: 195 | """Get real ckpt save dir with MD5. 196 | 197 | Args: 198 | cfg (Dict): config. 199 | 200 | Returns: 201 | str: Real ckpt save dir 202 | """ 203 | 204 | return os.path.join(cfg['TRAIN']['CKPT_SAVE_DIR'], cfg['MD5']) 205 | 206 | 207 | def init_cfg(cfg: Union[Dict, str], save: bool = False): 208 | if isinstance(cfg, str): 209 | cfg_path = cfg 210 | cfg = import_config(cfg, verbose=save) 211 | else: 212 | cfg_path = None 213 | 214 | # convert ckpt save dir 215 | cfg = convert_config(cfg) 216 | 217 | # save config 218 | ckpt_save_dir = get_ckpt_save_dir(cfg) 219 | if save and not os.path.isdir(ckpt_save_dir): 220 | os.makedirs(ckpt_save_dir) 221 | save_config_str(cfg, os.path.join(ckpt_save_dir, 'cfg.txt')) 222 | if cfg_path is not None: 223 | copy_config_file(cfg_path, ckpt_save_dir) 224 | 225 | return cfg 226 | -------------------------------------------------------------------------------- /easytorch/easyoptim/easy_lr_scheduler.py: -------------------------------------------------------------------------------- 1 | import math 2 | import warnings 3 | 4 | from torch.optim.lr_scheduler import _LRScheduler 5 | 6 | 7 | class MultiCosineAnnealingWarmupLR(_LRScheduler): 8 | r"""Set the learning rate of each parameter group using a cosine annealing 9 | schedule, where :math:`\eta_{max}` is set to the initial lr, :math:`T_{cur}` 10 | is the number of epochs since the last restart and :math:`T_{i}` is the number 11 | of epochs between two warm restarts in SGDR: 12 | 13 | .. math:: 14 | \eta_t = \eta_{min} + \lr_mult \times \frac{1}{2}(\eta_{max} - \eta_{min})\left(1 + 15 | \cos\left(\frac{T_{cur}}{T_{i}}\pi\right)\right) 16 | 17 | When :math:`T_{cur}=T_{i}`, set :math:`\eta_t = \eta_{min}`. 18 | When :math:`T_{cur}=0` after restart, set :math:`\eta_t=\eta_{max}`. 19 | 20 | It has been proposed in 21 | `SGDR: Stochastic Gradient Descent with Warm Restarts`_. 22 | 23 | Args: 24 | optimizer (Optimizer): Wrapped optimizer. 25 | final_epoch (int): Number of total iterations 26 | T_0 (list): Number of iterations for the restart 27 | lr_mult (list): A factor multiplied with learning rate at iteration T_0, 28 | must have the same shape as T_0 29 | warmup_begin (int, optional): Number of iterations for the beginning warm up, 30 | notice that the first decay T_mult will be reduced by this param. Default: 0 31 | warmup_factor (float, optional): A factor that the learning rate will be multiplied at first epoch. 32 | Default: 0.01 33 | eta_min (float, optional): Minimum learning rate. Default: 0. 34 | last_epoch (int, optional): The index of last epoch. Default: -1. 35 | 36 | .. _SGDR\: Stochastic Gradient Descent with Warm Restarts: 37 | https://arxiv.org/abs/1608.03983 38 | """ 39 | 40 | def __init__(self, optimizer, final_epoch, T_0=None, lr_mult=None, warmup_begin=0, warmup_factor=0.01, 41 | eta_min=0, last_epoch=-1, verbose=False): 42 | if T_0 and not isinstance(T_0, list): 43 | raise ValueError('Expected list object or None type T_0, but got {}'.format(type(T_0))) 44 | if lr_mult and not isinstance(lr_mult, list): 45 | raise ValueError('Expected list object or None type lr_mult, bug got {}'.format(type(lr_mult))) 46 | if not T_0 and not lr_mult: 47 | raise ValueError( 48 | 'Expected T_0 and lr_mult has the same length, but got NoneType and {}'.format(len(lr_mult)) 49 | ) 50 | if T_0 and not lr_mult: 51 | raise ValueError('Expected T_0 and lr_mult has the same length, but got {} and NoneType'.format(len(T_0))) 52 | if T_0 and lr_mult and len(T_0) != len(lr_mult): 53 | raise ValueError( 54 | 'Expected T_0 and lr_mult has the same length, but got {} and {}'.format(len(T_0), len(lr_mult)) 55 | ) 56 | 57 | self.final_epoch = final_epoch 58 | self.t_0_list = T_0 if T_0 else [] 59 | self.lr_mult_list = lr_mult if lr_mult else [] 60 | self.warmup_begin = warmup_begin 61 | self.warmup_factor = warmup_factor 62 | self.eta_min = eta_min 63 | 64 | # add number at beggining 65 | self.t_0_list.insert(0, 0) 66 | self.lr_mult_list.insert(0, 1) 67 | 68 | # calculate T_i accroding to given T_0 list 69 | self.t_0_expand= self.t_0_list.copy() 70 | self.t_0_expand.append(final_epoch) 71 | 72 | if self.warmup_begin > self.t_0_expand[1]: 73 | raise ValueError('the warmup_begin iteration is bigger than the first T_i,' \ 74 | ' please use smaller warmup_begin or bigger T_0[0]') 75 | 76 | self.t_i_list = [self.t_0_expand[i+1] - self.t_0_expand[i] - 1 for i in range(len(self.t_0_expand)-1)] 77 | self.t_i_list[0] = self.t_i_list[0] - self.warmup_begin # subtract warmup at beginning 78 | 79 | 80 | # initial T_i, lr_mult and T_cur 81 | self.t_i = self.t_i_list[0] 82 | self.lr_mult = self.lr_mult_list[0] 83 | self.t_cur = 0 84 | 85 | super().__init__(optimizer, last_epoch) 86 | 87 | def get_lr(self): 88 | if not self._get_lr_called_within_step: 89 | warnings.warn('To get the last learning rate computed by the scheduler, ' 90 | 'please use `get_last_lr()`.', UserWarning) 91 | 92 | if self.last_epoch <= self.warmup_begin: 93 | if self.warmup_begin != 0: 94 | lr = [(base_lr - self.warmup_factor * base_lr) * (self.last_epoch / self.warmup_begin) \ 95 | + self.warmup_factor * base_lr for base_lr in self.base_lrs] 96 | else: 97 | lr = list(self.base_lrs) 98 | else: 99 | lr = [self.eta_min + (self.lr_mult * base_lr - self.eta_min) \ 100 | * (1 + math.cos(math.pi * self.t_cur / self.t_i)) / 2 for base_lr in self.base_lrs] 101 | 102 | return lr 103 | 104 | def step(self, epoch=None): 105 | """Step could be called after every batch update 106 | 107 | Example: 108 | >>> T_0 = [20, 30, 50] 109 | >>> lr_mult = [0.8, 0.6, 0.5] 110 | >>> scheduler = MultiCosineAnnealingWarmupLR(optimizer, final_epoch=150, T_0, lr_mult, warmup_begin=5, \ 111 | >>> warmup_factor=0.001, eta_min=1e-7) 112 | >>> iters = len(dataloader) 113 | >>> for epoch in range(20): 114 | >>> for i, sample in enumerate(dataloader): 115 | >>> inputs, labels = sample['inputs'], sample['labels'] 116 | >>> optimizer.zero_grad() 117 | >>> outputs = net(inputs) 118 | >>> loss = criterion(outputs, labels) 119 | >>> loss.backward() 120 | >>> optimizer.step() 121 | >>> scheduler.step() 122 | """ 123 | 124 | def locate(self, epoch: int): 125 | """return the number of which section dose T_0 locate in T_0_list 126 | """ 127 | for i in range(1, len(self.T_0_list)): 128 | if epoch == 0: 129 | return 0 130 | elif self.T_0_list[i-1] <= epoch < self.T_0_list[i]: 131 | return i-1 132 | else: 133 | continue 134 | return len(self.T_0_list) - 1 135 | 136 | if epoch is None and self.last_epoch < 0: 137 | epoch = 0 138 | 139 | if epoch is None: 140 | epoch = self.last_epoch + 1 141 | section = locate(self, epoch) # the section of the T_0_list 142 | self.t_0 = self.t_0_list[section] 143 | self.t_i = self.t_i_list[section] 144 | self.t_cur = epoch - self.t_0 145 | if epoch < self.t_0_expand[1]: 146 | self.t_cur = self.t_cur - self.warmup_begin # subtract warmup_begin at first section 147 | self.lr_mult = self.lr_mult_list[section] 148 | else: 149 | if epoch < 0: 150 | raise ValueError('Expected non-negative epoch, but got {}'.format(epoch)) 151 | section = locate(self, epoch) 152 | self.t_0 = self.t_0_list[section] 153 | self.t_i = self.t_i_list[section] 154 | if epoch == 0: 155 | self.t_cur = 0 156 | else: 157 | self.t_cur = epoch - self.t_0 - 1 158 | if epoch < self.t_0_expand[1]: 159 | self.t_cur = self.t_cur - self.warmup_begin 160 | self.lr_mult = self.lr_mult_list[section] 161 | 162 | self.last_epoch = math.floor(epoch) 163 | 164 | class _EnableGetLRCall: 165 | """_EnableGetLRCall 166 | """ 167 | 168 | def __init__(self, o): 169 | self.o = o 170 | 171 | def __enter__(self): 172 | self.o._get_lr_called_within_step = True 173 | return self 174 | 175 | def __exit__(self, type_, value, traceback): 176 | self.o._get_lr_called_within_step = False 177 | return self 178 | 179 | with _EnableGetLRCall(self): 180 | for param_group, lr in zip(self.optimizer.param_groups, self.get_lr()): 181 | param_group['lr'] = lr 182 | 183 | self._last_lr = [group['lr'] for group in self.optimizer.param_groups] 184 | -------------------------------------------------------------------------------- /easytorch/easyoptim/lamb.py: -------------------------------------------------------------------------------- 1 | """ PyTorch Lamb optimizer w/ behaviour similar to NVIDIA FusedLamb 2 | This optimizer code was adapted from the following (starting with latest) 3 | * https://github.com/HabanaAI/Model-References/blob/2b435114fe8e31f159b1d3063b8280ae37af7423/PyTorch/nlp/bert/pretraining/lamb.py 4 | * https://github.com/NVIDIA/DeepLearningExamples/blob/master/PyTorch/LanguageModeling/Transformer-XL/pytorch/lamb.py 5 | * https://github.com/cybertronai/pytorch-lamb 6 | Use FusedLamb if you can (GPU). The reason for including this variant of Lamb is to have a version that is 7 | similar in behaviour to APEX FusedLamb if you aren't using NVIDIA GPUs or cannot install/use APEX. 8 | In addition to some cleanup, this Lamb impl has been modified to support PyTorch XLA and has been tested on TPU. 9 | Original copyrights for above sources are below. 10 | Modifications Copyright 2021 Ross Wightman 11 | """ 12 | # Copyright (c) 2021, Habana Labs Ltd. All rights reserved. 13 | 14 | # Copyright (c) 2019-2020, NVIDIA CORPORATION. All rights reserved. 15 | # 16 | # Licensed under the Apache License, Version 2.0 (the "License"); 17 | # you may not use this file except in compliance with the License. 18 | # You may obtain a copy of the License at 19 | # 20 | # http://www.apache.org/licenses/LICENSE-2.0 21 | # 22 | # Unless required by applicable law or agreed to in writing, software 23 | # distributed under the License is distributed on an "AS IS" BASIS, 24 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 25 | # See the License for the specific language governing permissions and 26 | # limitations under the License. 27 | 28 | # MIT License 29 | # 30 | # Copyright (c) 2019 cybertronai 31 | # 32 | # Permission is hereby granted, free of charge, to any person obtaining a copy 33 | # of this software and associated documentation files (the "Software"), to deal 34 | # in the Software without restriction, including without limitation the rights 35 | # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 36 | # copies of the Software, and to permit persons to whom the Software is 37 | # furnished to do so, subject to the following conditions: 38 | # 39 | # The above copyright notice and this permission notice shall be included in all 40 | # copies or substantial portions of the Software. 41 | # 42 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 43 | # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 44 | # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 45 | # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 46 | # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 47 | # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 48 | # SOFTWARE. 49 | import math 50 | 51 | import torch 52 | from torch.optim import Optimizer 53 | 54 | 55 | class Lamb(Optimizer): 56 | """Implements a pure pytorch variant of FuseLAMB (NvLamb variant) optimizer from apex.optimizers.FusedLAMB 57 | reference: https://github.com/NVIDIA/DeepLearningExamples/blob/master/PyTorch/LanguageModeling/Transformer-XL/pytorch/lamb.py 58 | LAMB was proposed in `Large Batch Optimization for Deep Learning: Training BERT in 76 minutes`_. 59 | Arguments: 60 | params (iterable): iterable of parameters to optimize or dicts defining parameter groups. 61 | lr (float, optional): learning rate. (default: 1e-3) 62 | betas (Tuple[float, float], optional): coefficients used for computing 63 | running averages of gradient and its norm. (default: (0.9, 0.999)) 64 | eps (float, optional): term added to the denominator to improve 65 | numerical stability. (default: 1e-8) 66 | weight_decay (float, optional): weight decay (L2 penalty) (default: 0) 67 | grad_averaging (bool, optional): whether apply (1-beta2) to grad when 68 | calculating running averages of gradient. (default: True) 69 | max_grad_norm (float, optional): value used to clip global grad norm (default: 1.0) 70 | trust_clip (bool): enable LAMBC trust ratio clipping (default: False) 71 | always_adapt (boolean, optional): Apply adaptive learning rate to 0.0 72 | weight decay parameter (default: False) 73 | .. _Large Batch Optimization for Deep Learning - Training BERT in 76 minutes: 74 | https://arxiv.org/abs/1904.00962 75 | .. _On the Convergence of Adam and Beyond: 76 | https://openreview.net/forum?id=ryQu7f-RZ 77 | """ 78 | 79 | def __init__( 80 | self, params, lr=1e-3, bias_correction=True, betas=(0.9, 0.999), eps=1e-6, 81 | weight_decay=0.01, grad_averaging=True, max_grad_norm=1.0, trust_clip=False, always_adapt=False): 82 | defaults = dict( 83 | lr=lr, bias_correction=bias_correction, betas=betas, eps=eps, weight_decay=weight_decay, 84 | grad_averaging=grad_averaging, max_grad_norm=max_grad_norm, 85 | trust_clip=trust_clip, always_adapt=always_adapt) 86 | super().__init__(params, defaults) 87 | 88 | @torch.no_grad() 89 | def step(self, closure=None): 90 | """Performs a single optimization step. 91 | Arguments: 92 | closure (callable, optional): A closure that reevaluates the model 93 | and returns the loss. 94 | """ 95 | loss = None 96 | if closure is not None: 97 | with torch.enable_grad(): 98 | loss = closure() 99 | 100 | device = self.param_groups[0]['params'][0].device 101 | one_tensor = torch.tensor(1.0, device=device) # because torch.where doesn't handle scalars correctly 102 | global_grad_norm = torch.zeros(1, device=device) 103 | for group in self.param_groups: 104 | for p in group['params']: 105 | if p.grad is None: 106 | continue 107 | grad = p.grad 108 | if grad.is_sparse: 109 | raise RuntimeError('Lamb does not support sparse gradients, consider SparseAdam instad.') 110 | global_grad_norm.add_(grad.pow(2).sum()) 111 | 112 | global_grad_norm = torch.sqrt(global_grad_norm) 113 | # FIXME it'd be nice to remove explicit tensor conversion of scalars when torch.where promotes 114 | # scalar types properly https://github.com/pytorch/pytorch/issues/9190 115 | max_grad_norm = torch.tensor(self.defaults['max_grad_norm'], device=device) 116 | clip_global_grad_norm = torch.where( 117 | global_grad_norm > max_grad_norm, 118 | global_grad_norm / max_grad_norm, 119 | one_tensor) 120 | 121 | for group in self.param_groups: 122 | bias_correction = 1 if group['bias_correction'] else 0 123 | beta1, beta2 = group['betas'] 124 | grad_averaging = 1 if group['grad_averaging'] else 0 125 | beta3 = 1 - beta1 if grad_averaging else 1.0 126 | 127 | # assume same step across group now to simplify things 128 | # per parameter step can be easily support by making it tensor, or pass list into kernel 129 | if 'step' in group: 130 | group['step'] += 1 131 | else: 132 | group['step'] = 1 133 | 134 | if bias_correction: 135 | bias_correction1 = 1 - beta1 ** group['step'] 136 | bias_correction2 = 1 - beta2 ** group['step'] 137 | else: 138 | bias_correction1, bias_correction2 = 1.0, 1.0 139 | 140 | for p in group['params']: 141 | if p.grad is None: 142 | continue 143 | grad = p.grad.div_(clip_global_grad_norm) 144 | state = self.state[p] 145 | 146 | # State initialization 147 | if len(state) == 0: 148 | # Exponential moving average of gradient valuesa 149 | state['exp_avg'] = torch.zeros_like(p) 150 | # Exponential moving average of squared gradient values 151 | state['exp_avg_sq'] = torch.zeros_like(p) 152 | 153 | exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq'] 154 | 155 | # Decay the first and second moment running average coefficient 156 | exp_avg.mul_(beta1).add_(grad, alpha=beta3) # m_t 157 | exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=1 - beta2) # v_t 158 | 159 | denom = (exp_avg_sq.sqrt() / math.sqrt(bias_correction2)).add_(group['eps']) 160 | update = (exp_avg / bias_correction1).div_(denom) 161 | 162 | weight_decay = group['weight_decay'] 163 | if weight_decay != 0: 164 | update.add_(p, alpha=weight_decay) 165 | 166 | if weight_decay != 0 or group['always_adapt']: 167 | # Layer-wise LR adaptation. By default, skip adaptation on parameters that are 168 | # excluded from weight decay, unless always_adapt == True, then always enabled. 169 | w_norm = p.norm(2.0) 170 | g_norm = update.norm(2.0) 171 | # FIXME nested where required since logical and/or not working in PT XLA 172 | trust_ratio = torch.where( 173 | w_norm > 0, 174 | torch.where(g_norm > 0, w_norm / g_norm, one_tensor), 175 | one_tensor, 176 | ) 177 | if group['trust_clip']: 178 | # LAMBC trust clipping, upper bound fixed at one 179 | trust_ratio = torch.minimum(trust_ratio, one_tensor) 180 | update.mul_(trust_ratio) 181 | 182 | p.add_(update, alpha=-group['lr']) 183 | 184 | return loss 185 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright [yyyy] [name of copyright owner] 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | -------------------------------------------------------------------------------- /.pylintrc: -------------------------------------------------------------------------------- 1 | # This Pylint rcfile contains a best-effort configuration to uphold the 2 | # best-practices and style described in the Google Python style guide: 3 | # https://google.github.io/styleguide/pyguide.html 4 | # 5 | # Its canonical open-source location is: 6 | # https://google.github.io/styleguide/pylintrc 7 | 8 | [MASTER] 9 | 10 | # Files or directories to be skipped. They should be base names, not paths. 11 | ignore=third_party 12 | 13 | # Files or directories matching the regex patterns are skipped. The regex 14 | # matches against base names, not paths. 15 | ignore-patterns= 16 | 17 | # Pickle collected data for later comparisons. 18 | persistent=no 19 | 20 | # List of plugins (as comma separated values of python modules names) to load, 21 | # usually to register additional checkers. 22 | load-plugins= 23 | 24 | # Use multiple processes to speed up Pylint. 25 | jobs=4 26 | 27 | # Allow loading of arbitrary C extensions. Extensions are imported into the 28 | # active Python interpreter and may run arbitrary code. 29 | unsafe-load-any-extension=no 30 | 31 | 32 | [MESSAGES CONTROL] 33 | 34 | # Only show warnings with the listed confidence levels. Leave empty to show 35 | # all. Valid levels: HIGH, INFERENCE, INFERENCE_FAILURE, UNDEFINED 36 | confidence= 37 | 38 | # Enable the message, report, category or checker with the given id(s). You can 39 | # either give multiple identifier separated by comma (,) or put this option 40 | # multiple time (only on the command line, not in the configuration file where 41 | # it should appear only once). See also the "--disable" option for examples. 42 | #enable= 43 | 44 | # Disable the message, report, category or checker with the given id(s). You 45 | # can either give multiple identifiers separated by comma (,) or put this 46 | # option multiple times (only on the command line, not in the configuration 47 | # file where it should appear only once).You can also use "--disable=all" to 48 | # disable everything first and then reenable specific checks. For example, if 49 | # you want to run only the similarities checker, you can use "--disable=all 50 | # --enable=similarities". If you want to run only the classes checker, but have 51 | # no Warning level messages displayed, use"--disable=all --enable=classes 52 | # --disable=W" 53 | disable=abstract-method, 54 | apply-builtin, 55 | arguments-differ, 56 | attribute-defined-outside-init, 57 | backtick, 58 | bad-option-value, 59 | basestring-builtin, 60 | buffer-builtin, 61 | c-extension-no-member, 62 | consider-using-enumerate, 63 | consider-using-f-string, 64 | cmp-builtin, 65 | cmp-method, 66 | coerce-builtin, 67 | coerce-method, 68 | delslice-method, 69 | div-method, 70 | duplicate-code, 71 | eq-without-hash, 72 | execfile-builtin, 73 | file-builtin, 74 | filter-builtin-not-iterating, 75 | fixme, 76 | getslice-method, 77 | global-statement, 78 | hex-method, 79 | idiv-method, 80 | implicit-str-concat-in-sequence, 81 | import-error, 82 | import-self, 83 | import-star-module-level, 84 | inconsistent-return-statements, 85 | input-builtin, 86 | intern-builtin, 87 | invalid-str-codec, 88 | locally-disabled, 89 | logging-format-interpolation, 90 | logging-fstring-interpolation, 91 | long-builtin, 92 | long-suffix, 93 | map-builtin-not-iterating, 94 | misplaced-comparison-constant, 95 | missing-function-docstring, 96 | missing-module-docstring, 97 | metaclass-assignment, 98 | next-method-called, 99 | next-method-defined, 100 | no-absolute-import, 101 | no-else-break, 102 | no-else-continue, 103 | no-else-raise, 104 | no-else-return, 105 | no-init, # added 106 | no-member, 107 | no-name-in-module, 108 | no-self-use, 109 | nonzero-method, 110 | oct-method, 111 | old-division, 112 | old-ne-operator, 113 | old-octal-literal, 114 | old-raise-syntax, 115 | parameter-unpacking, 116 | print-statement, 117 | protected-access, 118 | raising-string, 119 | range-builtin-not-iterating, 120 | raw_input-builtin, 121 | rdiv-method, 122 | reduce-builtin, 123 | relative-import, 124 | reload-builtin, 125 | round-builtin, 126 | setslice-method, 127 | signature-differs, 128 | standarderror-builtin, 129 | suppressed-message, 130 | sys-max-int, 131 | too-few-public-methods, 132 | too-many-ancestors, 133 | too-many-arguments, 134 | too-many-boolean-expressions, 135 | too-many-branches, 136 | too-many-instance-attributes, 137 | too-many-locals, 138 | too-many-nested-blocks, 139 | too-many-public-methods, 140 | too-many-return-statements, 141 | too-many-statements, 142 | trailing-newlines, 143 | unichr-builtin, 144 | unicode-builtin, 145 | unnecessary-pass, 146 | unpacking-in-except, 147 | unspecified-encoding, 148 | useless-else-on-loop, 149 | useless-object-inheritance, 150 | useless-suppression, 151 | using-cmp-argument, 152 | wrong-import-order, 153 | xrange-builtin, 154 | zip-builtin-not-iterating, 155 | 156 | 157 | [REPORTS] 158 | 159 | # Set the output format. Available formats are text, parseable, colorized, msvs 160 | # (visual studio) and html. You can also give a reporter class, eg 161 | # mypackage.mymodule.MyReporterClass. 162 | output-format=text 163 | 164 | # Tells whether to display a full report or only the messages 165 | reports=no 166 | 167 | # Python expression which should return a note less than 10 (10 is the highest 168 | # note). You have access to the variables errors warning, statement which 169 | # respectively contain the number of errors / warnings messages and the total 170 | # number of statements analyzed. This is used by the global evaluation report 171 | # (RP0004). 172 | evaluation=10.0 - ((float(5 * error + warning + refactor + convention) / statement) * 10) 173 | 174 | # Template used to display messages. This is a python new-style format string 175 | # used to format the message information. See doc for all details 176 | #msg-template= 177 | 178 | 179 | [BASIC] 180 | 181 | # Good variable names which should always be accepted, separated by a comma 182 | good-names=main,_ 183 | 184 | # Bad variable names which should always be refused, separated by a comma 185 | bad-names= 186 | 187 | # Colon-delimited sets of names that determine each other's naming style when 188 | # the name regexes allow several styles. 189 | name-group= 190 | 191 | # Include a hint for the correct naming format with invalid-name 192 | include-naming-hint=no 193 | 194 | # List of decorators that produce properties, such as abc.abstractproperty. Add 195 | # to this list to register other decorators that produce valid properties. 196 | property-classes=abc.abstractproperty,cached_property.cached_property,cached_property.threaded_cached_property,cached_property.cached_property_with_ttl,cached_property.threaded_cached_property_with_ttl 197 | 198 | # Regular expression matching correct function names 199 | function-rgx=^(?:(?PsetUp|tearDown|setUpModule|tearDownModule)|(?P_?[A-Z][a-zA-Z0-9]*)|(?P_?[a-z][a-z0-9_]*))$ 200 | 201 | # Regular expression matching correct variable names 202 | variable-rgx=^[a-z][a-z0-9_]*$ 203 | 204 | # Regular expression matching correct constant names 205 | const-rgx=^(_?[A-Z][A-Z0-9_]*|__[a-z0-9_]+__|_?[a-z][a-z0-9_]*)$ 206 | 207 | # Regular expression matching correct attribute names 208 | attr-rgx=^_{0,2}[a-z][a-z0-9_]*$ 209 | 210 | # Regular expression matching correct argument names 211 | argument-rgx=^[a-z][a-z0-9_]*$ 212 | 213 | # Regular expression matching correct class attribute names 214 | class-attribute-rgx=^(_?[A-Z][A-Z0-9_]*|__[a-z0-9_]+__|_?[a-z][a-z0-9_]*)$ 215 | 216 | # Regular expression matching correct inline iteration names 217 | inlinevar-rgx=^[a-z][a-z0-9_]*$ 218 | 219 | # Regular expression matching correct class names 220 | class-rgx=^_?[A-Z][a-zA-Z0-9]*$ 221 | 222 | # Regular expression matching correct module names 223 | module-rgx=^(_?[a-z][a-z0-9_]*|__init__)$ 224 | 225 | # Regular expression matching correct method names 226 | method-rgx=(?x)^(?:(?P_[a-z0-9_]+__|runTest|setUp|tearDown|setUpTestCase|tearDownTestCase|setupSelf|tearDownClass|setUpClass|(test|assert)_*[A-Z0-9][a-zA-Z0-9_]*|next)|(?P_{0,2}[A-Z][a-zA-Z0-9_]*)|(?P_{0,2}[a-z][a-z0-9_]*))$ 227 | 228 | # Regular expression which should only match function or class names that do 229 | # not require a docstring. 230 | no-docstring-rgx=(__.*__|main|test.*|.*test|.*Test)$ 231 | 232 | # Minimum line length for functions/classes that require docstrings, shorter 233 | # ones are exempt. 234 | docstring-min-length=10 235 | 236 | 237 | [TYPECHECK] 238 | 239 | # List of decorators that produce context managers, such as 240 | # contextlib.contextmanager. Add to this list to register other decorators that 241 | # produce valid context managers. 242 | contextmanager-decorators=contextlib.contextmanager,contextlib2.contextmanager 243 | 244 | # Tells whether missing members accessed in mixin class should be ignored. A 245 | # mixin class is detected if its name ends with "mixin" (case insensitive). 246 | ignore-mixin-members=yes 247 | 248 | # List of module names for which member attributes should not be checked 249 | # (useful for modules/projects where namespaces are manipulated during runtime 250 | # and thus existing member attributes cannot be deduced by static analysis. It 251 | # supports qualified module names, as well as Unix pattern matching. 252 | ignored-modules= 253 | 254 | # List of class names for which member attributes should not be checked (useful 255 | # for classes with dynamically set attributes). This supports the use of 256 | # qualified names. 257 | ignored-classes=optparse.Values,thread._local,_thread._local 258 | 259 | # List of members which are set dynamically and missed by pylint inference 260 | # system, and so shouldn't trigger E1101 when accessed. Python regular 261 | # expressions are accepted. 262 | generated-members= 263 | 264 | 265 | [FORMAT] 266 | 267 | # Maximum number of characters on a single line. 268 | max-line-length=120 269 | 270 | # TODO(https://github.com/PyCQA/pylint/issues/3352): Direct pylint to exempt 271 | # lines made too long by directives to pytype. 272 | 273 | # Regexp for a line that is allowed to be longer than the limit. 274 | ignore-long-lines=(?x)( 275 | ^\s*(\#\ )??$| 276 | ^\s*(from\s+\S+\s+)?import\s+.+$) 277 | 278 | # Allow the body of an if to be on the same line as the test if there is no 279 | # else. 280 | single-line-if-stmt=yes 281 | 282 | # Maximum number of lines in a module 283 | max-module-lines=99999 284 | 285 | # String used as indentation unit. The internal Google style guide mandates 2 286 | # spaces. Google's externaly-published style guide says 4, consistent with 287 | # PEP 8. Here, we use 2 spaces, for conformity with many open-sourced Google 288 | # projects (like TensorFlow). 289 | indent-string=' ' 290 | 291 | # Number of spaces of indent required inside a hanging or continued line. 292 | indent-after-paren=4 293 | 294 | # Expected format of line ending, e.g. empty (any line ending), LF or CRLF. 295 | expected-line-ending-format= 296 | 297 | 298 | [MISCELLANEOUS] 299 | 300 | # List of note tags to take in consideration, separated by a comma. 301 | notes=TODO 302 | 303 | 304 | [STRING] 305 | 306 | # This flag controls whether inconsistent-quotes generates a warning when the 307 | # character used as a quote delimiter is used inconsistently within a module. 308 | check-quote-consistency=yes 309 | 310 | 311 | [VARIABLES] 312 | 313 | # Tells whether we should check for unused import in __init__ files. 314 | init-import=no 315 | 316 | # A regular expression matching the name of dummy variables (i.e. expectedly 317 | # not used). 318 | dummy-variables-rgx=^\*{0,2}(_$|unused_|dummy_) 319 | 320 | # List of additional names supposed to be defined in builtins. Remember that 321 | # you should avoid to define new builtins when possible. 322 | additional-builtins= 323 | 324 | # List of strings which can identify a callback function by name. A callback 325 | # name must start or end with one of those strings. 326 | callbacks=cb_,_cb 327 | 328 | # List of qualified module names which can have objects that can redefine 329 | # builtins. 330 | redefining-builtins-modules=six,six.moves,past.builtins,future.builtins,functools 331 | 332 | 333 | [LOGGING] 334 | 335 | # Logging modules to check that the string format arguments are in logging 336 | # function parameter format 337 | logging-modules=logging,absl.logging,tensorflow.io.logging 338 | 339 | 340 | [SIMILARITIES] 341 | 342 | # Minimum lines number of a similarity. 343 | min-similarity-lines=4 344 | 345 | # Ignore comments when computing similarities. 346 | ignore-comments=yes 347 | 348 | # Ignore docstrings when computing similarities. 349 | ignore-docstrings=yes 350 | 351 | # Ignore imports when computing similarities. 352 | ignore-imports=no 353 | 354 | 355 | [SPELLING] 356 | 357 | # Spelling dictionary name. Available dictionaries: none. To make it working 358 | # install python-enchant package. 359 | spelling-dict= 360 | 361 | # List of comma separated words that should not be checked. 362 | spelling-ignore-words= 363 | 364 | # A path to a file that contains private dictionary; one word per line. 365 | spelling-private-dict-file= 366 | 367 | # Tells whether to store unknown words to indicated private dictionary in 368 | # --spelling-private-dict-file option instead of raising a message. 369 | spelling-store-unknown-words=no 370 | 371 | 372 | [IMPORTS] 373 | 374 | # Deprecated modules which should not be used, separated by a comma 375 | deprecated-modules=regsub, 376 | TERMIOS, 377 | Bastion, 378 | rexec, 379 | sets 380 | 381 | # Create a graph of every (i.e. internal and external) dependencies in the 382 | # given file (report RP0402 must not be disabled) 383 | import-graph= 384 | 385 | # Create a graph of external dependencies in the given file (report RP0402 must 386 | # not be disabled) 387 | ext-import-graph= 388 | 389 | # Create a graph of internal dependencies in the given file (report RP0402 must 390 | # not be disabled) 391 | int-import-graph= 392 | 393 | # Force import order to recognize a module as part of the standard 394 | # compatibility libraries. 395 | known-standard-library= 396 | 397 | # Force import order to recognize a module as part of a third party library. 398 | known-third-party=enchant, absl 399 | 400 | # Analyse import fallback blocks. This can be used to support both Python 2 and 401 | # 3 compatible code, which means that the block might have code that exists 402 | # only in one or another interpreter, leading to false positives when analysed. 403 | analyse-fallback-blocks=no 404 | 405 | 406 | [CLASSES] 407 | 408 | # List of method names used to declare (i.e. assign) instance attributes. 409 | defining-attr-methods=__init__, 410 | __new__, 411 | setUp 412 | 413 | # List of member names, which should be excluded from the protected access 414 | # warning. 415 | exclude-protected=_asdict, 416 | _fields, 417 | _replace, 418 | _source, 419 | _make 420 | 421 | # List of valid names for the first argument in a class method. 422 | valid-classmethod-first-arg=cls, 423 | class_ 424 | 425 | # List of valid names for the first argument in a metaclass class method. 426 | valid-metaclass-classmethod-first-arg=mcs 427 | 428 | 429 | [EXCEPTIONS] 430 | 431 | # Exceptions that will emit a warning when being caught. Defaults to 432 | # "Exception" 433 | overgeneral-exceptions=StandardError, 434 | Exception, 435 | BaseException 436 | -------------------------------------------------------------------------------- /easytorch/core/runner.py: -------------------------------------------------------------------------------- 1 | import os 2 | import time 3 | import logging 4 | from abc import ABCMeta, abstractmethod 5 | from typing import Tuple, Union, Optional 6 | 7 | from tqdm import tqdm 8 | import torch 9 | from torch import nn 10 | from torch.utils.data import Dataset, DataLoader 11 | from torch.utils.data.distributed import DistributedSampler 12 | from torch.nn.parallel import DistributedDataParallel as DDP 13 | from torch.utils.tensorboard import SummaryWriter 14 | 15 | from .meter_pool import MeterPool 16 | from .checkpoint import load_ckpt, save_ckpt, backup_last_ckpt, clear_ckpt 17 | from .data_loader import build_data_loader, build_data_loader_ddp 18 | from .optimizer_builder import build_optim, build_lr_scheduler 19 | from ..config import Config, get_ckpt_save_dir 20 | from ..utils import TimePredictor, get_logger, get_local_rank, is_master, master_only, set_env 21 | from ..utils.data_prefetcher import DevicePrefetcher 22 | from ..device import to_device 23 | 24 | 25 | class Runner(metaclass=ABCMeta): 26 | """Base EasyTorch Runner 27 | """ 28 | 29 | def __init__(self, cfg: Config): 30 | # default logger 31 | self.logger = get_logger('easytorch') 32 | 33 | # set env 34 | set_env(cfg.get('ENV', {})) 35 | 36 | # param 37 | self.model_name = cfg['MODEL.NAME'] 38 | self.ckpt_save_dir = get_ckpt_save_dir(cfg) 39 | self.logger.info('Set ckpt save dir: \'{}\''.format(self.ckpt_save_dir)) 40 | self.ckpt_save_strategy = None 41 | self.num_epochs = None 42 | self.start_epoch = None 43 | 44 | self.val_interval = 1 45 | 46 | # create checkpoint save dir 47 | if not os.path.isdir(self.ckpt_save_dir): 48 | os.makedirs(self.ckpt_save_dir) 49 | 50 | # create model 51 | self.model = self.build_model(cfg) 52 | 53 | # declare optimizer and lr_scheduler 54 | self.optim = None 55 | self.scheduler = None 56 | 57 | # declare data loader 58 | self.train_data_loader = None 59 | self.val_data_loader = None 60 | 61 | # declare meter pool 62 | self.meter_pool = None 63 | 64 | # declare tensorboard_writer 65 | self.tensorboard_writer = None 66 | 67 | def init_logger(self, logger: logging.Logger = None, logger_name: str = None, 68 | log_file_name: str = None, log_level: int = logging.INFO): 69 | """Initialize logger. 70 | 71 | Args: 72 | logger (logging.Logger, optional): specified logger. 73 | logger_name (str, optional): specified name of logger. 74 | log_file_name (str, optional): logger file name. 75 | log_level (int, optional): log level, default is INFO. 76 | """ 77 | 78 | if logger is not None: 79 | self.logger = logger 80 | elif logger_name is not None: 81 | if log_file_name is not None: 82 | log_file_name = '{}_{}.log'.format(log_file_name, time.strftime('%Y%m%d%H%M%S', time.localtime())) 83 | log_file_path = os.path.join(self.ckpt_save_dir, log_file_name) 84 | else: 85 | log_file_path = None 86 | self.logger = get_logger(logger_name, log_file_path, log_level) 87 | else: 88 | raise TypeError('At least one of logger and logger_name is not None') 89 | 90 | @staticmethod 91 | @abstractmethod 92 | def define_model(cfg: Config) -> nn.Module: 93 | """It must be implement to define the model for training or inference. 94 | 95 | Users can select different models by param in cfg. 96 | 97 | Args: 98 | cfg (Dict): config 99 | 100 | Returns: 101 | model (nn.Module) 102 | """ 103 | 104 | pass 105 | 106 | @staticmethod 107 | @abstractmethod 108 | def build_train_dataset(cfg: Config) -> Dataset: 109 | """It must be implement to build dataset for training. 110 | 111 | Args: 112 | cfg (Dict): config 113 | 114 | Returns: 115 | train dataset (Dataset) 116 | """ 117 | 118 | pass 119 | 120 | @staticmethod 121 | def build_val_dataset(cfg: Config): 122 | """It can be implement to build dataset for validation (not necessary). 123 | 124 | Args: 125 | cfg (Dict): config 126 | 127 | Returns: 128 | val dataset (Dataset) 129 | """ 130 | 131 | raise NotImplementedError() 132 | 133 | def build_train_data_loader(self, cfg: Config) -> DataLoader: 134 | """Build train dataset and dataloader. 135 | Build dataset by calling ```self.build_train_dataset```, 136 | build dataloader by calling ```build_data_loader``` or 137 | ```build_data_loader_ddp``` when DDP is initialized 138 | 139 | Args: 140 | cfg (Dict): config 141 | 142 | Returns: 143 | train data loader (DataLoader) 144 | """ 145 | 146 | self.logger.info('Building training data loader.') 147 | dataset = self.build_train_dataset(cfg) 148 | if torch.distributed.is_initialized(): 149 | return build_data_loader_ddp(dataset, cfg['TRAIN.DATA']) 150 | else: 151 | return build_data_loader(dataset, cfg['TRAIN.DATA']) 152 | 153 | def build_val_data_loader(self, cfg: Config) -> DataLoader: 154 | """Build val dataset and dataloader. 155 | Build dataset by calling ```self.build_train_dataset```, 156 | build dataloader by calling ```build_data_loader```. 157 | 158 | Args: 159 | cfg (Dict): config 160 | 161 | Returns: 162 | val data loader (DataLoader) 163 | """ 164 | 165 | self.logger.info('Building val data loader.') 166 | dataset = self.build_val_dataset(cfg) 167 | return build_data_loader(dataset, cfg['VAL.DATA']) 168 | 169 | def build_model(self, cfg: Config) -> nn.Module: 170 | """Build model. 171 | 172 | Initialize model by calling ```self.define_model```, 173 | Moves model to the GPU. 174 | 175 | If DDP is initialized, initialize the DDP wrapper. 176 | 177 | Args: 178 | cfg (Dict): config 179 | 180 | Returns: 181 | model (nn.Module) 182 | """ 183 | 184 | self.logger.info('Building model.') 185 | model = self.define_model(cfg) 186 | model = to_device(model) 187 | if torch.distributed.is_initialized(): 188 | model = DDP( 189 | model, 190 | device_ids=[get_local_rank()], 191 | find_unused_parameters=cfg.get('MODEL.DDP_FIND_UNUSED_PARAMETERS', False) 192 | ) 193 | return model 194 | 195 | def get_ckpt_path(self, epoch: int) -> str: 196 | """Get checkpoint path. 197 | 198 | The format is "{ckpt_save_dir}/{model_name}_{epoch}" 199 | 200 | Args: 201 | epoch (int): current epoch. 202 | 203 | Returns: 204 | checkpoint path (str) 205 | """ 206 | 207 | epoch_str = str(epoch).zfill(len(str(self.num_epochs))) 208 | ckpt_name = '{}_{}.pt'.format(self.model_name, epoch_str) 209 | return os.path.join(self.ckpt_save_dir, ckpt_name) 210 | 211 | @master_only 212 | def save_model(self, epoch: int): 213 | """Save checkpoint every epoch. 214 | 215 | checkpoint format is { 216 | 'epoch': current epoch ([1, num_epochs]), 217 | 'model_state_dict': state_dict of model, 218 | 'optim_state_dict': state_dict of optimizer 219 | } 220 | 221 | Decide whether to delete the last checkpoint by the checkpoint save strategy. 222 | 223 | Args: 224 | epoch (int): current epoch. 225 | """ 226 | 227 | model = self.model.module if isinstance(self.model, DDP) else self.model 228 | ckpt_dict = { 229 | 'epoch': epoch, 230 | 'model_state_dict': model.state_dict(), 231 | 'optim_state_dict': self.optim.state_dict(), 232 | 'best_metrics': self.best_metrics 233 | } 234 | 235 | # backup last epoch 236 | last_ckpt_path = self.get_ckpt_path(epoch - 1) 237 | backup_last_ckpt(last_ckpt_path, epoch, self.ckpt_save_strategy) 238 | 239 | # save ckpt 240 | ckpt_path = self.get_ckpt_path(epoch) 241 | save_ckpt(ckpt_dict, ckpt_path, self.logger) 242 | 243 | # clear ckpt every 10 epoch or in the end 244 | if epoch % 10 == 0 or epoch == self.num_epochs: 245 | clear_ckpt(self.ckpt_save_dir) 246 | 247 | def load_model_resume(self, strict: bool = True): 248 | """Load last checkpoint in checkpoint save dir to resume training. 249 | 250 | Load model state dict. 251 | Load optimizer state dict. 252 | Load start epoch and set it to lr_scheduler. 253 | 254 | Args: 255 | strict (bool, optional): whether to strictly enforce that the keys 256 | in :attr:`state_dict` match the keys returned by this module's 257 | :meth:`~torch.nn.Module.state_dict` function. Default: ``True`` 258 | """ 259 | 260 | try: 261 | checkpoint_dict = load_ckpt(self.ckpt_save_dir, logger=self.logger) 262 | if isinstance(self.model, DDP): 263 | self.model.module.load_state_dict(checkpoint_dict['model_state_dict'], strict=strict) 264 | else: 265 | self.model.load_state_dict(checkpoint_dict['model_state_dict'], strict=strict) 266 | self.optim.load_state_dict(checkpoint_dict['optim_state_dict']) 267 | self.start_epoch = checkpoint_dict['epoch'] 268 | if checkpoint_dict.get('best_metrics') is not None: 269 | self.best_metrics = checkpoint_dict['best_metrics'] 270 | if self.scheduler is not None: 271 | self.scheduler.last_epoch = checkpoint_dict['epoch'] 272 | self.logger.info('Resume training') 273 | except (IndexError, OSError, KeyError): 274 | pass 275 | 276 | def load_model(self, ckpt_path: str = None, strict: bool = True): 277 | """Load model state dict. 278 | if param `ckpt_path` is None, load the last checkpoint in `self.ckpt_save_dir`, 279 | else load checkpoint from `ckpt_path` 280 | 281 | Args: 282 | ckpt_path (str, optional): checkpoint path, default is None 283 | strict (bool, optional): whether to strictly enforce that the keys 284 | in :attr:`state_dict` match the keys returned by this module's 285 | :meth:`~torch.nn.Module.state_dict` function. Default: ``True`` 286 | """ 287 | 288 | try: 289 | checkpoint_dict = load_ckpt(self.ckpt_save_dir, ckpt_path=ckpt_path, logger=self.logger) 290 | if isinstance(self.model, DDP): 291 | self.model.module.load_state_dict(checkpoint_dict['model_state_dict'], strict=strict) 292 | else: 293 | self.model.load_state_dict(checkpoint_dict['model_state_dict'], strict=strict) 294 | except (IndexError, OSError) as e: 295 | raise OSError('Ckpt file does not exist') from e 296 | 297 | def train(self, cfg: Config): 298 | """Train model. 299 | 300 | Train process: 301 | [init_training] 302 | for in train_epoch 303 | [on_epoch_start] 304 | for in train iters 305 | [train_iters] 306 | [on_epoch_end] ------> Epoch Val: val every n epoch 307 | [on_validating_start] 308 | for in val iters 309 | val iter 310 | [on_validating_end] 311 | [on_training_end] 312 | 313 | Args: 314 | cfg (Dict): config 315 | """ 316 | 317 | self.init_training(cfg) 318 | 319 | # train time predictor 320 | train_time_predictor = TimePredictor(self.start_epoch, self.num_epochs) 321 | 322 | # training loop 323 | for epoch_index in range(self.start_epoch, self.num_epochs): 324 | epoch = epoch_index + 1 325 | self.on_epoch_start(epoch) 326 | epoch_start_time = time.time() 327 | # start training 328 | self.model.train() 329 | 330 | # tqdm process bar 331 | if cfg.get('TRAIN.DATA.DEVICE_PREFETCH', False): 332 | data_loader = DevicePrefetcher(self.train_data_loader) 333 | else: 334 | data_loader = self.train_data_loader 335 | data_loader = tqdm(data_loader) if get_local_rank() == 0 else data_loader 336 | 337 | # data loop 338 | for iter_index, data in enumerate(data_loader): 339 | loss = self.train_iters(epoch, iter_index, data) 340 | if loss is not None: 341 | self.backward(loss) 342 | # update lr_scheduler 343 | if self.scheduler is not None: 344 | self.scheduler.step() 345 | 346 | epoch_end_time = time.time() 347 | # epoch time 348 | self.update_epoch_meter('train_time', epoch_end_time - epoch_start_time) 349 | self.on_epoch_end(epoch) 350 | 351 | expected_end_time = train_time_predictor.get_expected_end_time(epoch) 352 | 353 | # estimate training finish time 354 | if epoch < self.num_epochs: 355 | self.logger.info('The estimated training finish time is {}'.format( 356 | time.strftime('%Y-%m-%d %H:%M:%S', time.localtime(expected_end_time)))) 357 | 358 | # log training finish time 359 | self.logger.info('The training finished at {}'.format( 360 | time.strftime('%Y-%m-%d %H:%M:%S', time.localtime()) 361 | )) 362 | 363 | self.on_training_end() 364 | 365 | def init_lr_scheduler(self, cfg: Config): 366 | """Initialize lr_scheduler 367 | 368 | Args: 369 | cfg (Dict): config 370 | """ 371 | # create lr_scheduler 372 | if cfg.has('TRAIN.LR_SCHEDULER'): 373 | self.scheduler = build_lr_scheduler(cfg['TRAIN.LR_SCHEDULER'], self.optim) 374 | self.logger.info('Set lr_scheduler: {}'.format(self.scheduler)) 375 | self.register_epoch_meter('lr', 'train', '{:.2e}') 376 | 377 | def init_training(self, cfg: Config): 378 | """Initialize training 379 | 380 | Args: 381 | cfg (Dict): config 382 | """ 383 | 384 | self.logger.info('Initializing training.') 385 | 386 | # init training param 387 | self.num_epochs = cfg['TRAIN.NUM_EPOCHS'] 388 | self.start_epoch = 0 389 | self.ckpt_save_strategy = cfg.get('TRAIN.CKPT_SAVE_STRATEGY') 390 | self.best_metrics = {} 391 | self.clip_grad_param = cfg.get('TRAIN.CLIP_GRAD_PARAM') 392 | if self.clip_grad_param is not None: 393 | self.logger.info('Set clip grad, param: {}'.format(self.clip_grad_param)) 394 | 395 | # train data loader 396 | self.train_data_loader = self.build_train_data_loader(cfg) 397 | self.register_epoch_meter('train_time', 'train', '{:.2f} (s)', plt=False) 398 | 399 | # create optim 400 | self.optim = build_optim(cfg['TRAIN.OPTIM'], self.model) 401 | self.logger.info('Set optim: {}'.format(self.optim)) 402 | 403 | # create lr_scheduler 404 | self.init_lr_scheduler(cfg) 405 | 406 | # fine tune 407 | if cfg.has('TRAIN.FINETUNE_FROM'): 408 | self.load_model(cfg['TRAIN.FINETUNE_FROM'], cfg.get('TRAIN.FINETUNE_STRICT_LOAD', True)) 409 | self.logger.info('Start fine tuning') 410 | 411 | # resume 412 | self.load_model_resume() 413 | 414 | # init tensorboard(after resume) 415 | if is_master(): 416 | self.tensorboard_writer = SummaryWriter( 417 | os.path.join(self.ckpt_save_dir, 'tensorboard'), 418 | purge_step=(self.start_epoch + 1) if self.start_epoch != 0 else None 419 | ) 420 | 421 | # init validation 422 | if cfg.has('VAL'): 423 | self.init_validation(cfg) 424 | 425 | def on_epoch_start(self, epoch: int): 426 | """Callback at the start of an epoch. 427 | 428 | Args: 429 | epoch (int): current epoch 430 | """ 431 | 432 | # print epoch num 433 | self.logger.info('Epoch {:d} / {:d}'.format(epoch, self.num_epochs)) 434 | # update lr meter 435 | if self.scheduler is not None: 436 | self.update_epoch_meter('lr', self.scheduler.get_last_lr()[0]) 437 | 438 | # set epoch for sampler in distributed mode 439 | # see https://pytorch.org/docs/stable/data.html 440 | sampler = self.train_data_loader.sampler 441 | if torch.distributed.is_initialized() and isinstance(sampler, DistributedSampler) and sampler.shuffle: 442 | sampler.set_epoch(epoch) 443 | 444 | def on_epoch_end(self, epoch: int): 445 | """Callback at the end of an epoch. 446 | 447 | Args: 448 | epoch (int): current epoch. 449 | """ 450 | 451 | # print train meters 452 | self.print_epoch_meters('train') 453 | # tensorboard plt meters 454 | self.plt_epoch_meters('train', epoch) 455 | # validate 456 | if self.val_data_loader is not None and epoch % self.val_interval == 0: 457 | self.validate(train_epoch=epoch) 458 | # save model 459 | self.save_model(epoch) 460 | # reset meters 461 | self.reset_epoch_meters() 462 | 463 | def on_training_end(self): 464 | """Callback at the end of training. 465 | """ 466 | 467 | if is_master(): 468 | # close tensorboard writer 469 | self.tensorboard_writer.close() 470 | 471 | @abstractmethod 472 | def train_iters(self, epoch: int, iter_index: int, data: Union[torch.Tensor, Tuple]) -> torch.Tensor: 473 | """It must be implement to define training detail. 474 | 475 | If it returns `loss`, the function ```self.backward``` will be called. 476 | 477 | Args: 478 | epoch (int): current epoch. 479 | iter_index (int): current iter. 480 | data (torch.Tensor or tuple): Data provided by DataLoader 481 | 482 | Returns: 483 | loss (torch.Tensor) 484 | """ 485 | 486 | pass 487 | 488 | def backward(self, loss: torch.Tensor): 489 | """Backward and update params. 490 | 491 | Args: 492 | loss (torch.Tensor): loss 493 | """ 494 | 495 | self.optim.zero_grad() 496 | loss.backward() 497 | if self.clip_grad_param is not None: 498 | torch.nn.utils.clip_grad_norm_(self.model.parameters(), **self.clip_grad_param) 499 | self.optim.step() 500 | 501 | @torch.no_grad() 502 | @master_only 503 | def validate(self, cfg: Config = None, train_epoch: Optional[int] = None): 504 | """Validate model. 505 | 506 | Args: 507 | cfg (Dict, optional): config 508 | train_epoch (int, optional): current epoch if in training process. 509 | """ 510 | 511 | # init validation if not in training process 512 | if train_epoch is None: 513 | self.init_validation(cfg) 514 | 515 | self.logger.info('Start validation.') 516 | 517 | self.on_validating_start(train_epoch) 518 | 519 | val_start_time = time.time() 520 | self.model.eval() 521 | 522 | # tqdm process bar 523 | data_iter = tqdm(self.val_data_loader) 524 | 525 | # val loop 526 | for iter_index, data in enumerate(data_iter): 527 | self.val_iters(iter_index, data) 528 | 529 | val_end_time = time.time() 530 | self.update_epoch_meter('val_time', val_end_time - val_start_time) 531 | # print val meters 532 | self.print_epoch_meters('val') 533 | if train_epoch is not None: 534 | # tensorboard plt meters 535 | self.plt_epoch_meters('val', train_epoch // self.val_interval) 536 | 537 | self.on_validating_end(train_epoch) 538 | 539 | @master_only 540 | def init_validation(self, cfg: Config): 541 | """Initialize validation 542 | 543 | Args: 544 | cfg (Dict): config 545 | """ 546 | 547 | self.logger.info('Initializing validation.') 548 | self.val_interval = cfg.get('VAL.INTERVAL', 1) 549 | self.val_data_loader = self.build_val_data_loader(cfg) 550 | self.register_epoch_meter('val_time', 'val', '{:.2f} (s)', plt=False) 551 | 552 | @master_only 553 | def on_validating_start(self, train_epoch: Optional[int]): 554 | """Callback at the start of validating. 555 | 556 | Args: 557 | train_epoch (Optional[int]): current epoch if in training process. 558 | """ 559 | 560 | pass 561 | 562 | @master_only 563 | def on_validating_end(self, train_epoch: Optional[int]): 564 | """Callback at the end of validating. 565 | 566 | Args: 567 | train_epoch (Optional[int]): current epoch if in training process. 568 | """ 569 | 570 | pass 571 | 572 | def val_iters(self, iter_index: int, data: Union[torch.Tensor, Tuple]): 573 | """It can be implement to define validating detail (not necessary). 574 | 575 | Args: 576 | iter_index (int): current iter. 577 | data (Union[torch.Tensor, Tuple]): Data provided by DataLoader 578 | """ 579 | 580 | raise NotImplementedError() 581 | 582 | @master_only 583 | def save_best_model(self, epoch: int, metric_name: str, greater_best: bool = True): 584 | """Save the best model while training. 585 | 586 | Examples: 587 | >>> def on_validating_end(self, train_epoch: Optional[int]): 588 | >>> if train_epoch is not None: 589 | >>> self.save_best_model(train_epoch, 'val/loss', greater_best=False) 590 | 591 | Args: 592 | epoch (int): current epoch. 593 | metric_name (str): metric name used to measure the model, must be registered in `epoch_meter`. 594 | greater_best (bool, optional): `True` means greater value is best, such as `acc` 595 | `False` means lower value is best, such as `loss`. Defaults to True. 596 | """ 597 | 598 | metric = self.meter_pool.get_avg(metric_name) 599 | best_metric = self.best_metrics.get(metric_name) 600 | if best_metric is None or (metric > best_metric if greater_best else metric < best_metric): 601 | self.best_metrics[metric_name] = metric 602 | model = self.model.module if isinstance(self.model, DDP) else self.model 603 | ckpt_dict = { 604 | 'epoch': epoch, 605 | 'model_state_dict': model.state_dict(), 606 | 'optim_state_dict': self.optim.state_dict(), 607 | 'best_metrics': self.best_metrics 608 | } 609 | ckpt_path = os.path.join( 610 | self.ckpt_save_dir, 611 | '{}_best_{}.pt'.format(self.model_name, metric_name.replace('/', '_')) 612 | ) 613 | save_ckpt(ckpt_dict, ckpt_path, self.logger) 614 | 615 | @master_only 616 | def register_epoch_meter(self, name, meter_type, fmt='{:f}', plt=True): 617 | if self.meter_pool is None: 618 | self.meter_pool = MeterPool() 619 | self.meter_pool.register(name, meter_type, fmt, plt) 620 | 621 | @master_only 622 | def update_epoch_meter(self, name, value, n=1): 623 | self.meter_pool.update(name, value, n) 624 | 625 | @master_only 626 | def print_epoch_meters(self, meter_type): 627 | self.meter_pool.print_meters(meter_type, self.logger) 628 | 629 | @master_only 630 | def plt_epoch_meters(self, meter_type, step): 631 | self.meter_pool.plt_meters(meter_type, step, self.tensorboard_writer) 632 | 633 | @master_only 634 | def reset_epoch_meters(self): 635 | self.meter_pool.reset() 636 | --------------------------------------------------------------------------------