├── requirements.txt ├── torchpack ├── version.py ├── __init__.py ├── runner │ ├── hooks │ │ ├── closure.py │ │ ├── timer.py │ │ ├── __init__.py │ │ ├── optimizer_stepper.py │ │ ├── checkpoint_saver.py │ │ ├── text_logger.py │ │ ├── tensorboard_logger.py │ │ ├── logger.py │ │ ├── hook.py │ │ ├── pavi_logger.py │ │ └── lr_updater.py │ ├── __init__.py │ ├── log_buffer.py │ ├── utils.py │ └── runner.py ├── parallel.py ├── config.py └── io.py ├── travis └── install_pytorch.sh ├── examples ├── config.py └── train_imagenet.py ├── .travis.yml ├── LICENSE ├── setup.py ├── .gitignore ├── tests ├── test_runner.py └── test_io.py └── README.md /requirements.txt: -------------------------------------------------------------------------------- 1 | requests>=2.9.1 2 | six>=1.10.0 -------------------------------------------------------------------------------- /torchpack/version.py: -------------------------------------------------------------------------------- 1 | __version__ = '0.2.2' 2 | -------------------------------------------------------------------------------- /torchpack/__init__.py: -------------------------------------------------------------------------------- 1 | from .config import * 2 | from .io import * 3 | from .parallel import * 4 | from .runner import * 5 | from .version import __version__ 6 | -------------------------------------------------------------------------------- /torchpack/runner/hooks/closure.py: -------------------------------------------------------------------------------- 1 | from .hook import Hook 2 | 3 | 4 | class ClosureHook(Hook): 5 | 6 | def __init__(self, fn_name, fn): 7 | assert hasattr(self, fn_name) 8 | assert callable(fn) 9 | setattr(self, fn_name, fn) 10 | -------------------------------------------------------------------------------- /torchpack/runner/__init__.py: -------------------------------------------------------------------------------- 1 | from .runner import * 2 | from .hooks import * 3 | from .utils import * 4 | 5 | __all__ = [ 6 | 'Runner', 'Hook', 'CheckpointSaverHook', 'ClosureHook', 'LrUpdaterHook', 7 | 'OptimizerStepperHook', 'TimerHook', 'LoggerHook', 'TextLoggerHook', 8 | 'TensorboardLoggerHook', 'PaviLogger', 'PaviLoggerHook', 'get_host_info', 9 | 'get_dist_info', 'master_only', 'AverageMeter' 10 | ] 11 | -------------------------------------------------------------------------------- /torchpack/runner/hooks/timer.py: -------------------------------------------------------------------------------- 1 | import time 2 | 3 | from .hook import Hook 4 | 5 | 6 | class TimerHook(Hook): 7 | 8 | def before_epoch(self, runner): 9 | self.t = time.time() 10 | 11 | def before_iter(self, runner): 12 | runner.log_buffer.update({'data_time': time.time() - self.t}) 13 | 14 | def after_iter(self, runner): 15 | runner.log_buffer.update({'time': time.time() - self.t}) 16 | self.t = time.time() 17 | -------------------------------------------------------------------------------- /travis/install_pytorch.sh: -------------------------------------------------------------------------------- 1 | if [[ "$TRAVIS_PYTHON_VERSION" == "2.7" ]]; then 2 | pip install http://download.pytorch.org/whl/cpu/torch-0.4.0-cp27-cp27mu-linux_x86_64.whl 3 | elif [[ "$TRAVIS_PYTHON_VERSION" == "3.5" ]]; then 4 | pip install http://download.pytorch.org/whl/cpu/torch-0.4.0-cp35-cp35m-linux_x86_64.whl 5 | elif [[ "$TRAVIS_PYTHON_VERSION" == "3.6" ]]; then 6 | pip install http://download.pytorch.org/whl/cpu/torch-0.4.0-cp36-cp36m-linux_x86_64.whl 7 | fi 8 | pip install torchvision -------------------------------------------------------------------------------- /torchpack/runner/hooks/__init__.py: -------------------------------------------------------------------------------- 1 | from .hook import Hook 2 | from .checkpoint_saver import CheckpointSaverHook 3 | from .closure import ClosureHook 4 | from .lr_updater import LrUpdaterHook 5 | from .optimizer_stepper import OptimizerStepperHook 6 | from .timer import TimerHook 7 | from .logger import LoggerHook 8 | from .text_logger import TextLoggerHook 9 | from .tensorboard_logger import TensorboardLoggerHook 10 | from .pavi_logger import PaviLogger, PaviLoggerHook 11 | 12 | __all__ = [ 13 | 'Hook', 'CheckpointSaverHook', 'ClosureHook', 'LrUpdaterHook', 14 | 'OptimizerStepperHook', 'TimerHook', 'LoggerHook', 'TextLoggerHook', 15 | 'TensorboardLoggerHook', 'PaviLogger', 'PaviLoggerHook' 16 | ] 17 | -------------------------------------------------------------------------------- /torchpack/runner/hooks/optimizer_stepper.py: -------------------------------------------------------------------------------- 1 | from torch.nn.utils import clip_grad 2 | 3 | from .hook import Hook 4 | 5 | 6 | class OptimizerStepperHook(Hook): 7 | 8 | def __init__(self, grad_clip=False, max_norm=35, norm_type=2): 9 | self.grad_clip = grad_clip 10 | self.max_norm = max_norm 11 | self.norm_type = norm_type 12 | 13 | def after_train_iter(self, runner): 14 | runner.optimizer.zero_grad() 15 | runner.outputs['loss'].backward() 16 | if self.grad_clip: 17 | clip_grad.clip_grad_norm_( 18 | filter(lambda p: p.requires_grad, runner.model.parameters()), 19 | max_norm=self.max_norm, 20 | norm_type=self.norm_type) 21 | runner.optimizer.step() 22 | -------------------------------------------------------------------------------- /examples/config.py: -------------------------------------------------------------------------------- 1 | # model settings 2 | model = 'resnet18' 3 | # dataset settings 4 | data_root = '/mnt/SSD/dataset/ILSVRC/Data/CLS-LOC' 5 | mean = [0.485, 0.456, 0.406] 6 | std = [0.229, 0.224, 0.225] 7 | batch_size = 256 8 | 9 | # optimizer and learning rate 10 | optimizer = dict( 11 | algorithm='SGD', args=dict(lr=0.1, momentum=0.9, weight_decay=5e-4)) 12 | lr_policy = dict(policy='step', step=30) 13 | 14 | # logging settings 15 | log_level = 'INFO' 16 | log_cfg = dict( 17 | # log at every 50 iterations 18 | interval=50, 19 | hooks=[ 20 | ('TextLoggerHook', {}), 21 | # ('TensorboardLoggerHook', dict(log_dir=work_dir + '/log')) 22 | ]) 23 | 24 | # runtime settings 25 | work_dir = './demo' 26 | gpus = range(8) 27 | data_workers = len(gpus) * 2 28 | checkpoint_cfg = dict(interval=5) # save checkpoint at every epoch 29 | workflow = [('train', 5), ('val', 1)] 30 | max_epoch = 90 31 | -------------------------------------------------------------------------------- /torchpack/runner/hooks/checkpoint_saver.py: -------------------------------------------------------------------------------- 1 | from torchpack.io import save_checkpoint 2 | from .hook import Hook 3 | from ..utils import master_only 4 | 5 | 6 | class CheckpointSaverHook(Hook): 7 | 8 | def __init__(self, 9 | interval=-1, 10 | save_optimizer=True, 11 | out_dir=None, 12 | **kwargs): 13 | self.interval = interval 14 | self.save_optimizer = save_optimizer 15 | self.out_dir = out_dir 16 | self.args = kwargs 17 | 18 | @master_only 19 | def after_train_epoch(self, runner): 20 | if not self.every_n_epochs(runner, self.interval): 21 | return 22 | if not self.out_dir: 23 | self.out_dir = runner.work_dir 24 | optimizer = runner.optimizer if self.save_optimizer else None 25 | save_checkpoint( 26 | model=runner.model, 27 | epoch=runner.epoch + 1, 28 | num_iters=runner.num_iters, 29 | out_dir=self.out_dir, 30 | optimizer=optimizer, 31 | **self.args) 32 | -------------------------------------------------------------------------------- /.travis.yml: -------------------------------------------------------------------------------- 1 | language: python 2 | 3 | python: 4 | - '2.7' 5 | - '3.5' 6 | - '3.6' 7 | 8 | install: 9 | - bash travis/install_pytorch.sh 10 | - pip install . 11 | 12 | script: pytest 13 | 14 | deploy: 15 | provider: pypi 16 | user: kchen 17 | password: 18 | secure: ZkQJqC75Jfe338ql26VYUM9RjfQhwfqOSOyg5ixo4IQlRXZXHOdB1f0QCbMWo1ZTRZa6fHsvMCGYS/nUUsb+YASdXS+OWOqfjdHo8HZANURrJmcxvSaz9J7fxB6Wir0g5a33KydjUeW8/X8+KirMl2WrYRouWrKiibcdFaBO9xFpdr8xDlteIV/8spqLboaQ6Zi9bkq85UGvE5fxBybONHMhGcnoj13fUPAtAYhrmUNy/KKBrTMjRMXgT3MFB2KhhYIn/l8kLO7pMPYYAbz7P7yBxqEyunbBY8hEzjd8occq0tfltluVpWTod889HenMf5fdjM/uxWtbyawEd3Dl0jAB1immycOxSG3eJpcg4nFvRUAPDk/Z6Vir/k5LtFZy1hPOZjJuLan9t0qtIaY1ap64ninfenRsG0UTmZaZ30/LeVgjqsP3xBTabdtu2JXCrJPWtsRSkq3A4A6Xs6WuH40pmTLpSAsDm+rED7JxNPYygAoC/zp3Z9gJ8i93KTjOpAmF0oQpHd/3vL4h2HQQLBndCYJAfd7He33jng/Uvx8M1RCtbFqwkd1NqLGKbxtYd0GvUNeJ2Num2rvox7aju09B1axIgBZupiTGVJoBrYz5UwoDx/VVXBYu5NTn/om9kyyD0CevgVp6frJrJc6c2qoQWIYluN0OA7A7m7OKT8o= 19 | on: 20 | branch: master 21 | tags: true 22 | python: '3.5' 23 | distributions: sdist bdist_wheel 24 | skip_cleanup: true 25 | skip_upload_docs: true 26 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2017 Kai Chen 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /torchpack/runner/hooks/text_logger.py: -------------------------------------------------------------------------------- 1 | from .logger import LoggerHook 2 | 3 | 4 | class TextLoggerHook(LoggerHook): 5 | 6 | def log(self, runner): 7 | if runner.mode == 'train': 8 | lr_str = ', '.join( 9 | ['{:.5f}'.format(lr) for lr in runner.current_lr()]) 10 | log_str = 'Epoch [{}][{}/{}]\tlr: {}, '.format( 11 | runner.epoch + 1, runner.num_epoch_iters + 1, 12 | len(runner.data_loader), lr_str) 13 | else: 14 | log_str = 'Epoch({}) [{}][{}]\t'.format(runner.mode, runner.epoch, 15 | runner.num_epoch_iters + 1) 16 | if 'time' in runner.log_buffer.output: 17 | log_str += ( 18 | 'time: {log[time]:.3f}, data_time: {log[data_time]:.3f}, '. 19 | format(log=runner.log_buffer.output)) 20 | log_items = [] 21 | for name, val in runner.log_buffer.output.items(): 22 | if name in ['time', 'data_time']: 23 | continue 24 | log_items.append('{}: {:.4f}'.format(name, val)) 25 | log_str += ', '.join(log_items) 26 | runner.logger.info(log_str) 27 | -------------------------------------------------------------------------------- /torchpack/runner/log_buffer.py: -------------------------------------------------------------------------------- 1 | from collections import OrderedDict 2 | import numpy as np 3 | 4 | 5 | class LogBuffer(object): 6 | 7 | def __init__(self): 8 | self.val_history = OrderedDict() 9 | self.n_history = OrderedDict() 10 | self.output = OrderedDict() 11 | self.ready = False 12 | 13 | def clear(self): 14 | self.val_history.clear() 15 | self.n_history.clear() 16 | self.clear_output() 17 | 18 | def clear_output(self): 19 | self.output.clear() 20 | self.ready = False 21 | 22 | def update(self, vars, count=1): 23 | assert isinstance(vars, dict) 24 | for key, var in vars.items(): 25 | if key not in self.val_history: 26 | self.val_history[key] = [] 27 | self.n_history[key] = [] 28 | self.val_history[key].append(var) 29 | self.n_history[key].append(count) 30 | 31 | def average(self, n=0): 32 | """Average latest n values or all values""" 33 | assert n >= 0 34 | for key in self.val_history: 35 | values = np.array(self.val_history[key][-n:]) 36 | nums = np.array(self.n_history[key][-n:]) 37 | avg = np.sum(values * nums) / np.sum(nums) 38 | self.output[key] = avg 39 | self.ready = True 40 | -------------------------------------------------------------------------------- /torchpack/runner/hooks/tensorboard_logger.py: -------------------------------------------------------------------------------- 1 | from .logger import LoggerHook 2 | from ..utils import master_only 3 | 4 | 5 | class TensorboardLoggerHook(LoggerHook): 6 | 7 | def __init__(self, 8 | log_dir, 9 | interval=10, 10 | reset_meter=True, 11 | ignore_last=True): 12 | super(TensorboardLoggerHook, self).__init__(interval, reset_meter, 13 | ignore_last) 14 | self.log_dir = log_dir 15 | 16 | @master_only 17 | def before_run(self, runner): 18 | try: 19 | from tensorboardX import SummaryWriter 20 | except ImportError: 21 | raise ImportError('Please install tensorflow and tensorboardX ' 22 | 'to use TensorboardLoggerHook.') 23 | else: 24 | self.writer = SummaryWriter(self.log_dir) 25 | 26 | @master_only 27 | def log(self, runner): 28 | for var in runner.log_buffer.output: 29 | if var in ['time', 'data_time']: 30 | continue 31 | tag = '{}/{}'.format(var, runner.mode) 32 | self.writer.add_scalar(tag, runner.log_buffer.output[var], 33 | runner.num_iters) 34 | 35 | @master_only 36 | def after_run(self, runner): 37 | self.writer.close() 38 | -------------------------------------------------------------------------------- /torchpack/runner/hooks/logger.py: -------------------------------------------------------------------------------- 1 | from .hook import Hook 2 | 3 | 4 | class LoggerHook(Hook): 5 | """Base class for logger hooks.""" 6 | 7 | def __init__(self, interval=10, ignore_last=True, reset_flag=False): 8 | self.interval = interval 9 | self.ignore_last = ignore_last 10 | self.reset_flag = reset_flag 11 | 12 | def before_run(self, runner): 13 | for hook in runner.hooks[::-1]: 14 | if isinstance(hook, LoggerHook): 15 | hook.reset_flag = True 16 | break 17 | 18 | def before_epoch(self, runner): 19 | runner.log_buffer.clear() # clear logs of last epoch 20 | 21 | def after_train_iter(self, runner): 22 | if self.every_n_inner_iters(runner, self.interval): 23 | runner.log_buffer.average(self.interval) 24 | elif self.end_of_epoch(runner) and not self.ignore_last: 25 | # not precise but more stable 26 | runner.log_buffer.average(self.interval) 27 | 28 | if runner.log_buffer.ready: 29 | self.log(runner) 30 | if self.reset_flag: 31 | runner.log_buffer.clear_output() 32 | 33 | def after_train_epoch(self, runner): 34 | if runner.log_buffer.ready: 35 | self.log(runner) 36 | 37 | def after_val_epoch(self, runner): 38 | runner.log_buffer.average() 39 | self.log(runner) 40 | if self.reset_flag: 41 | runner.log_buffer.clear_output() 42 | 43 | def log(self, runner): 44 | pass 45 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | from setuptools import find_packages, setup 2 | 3 | with open('requirements.txt', 'r') as f: 4 | install_requires = [line for line in f] 5 | 6 | 7 | def readme(): 8 | with open('README.md') as f: 9 | content = f.read() 10 | return content 11 | 12 | 13 | def get_version(): 14 | version_file = 'torchpack/version.py' 15 | with open(version_file, 'r') as f: 16 | exec(compile(f.read(), version_file, 'exec')) 17 | return locals()['__version__'] 18 | 19 | 20 | setup( 21 | name='torchpack', 22 | version=get_version(), 23 | description='A set of interfaces to simplify the usage of PyTorch', 24 | long_description=readme(), 25 | keywords='computer vision', 26 | packages=find_packages(), 27 | classifiers=[ 28 | 'Development Status :: 3 - Alpha', 29 | 'License :: OSI Approved :: MIT License', 30 | 'Operating System :: OS Independent', 31 | 'Programming Language :: Python :: 2', 32 | 'Programming Language :: Python :: 2.7', 33 | 'Programming Language :: Python :: 3', 34 | 'Programming Language :: Python :: 3.4', 35 | 'Programming Language :: Python :: 3.5', 36 | 'Programming Language :: Python :: 3.6', 37 | 'Topic :: Utilities', 38 | ], 39 | url='https://github.com/hellock/torchpack', 40 | author='Kai Chen', 41 | author_email='chenkaidev@gmail.com', 42 | license='MIT', 43 | setup_requires=['pytest-runner'], 44 | tests_require=['pytest'], 45 | install_requires=install_requires, 46 | zip_safe=False 47 | ) # yapf: disable -------------------------------------------------------------------------------- /torchpack/runner/hooks/hook.py: -------------------------------------------------------------------------------- 1 | class Hook(object): 2 | 3 | def before_run(self, runner): 4 | pass 5 | 6 | def after_run(self, runner): 7 | pass 8 | 9 | def before_epoch(self, runner): 10 | pass 11 | 12 | def after_epoch(self, runner): 13 | pass 14 | 15 | def before_iter(self, runner): 16 | pass 17 | 18 | def after_iter(self, runner): 19 | pass 20 | 21 | def before_train_epoch(self, runner): 22 | self.before_epoch(runner) 23 | 24 | def before_val_epoch(self, runner): 25 | self.before_epoch(runner) 26 | 27 | def after_train_epoch(self, runner): 28 | self.after_epoch(runner) 29 | 30 | def after_val_epoch(self, runner): 31 | self.after_epoch(runner) 32 | 33 | def before_train_iter(self, runner): 34 | self.before_iter(runner) 35 | 36 | def before_val_iter(self, runner): 37 | self.before_iter(runner) 38 | 39 | def after_train_iter(self, runner): 40 | self.after_iter(runner) 41 | 42 | def after_val_iter(self, runner): 43 | self.after_iter(runner) 44 | 45 | def every_n_epochs(self, runner, n): 46 | return (runner.epoch + 1) % n == 0 if n > 0 else False 47 | 48 | def every_n_inner_iters(self, runner, n): 49 | return (runner.num_epoch_iters + 1) % n == 0 if n > 0 else False 50 | 51 | def every_n_iters(self, runner, n): 52 | return (runner.num_iters + 1) % n == 0 if n > 0 else False 53 | 54 | def end_of_epoch(self, runner): 55 | return runner.num_epoch_iters + 1 == len(runner.data_loader) 56 | -------------------------------------------------------------------------------- /torchpack/runner/utils.py: -------------------------------------------------------------------------------- 1 | import functools 2 | from collections import defaultdict 3 | from getpass import getuser 4 | from socket import gethostname 5 | 6 | import torch.distributed as dist 7 | 8 | 9 | def get_host_info(): 10 | return '{}@{}'.format(getuser(), gethostname()) 11 | 12 | 13 | def get_dist_info(): 14 | if dist._initialized: 15 | rank = dist.get_rank() 16 | world_size = dist.get_world_size() 17 | else: 18 | rank = 0 19 | world_size = 1 20 | return rank, world_size 21 | 22 | 23 | def master_only(func): 24 | 25 | @functools.wraps(func) 26 | def wrapper(*args, **kwargs): 27 | rank, _ = get_dist_info() 28 | if rank == 0: 29 | return func(*args, **kwargs) 30 | 31 | return wrapper 32 | 33 | 34 | class AverageMeter(object): 35 | """Computes and stores the average and current value""" 36 | 37 | def __init__(self): 38 | self.val = defaultdict(float) 39 | self.sum = defaultdict(float) 40 | self.avg = defaultdict(float) 41 | self.count = defaultdict(int) 42 | self.reset() 43 | 44 | def reset(self, keys=None): 45 | if isinstance(keys, str): 46 | keys = [keys] 47 | elif keys is None: 48 | keys = self.val.keys() 49 | for k in keys: 50 | self.val[k] = 0 51 | self.sum[k] = 0 52 | self.avg[k] = 0 53 | self.count[k] = 0 54 | 55 | def update(self, pairs, n=1): 56 | for k, v in pairs.items(): 57 | self.val[k] = v 58 | self.sum[k] += v * n 59 | self.count[k] += n 60 | self.avg[k] = self.sum[k] / self.count[k] 61 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | env/ 12 | build/ 13 | develop-eggs/ 14 | dist/ 15 | downloads/ 16 | eggs/ 17 | .eggs/ 18 | lib/ 19 | lib64/ 20 | parts/ 21 | sdist/ 22 | var/ 23 | wheels/ 24 | *.egg-info/ 25 | .installed.cfg 26 | *.egg 27 | 28 | # PyInstaller 29 | # Usually these files are written by a python script from a template 30 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 31 | *.manifest 32 | *.spec 33 | 34 | # Installer logs 35 | pip-log.txt 36 | pip-delete-this-directory.txt 37 | 38 | # Unit test / coverage reports 39 | htmlcov/ 40 | .tox/ 41 | .coverage 42 | .coverage.* 43 | .cache 44 | nosetests.xml 45 | coverage.xml 46 | *.cover 47 | .hypothesis/ 48 | 49 | # Translations 50 | *.mo 51 | *.pot 52 | 53 | # Django stuff: 54 | *.log 55 | local_settings.py 56 | 57 | # Flask stuff: 58 | instance/ 59 | .webassets-cache 60 | 61 | # Scrapy stuff: 62 | .scrapy 63 | 64 | # Sphinx documentation 65 | docs/_build/ 66 | 67 | # PyBuilder 68 | target/ 69 | 70 | # Jupyter Notebook 71 | .ipynb_checkpoints 72 | 73 | # pyenv 74 | .python-version 75 | 76 | # celery beat schedule file 77 | celerybeat-schedule 78 | 79 | # SageMath parsed files 80 | *.sage.py 81 | 82 | # dotenv 83 | .env 84 | 85 | # virtualenv 86 | .venv 87 | venv/ 88 | ENV/ 89 | 90 | # Spyder project settings 91 | .spyderproject 92 | .spyproject 93 | 94 | # Rope project settings 95 | .ropeproject 96 | 97 | # mkdocs documentation 98 | /site 99 | 100 | # mypy 101 | .mypy_cache/ 102 | 103 | # vscode 104 | .vscode/ 105 | 106 | .pytest_cache/ 107 | -------------------------------------------------------------------------------- /torchpack/parallel.py: -------------------------------------------------------------------------------- 1 | import multiprocessing 2 | 3 | import torch 4 | 5 | from .io import load_checkpoint 6 | 7 | 8 | def worker_func(model_cls, model_kwargs, checkpoint, dataset, data_func, 9 | gpu_id, idx_queue, result_queue): 10 | model = model_cls(**model_kwargs) 11 | load_checkpoint(model, checkpoint, map_location='cpu') 12 | torch.cuda.set_device(gpu_id) 13 | model.cuda() 14 | model.eval() 15 | with torch.no_grad(): 16 | while True: 17 | idx = idx_queue.get() 18 | data = dataset[idx] 19 | result = model(**data_func(data, gpu_id)) 20 | result_queue.put((idx, result)) 21 | 22 | 23 | def parallel_test(model_cls, 24 | model_kwargs, 25 | checkpoint, 26 | dataset, 27 | data_func, 28 | gpus, 29 | worker_per_gpu=1): 30 | ctx = multiprocessing.get_context('spawn') 31 | idx_queue = ctx.Queue() 32 | result_queue = ctx.Queue() 33 | num_workers = len(gpus) * worker_per_gpu 34 | workers = [ 35 | ctx.Process( 36 | target=worker_func, 37 | args=(model_cls, model_kwargs, checkpoint, dataset, data_func, 38 | gpus[i % len(gpus)], idx_queue, result_queue)) 39 | for i in range(num_workers) 40 | ] 41 | for w in workers: 42 | w.daemon = True 43 | w.start() 44 | 45 | for i in range(len(dataset)): 46 | idx_queue.put(i) 47 | 48 | results = [None for _ in range(len(dataset))] 49 | import cvbase as cvb 50 | prog_bar = cvb.ProgressBar(task_num=len(dataset)) 51 | for _ in range(len(dataset)): 52 | idx, res = result_queue.get() 53 | results[idx] = res 54 | prog_bar.update() 55 | print('\n') 56 | for worker in workers: 57 | worker.terminate() 58 | 59 | return results 60 | -------------------------------------------------------------------------------- /tests/test_runner.py: -------------------------------------------------------------------------------- 1 | import tempfile 2 | from collections import OrderedDict 3 | 4 | import torch 5 | from torch import nn 6 | from torch.utils.data import TensorDataset 7 | from torchpack import Runner 8 | 9 | 10 | class Model(nn.Module): 11 | 12 | def __init__(self): 13 | super(Model, self).__init__() 14 | self.conv = nn.Conv2d(2, 3, kernel_size=3, padding=1) 15 | self.relu = nn.ReLU(inplace=True) 16 | self.pool = nn.AvgPool2d(5) 17 | self.fc = nn.Linear(3, 1) 18 | 19 | def forward(self, x): 20 | x = self.conv(x) 21 | x = self.relu(x) 22 | x = self.pool(x) 23 | x = x.view(x.size(0), -1) 24 | x = self.fc(x) 25 | return x 26 | 27 | 28 | def batch_processor(model, data, train_mode): 29 | p_img, = data 30 | res = model(p_img) 31 | loss = res.mean() 32 | log_vars = OrderedDict() 33 | log_vars['loss'] = loss.item() 34 | outputs = dict(loss=loss, log_vars=log_vars, num_samples=p_img.size(0)) 35 | return outputs 36 | 37 | 38 | class TestRunner(object): 39 | 40 | @classmethod 41 | def setup_class(cls): 42 | cls.model = Model() 43 | cls.train_dataset = TensorDataset(torch.rand(10, 2, 5, 5)) 44 | cls.val_dataset = TensorDataset(torch.rand(3, 2, 5, 5)) 45 | 46 | def test_init(self): 47 | optimizer = dict( 48 | algorithm='SGD', 49 | args=dict(lr=0.001, momentum=0.9, weight_decay=5e-4)) 50 | work_dir = tempfile.mkdtemp() 51 | runner = Runner(self.model, optimizer, batch_processor, work_dir) 52 | train_loader = torch.utils.data.DataLoader( 53 | self.train_dataset, batch_size=5, shuffle=True) 54 | val_loader = torch.utils.data.DataLoader( 55 | self.train_dataset, batch_size=3, shuffle=False) 56 | runner.run( 57 | [train_loader, val_loader], [('train', 1), ('val', 1)], 58 | max_epoch=2) 59 | -------------------------------------------------------------------------------- /tests/test_io.py: -------------------------------------------------------------------------------- 1 | import os.path as osp 2 | import shutil 3 | import tempfile 4 | 5 | import pytest 6 | import torch 7 | from torch import nn 8 | 9 | from torchpack import load_checkpoint, save_checkpoint 10 | 11 | 12 | class Model(nn.Module): 13 | 14 | def __init__(self): 15 | super(Model, self).__init__() 16 | self.conv = nn.Conv2d(2, 3, kernel_size=3, padding=1) 17 | self.relu = nn.ReLU(inplace=True) 18 | self.pool = nn.AvgPool2d(5) 19 | self.fc = nn.Linear(3, 1) 20 | 21 | def forward(self, x): 22 | x = self.conv(x) 23 | x = self.relu(x) 24 | x = self.pool(x) 25 | x = self.fc(x.view(-1)) 26 | return x 27 | 28 | def verify_params(self, state_dict): 29 | from collections import OrderedDict 30 | assert isinstance(state_dict, OrderedDict) 31 | assert list(state_dict.keys()) == [ 32 | 'conv.weight', 'conv.bias', 'fc.weight', 'fc.bias' 33 | ] 34 | 35 | 36 | def test_save_checkpoint(): 37 | tmp_dir = tempfile.mkdtemp() 38 | model = Model() 39 | epoch = 1 40 | num_iters = 100 41 | optimizer = torch.optim.SGD(model.parameters(), 0.01) 42 | save_checkpoint(model, epoch, num_iters, tmp_dir) 43 | assert osp.isfile(tmp_dir + '/epoch_1.pth') 44 | chkp = torch.load(tmp_dir + '/epoch_1.pth') 45 | assert isinstance(chkp, dict) 46 | assert chkp['epoch'] == epoch 47 | assert chkp['num_iters'] == num_iters 48 | model.verify_params(chkp['state_dict']) 49 | save_checkpoint( 50 | model, 51 | epoch, 52 | num_iters, 53 | tmp_dir, 54 | filename_tmpl='test_{}.pth', 55 | optimizer=optimizer) 56 | assert osp.isfile(tmp_dir + '/test_1.pth') 57 | chkp = torch.load(tmp_dir + '/test_1.pth') 58 | assert isinstance(chkp, dict) 59 | assert chkp['epoch'] == epoch 60 | assert chkp['num_iters'] == num_iters 61 | model.verify_params(chkp['state_dict']) 62 | shutil.rmtree(tmp_dir) 63 | 64 | 65 | def test_load_checkpoint(): 66 | model = Model() 67 | with pytest.raises(IOError): 68 | load_checkpoint(model, 'non_exist_file') 69 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # torchpack (Deprecated! Please use [mmcv](https://github.com/open-mmlab/mmcv) instead.) 2 | 3 | [![PyPI Version](https://img.shields.io/pypi/v/torchpack.svg)](https://pypi.python.org/pypi/torchpack) 4 | 5 | Torchpack is a set of interfaces to simplify the usage of PyTorch. 6 | 7 | Documentation is ongoing. 8 | 9 | 10 | ## Installation 11 | 12 | - Install with pip. 13 | ``` 14 | pip install torchpack 15 | ``` 16 | - Install from source. 17 | ``` 18 | git clone https://github.com/hellock/torchpack.git 19 | cd torchpack 20 | python setup.py install 21 | ``` 22 | 23 | **Note**: If you want to use tensorboard to visualize the training process, you need to 24 | install tensorflow([`installation guide`](https://www.tensorflow.org/install/install_linux)) and tensorboardX(`pip install tensorboardX`). 25 | 26 | ## What can torchpack do 27 | 28 | Torchpack aims to help users to start training with less code, while stays 29 | flexible and configurable. It provides a `Runner` with lots of `Hooks`. 30 | 31 | ## Example 32 | 33 | ```python 34 | ######################## file1: config.py ####################### 35 | work_dir = './demo' # dir to save log file and checkpoints 36 | optimizer = dict( 37 | algorithm='SGD', args=dict(lr=0.001, momentum=0.9, weight_decay=5e-4)) 38 | workflow = [('train', 2), ('val', 1)] # train 2 epochs and then validate 1 epochs, iteratively 39 | max_epoch = 16 40 | lr_policy = dict(policy='step', step=12) # decrese learning rate by 10 every 12 epochs 41 | checkpoint_cfg = dict(interval=1) # save checkpoint at every epoch 42 | log_cfg = dict( 43 | # log at every 50 iterations 44 | interval=50, 45 | # two logging hooks, one for printing in terminal and one for tensorboard visualization 46 | hooks=[ 47 | ('TextLoggerHook', {}), 48 | ('TensorboardLoggerHook', dict(log_dir=work_dir + '/log')) 49 | ]) 50 | 51 | ######################### file2: main.py ######################## 52 | import torch 53 | from torchpack import Config, Runner 54 | from collections import OrderedDict 55 | 56 | # define how to process a batch and return a dict 57 | def batch_processor(model, data, train_mode): 58 | img, label = data 59 | label = label.cuda(non_blocking=True) 60 | pred = model(img) 61 | loss = F.cross_entropy(pred, label) 62 | accuracy = get_accuracy(pred, label_var) 63 | log_vars = OrderedDict() 64 | log_vars['loss'] = loss.item() 65 | log_vars['accuracy'] = accuracy.item() 66 | outputs = dict(loss=loss, log_vars=log_vars, num_samples=img.size(0)) 67 | return outputs 68 | 69 | cfg = Config.from_file('config.py') # or config.yaml/config.json 70 | model = resnet18() 71 | runner = Runner(model, cfg.optimizer, batch_processor, cfg.work_dir) 72 | runner.register_default_hooks(lr_config=cfg.lr_policy, 73 | checkpoint_config=cfg.checkpoint_cfg, 74 | log_config=cfg.log_cfg) 75 | 76 | runner.run([train_loader, val_loader], cfg.workflow, cfg.max_epoch) 77 | ``` 78 | 79 | For a full example of training on ImageNet, please see `examples/train_imagenet.py`. 80 | 81 | ```shell 82 | python examples/train_imagenet.py examples/config.py 83 | ``` 84 | -------------------------------------------------------------------------------- /examples/train_imagenet.py: -------------------------------------------------------------------------------- 1 | import os.path as osp 2 | from argparse import ArgumentParser 3 | from collections import OrderedDict 4 | 5 | import torch 6 | import torch.nn.functional as F 7 | from torchpack import Config, Runner 8 | from torchvision import datasets, models, transforms 9 | 10 | 11 | def accuracy(output, target, topk=(1, )): 12 | """Computes the precision@k for the specified values of k""" 13 | with torch.no_grad(): 14 | maxk = max(topk) 15 | batch_size = target.size(0) 16 | 17 | _, pred = output.topk(maxk, 1, True, True) 18 | pred = pred.t() 19 | correct = pred.eq(target.view(1, -1).expand_as(pred)) 20 | 21 | res = [] 22 | for k in topk: 23 | correct_k = correct[:k].view(-1).float().sum(0, keepdim=True) 24 | res.append(correct_k.mul_(100.0 / batch_size)) 25 | return res 26 | 27 | 28 | def batch_processor(model, data, train_mode): 29 | img, label = data 30 | label = label.cuda(non_blocking=True) 31 | pred = model(img) 32 | loss = F.cross_entropy(pred, label) 33 | acc_top1, acc_top5 = accuracy(pred, label, topk=(1, 5)) 34 | log_vars = OrderedDict() 35 | log_vars['loss'] = loss.item() 36 | log_vars['acc_top1'] = acc_top1.item() 37 | log_vars['acc_top5'] = acc_top5.item() 38 | outputs = dict(loss=loss, log_vars=log_vars, num_samples=img.size(0)) 39 | return outputs 40 | 41 | 42 | def parse_args(): 43 | parser = ArgumentParser(description='Train Faster R-CNN') 44 | parser.add_argument('config', help='train config file path') 45 | return parser.parse_args() 46 | 47 | 48 | def main(): 49 | args = parse_args() 50 | cfg = Config.from_file(args.config) # or config.yaml/config.json 51 | model = getattr(models, cfg.model)() 52 | # from mobilenet_v2 import MobileNetV2 53 | # model = MobileNetV2() 54 | model = torch.nn.DataParallel(model, device_ids=cfg.gpus).cuda() 55 | 56 | normalize = transforms.Normalize(mean=cfg.mean, std=cfg.std) 57 | train_dataset = datasets.ImageFolder( 58 | osp.join(cfg.data_root, 'train'), 59 | transforms.Compose([ 60 | transforms.RandomResizedCrop(224), 61 | transforms.RandomHorizontalFlip(), 62 | transforms.ToTensor(), 63 | normalize, 64 | ])) 65 | val_dataset = datasets.ImageFolder( 66 | osp.join(cfg.data_root, 'val'), 67 | transforms.Compose([ 68 | transforms.Resize(256), 69 | transforms.CenterCrop(224), 70 | transforms.ToTensor(), 71 | normalize, 72 | ])) 73 | train_loader = torch.utils.data.DataLoader( 74 | train_dataset, 75 | batch_size=cfg.batch_size, 76 | shuffle=True, 77 | num_workers=cfg.data_workers, 78 | pin_memory=True) 79 | val_loader = torch.utils.data.DataLoader( 80 | val_dataset, 81 | batch_size=cfg.batch_size, 82 | shuffle=False, 83 | num_workers=cfg.data_workers, 84 | pin_memory=True) 85 | 86 | runner = Runner(model, cfg.optimizer, batch_processor, cfg.work_dir) 87 | runner.register_default_hooks( 88 | lr_config=cfg.lr_policy, 89 | checkpoint_config=cfg.checkpoint_cfg, 90 | log_config=cfg.log_cfg) 91 | 92 | runner.run([train_loader, val_loader], cfg.workflow, cfg.max_epoch) 93 | 94 | 95 | if __name__ == '__main__': 96 | main() 97 | -------------------------------------------------------------------------------- /torchpack/config.py: -------------------------------------------------------------------------------- 1 | import os.path as osp 2 | import sys 3 | from argparse import ArgumentParser 4 | from collections import Iterable 5 | from importlib import import_module 6 | 7 | 8 | def add_args(parser, cfg, prefix=''): 9 | for k, v in cfg.items(): 10 | if isinstance(v, str): 11 | parser.add_argument('--' + prefix + k) 12 | elif isinstance(v, int): 13 | parser.add_argument('--' + prefix + k, type=int) 14 | elif isinstance(v, float): 15 | parser.add_argument('--' + prefix + k, type=float) 16 | elif isinstance(v, bool): 17 | parser.add_argument('--' + prefix + k, action='store_true') 18 | elif isinstance(v, dict): 19 | add_args(parser, v, k + '.') 20 | elif isinstance(v, Iterable): 21 | parser.add_argument('--' + prefix + k, type=type(v[0]), nargs='+') 22 | else: 23 | print('connot parse key {} of type {}'.format(prefix + k, type(v))) 24 | return parser 25 | 26 | 27 | class Config(object): 28 | 29 | @staticmethod 30 | def from_file(filename): 31 | if filename.endswith('.py'): 32 | sys.path.append(osp.dirname(filename)) 33 | module_name = osp.basename(filename)[:-3] 34 | cfg = import_module(module_name) 35 | config_dict = { 36 | name: value 37 | for name, value in cfg.__dict__.items() 38 | if not name.startswith(('__', '_')) 39 | } 40 | elif filename.endswith(('.yaml', '.json')): 41 | import cvbase as cvb 42 | config_dict = cvb.load(filename) 43 | else: 44 | raise IOError( 45 | 'only py/yaml/json type are supported as config files') 46 | return Config(config_dict, filename=filename) 47 | 48 | @staticmethod 49 | def auto_argparser(description=None): 50 | partial_parser = ArgumentParser(description=description) 51 | partial_parser.add_argument('config', help='config file path') 52 | cfg_file = partial_parser.parse_known_args()[0].config 53 | cfg = Config.from_py(cfg_file) 54 | parser = ArgumentParser(description=description) 55 | parser.add_argument('config', help='config file path') 56 | add_args(parser, cfg) 57 | return parser, cfg 58 | 59 | def __init__(self, config_dict, filename=None): 60 | assert isinstance(config_dict, dict) 61 | self._config_dict = config_dict 62 | self._default_dict = {} 63 | self.filename = filename 64 | if filename: 65 | with open(filename, 'r') as f: 66 | self._text = f.read() 67 | 68 | def __getattr__(self, key): 69 | try: 70 | val = self._config_dict[key] 71 | except KeyError: 72 | if key in self._default_dict: 73 | val = self._default_dict[key] 74 | else: 75 | raise 76 | return val 77 | 78 | def __getitem__(self, key): 79 | return self.__getattr__(key) 80 | 81 | def __iter__(self): 82 | return self.keys() 83 | 84 | def __contains__(self, key): 85 | if key in self._config_dict or key in self._default_dict: 86 | return True 87 | else: 88 | return False 89 | 90 | @property 91 | def text(self): 92 | return self._text 93 | 94 | def get(self, key, default=None): 95 | if key in self: 96 | return self.__getattr__(key) 97 | else: 98 | return default 99 | 100 | def keys(self): 101 | for key in self._config_dict: 102 | yield key 103 | for key in self._default_dict: 104 | if key not in self._config_dict: 105 | yield key 106 | 107 | def values(self): 108 | for key in self.keys(): 109 | yield self.__getattr__(key) 110 | 111 | def items(self): 112 | for key in self.keys(): 113 | yield key, self.__getattr__(key) 114 | 115 | def set_default(self, default_dict): 116 | assert isinstance(default_dict, dict) 117 | self._default_dict.update(default_dict) 118 | 119 | def update_with_args(self, args): 120 | for k, v in vars(args).items(): 121 | if v is not None: 122 | if '.' not in k: 123 | self._config_dict[k] = v 124 | else: 125 | tree = k.split('.') 126 | tmp = self._config_dict 127 | for key in tree[:-1]: 128 | tmp = tmp[key] 129 | tmp[tree[-1]] = v 130 | -------------------------------------------------------------------------------- /torchpack/io.py: -------------------------------------------------------------------------------- 1 | import os 2 | from collections import OrderedDict 3 | 4 | import torch 5 | from torch.nn.parallel import DataParallel, DistributedDataParallel 6 | from torch.utils import model_zoo 7 | from torchvision.models.resnet import model_urls 8 | 9 | 10 | def load_state_dict(module, state_dict, strict=False, logger=None): 11 | unexpected_keys = [] 12 | own_state = module.state_dict() 13 | for name, param in state_dict.items(): 14 | if name not in own_state: 15 | unexpected_keys.append(name) 16 | continue 17 | if isinstance(param, torch.nn.Parameter): 18 | # backwards compatibility for serialized parameters 19 | param = param.data 20 | 21 | try: 22 | own_state[name].copy_(param) 23 | except Exception: 24 | raise RuntimeError('While copying the parameter named {}, ' 25 | 'whose dimensions in the model are {} and ' 26 | 'whose dimensions in the checkpoint are {}.' 27 | .format(name, own_state[name].size(), 28 | param.size())) 29 | missing_keys = set(own_state.keys()) - set(state_dict.keys()) 30 | 31 | err_msgs = [] 32 | if unexpected_keys: 33 | err_msgs.append('unexpected key in source state_dict: {}\n'.format( 34 | ', '.join(unexpected_keys))) 35 | if missing_keys: 36 | err_msgs.append('missing keys in source state_dict: {}\n'.format( 37 | ', '.join(missing_keys))) 38 | msg = '\n'.join(err_msgs) 39 | if msg: 40 | if strict: 41 | raise RuntimeError(msg) 42 | elif logger is not None: 43 | logger.warn(msg) 44 | else: 45 | print(msg) 46 | 47 | 48 | def load_checkpoint(model, 49 | filename, 50 | map_location=None, 51 | strict=False, 52 | logger=None): 53 | # load checkpoint from modelzoo or file or url 54 | if filename.startswith('modelzoo://'): 55 | model_name = filename[11:] 56 | checkpoint = model_zoo.load_url(model_urls[model_name]) 57 | elif filename.startswith(('http://', 'https://')): 58 | checkpoint = model_zoo.load_url(filename) 59 | else: 60 | if not os.path.isfile(filename): 61 | raise IOError('{} is not a checkpoint file'.format(filename)) 62 | checkpoint = torch.load(filename, map_location=map_location) 63 | # get state_dict from checkpoint 64 | if isinstance(checkpoint, OrderedDict): 65 | state_dict = checkpoint 66 | elif isinstance(checkpoint, dict) and 'state_dict' in checkpoint: 67 | state_dict = checkpoint['state_dict'] 68 | else: 69 | raise RuntimeError( 70 | 'No state_dict found in checkpoint file {}'.format(filename)) 71 | # strip prefix of state_dict 72 | if list(state_dict.keys())[0].startswith('module.'): 73 | state_dict = {k[7:]: v for k, v in checkpoint['state_dict'].items()} 74 | # load state_dict 75 | if isinstance(model, (DataParallel, DistributedDataParallel)): 76 | load_state_dict(model.module, state_dict, strict, logger) 77 | else: 78 | load_state_dict(model, state_dict, strict, logger) 79 | return checkpoint 80 | 81 | 82 | def make_link(filename, link): 83 | if os.path.islink(link): 84 | os.remove(link) 85 | os.symlink(filename, link) 86 | 87 | 88 | def model_weights_to_cpu(state_dict): 89 | state_dict_cpu = OrderedDict() 90 | for key, val in state_dict.items(): 91 | state_dict_cpu[key] = val.cpu() 92 | return state_dict_cpu 93 | 94 | 95 | def save_checkpoint(model, 96 | epoch, 97 | num_iters, 98 | out_dir, 99 | filename_tmpl='epoch_{}.pth', 100 | optimizer=None, 101 | is_best=False): 102 | if not os.path.isdir(out_dir): 103 | os.makedirs(out_dir) 104 | if isinstance(model, (DataParallel, DistributedDataParallel)): 105 | model = model.module 106 | filename = os.path.join(out_dir, filename_tmpl.format(epoch)) 107 | checkpoint = { 108 | 'epoch': epoch, 109 | 'num_iters': num_iters, 110 | 'state_dict': model_weights_to_cpu(model.state_dict()) 111 | } 112 | if optimizer is not None: 113 | checkpoint['optimizer'] = optimizer.state_dict() 114 | torch.save(checkpoint, filename) 115 | latest_link = os.path.join(out_dir, 'latest.pth') 116 | make_link(filename, latest_link) 117 | if is_best: 118 | best_link = os.path.join(out_dir, 'best.pth') 119 | make_link(filename, best_link) 120 | -------------------------------------------------------------------------------- /torchpack/runner/hooks/pavi_logger.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | 3 | import os 4 | import time 5 | from datetime import datetime 6 | from threading import Thread 7 | 8 | import requests 9 | from six.moves.queue import Empty, Queue 10 | 11 | from .logger import LoggerHook 12 | from ..utils import master_only, get_host_info 13 | 14 | 15 | class PaviLogger(object): 16 | 17 | def __init__(self, url, username=None, password=None, instance_id=None): 18 | self.url = url 19 | self.username = self._get_env_var(username, 'PAVI_USERNAME') 20 | self.password = self._get_env_var(password, 'PAVI_PASSWORD') 21 | self.instance_id = instance_id 22 | self.log_queue = None 23 | 24 | def _get_env_var(self, var, env_var): 25 | if var is not None: 26 | return str(var) 27 | 28 | var = os.getenv(env_var) 29 | if not var: 30 | raise ValueError( 31 | '"{}" is neither specified nor defined as env variables'. 32 | format(env_var)) 33 | return var 34 | 35 | def connect(self, 36 | model_name, 37 | work_dir=None, 38 | info=dict(), 39 | timeout=5, 40 | logger=None): 41 | if logger: 42 | log_info = logger.info 43 | log_error = logger.error 44 | else: 45 | log_info = log_error = print 46 | log_info('connecting pavi service {}...'.format(self.url)) 47 | post_data = dict( 48 | time=str(datetime.now()), 49 | username=self.username, 50 | password=self.password, 51 | instance_id=self.instance_id, 52 | model=model_name, 53 | work_dir=os.path.abspath(work_dir) if work_dir else '', 54 | session_file=info.get('session_file', ''), 55 | session_text=info.get('session_text', ''), 56 | model_text=info.get('model_text', ''), 57 | device=get_host_info()) 58 | try: 59 | response = requests.post(self.url, json=post_data, timeout=timeout) 60 | except Exception as ex: 61 | log_error('fail to connect to pavi service: {}'.format(ex)) 62 | else: 63 | if response.status_code == 200: 64 | self.instance_id = response.text 65 | log_info('pavi service connected, instance_id: {}'.format( 66 | self.instance_id)) 67 | self.log_queue = Queue() 68 | self.log_thread = Thread(target=self.post_worker_fn) 69 | self.log_thread.daemon = True 70 | self.log_thread.start() 71 | return True 72 | else: 73 | log_error('fail to connect to pavi service, status code: ' 74 | '{}, err message: {}'.format(response.status_code, 75 | response.reason)) 76 | return False 77 | 78 | def post_worker_fn(self, max_retry=3, queue_timeout=1, req_timeout=3): 79 | while True: 80 | try: 81 | log = self.log_queue.get(timeout=queue_timeout) 82 | except Empty: 83 | time.sleep(1) 84 | except Exception as ex: 85 | print('fail to get logs from queue: {}'.format(ex)) 86 | else: 87 | retry = 0 88 | while retry < max_retry: 89 | try: 90 | response = requests.post( 91 | self.url, json=log, timeout=req_timeout) 92 | except Exception as ex: 93 | retry += 1 94 | print('error when posting logs to pavi: {}'.format(ex)) 95 | else: 96 | status_code = response.status_code 97 | if status_code == 200: 98 | break 99 | else: 100 | print('unexpected status code: %d, err msg: %s', 101 | status_code, response.reason) 102 | retry += 1 103 | if retry == max_retry: 104 | print('fail to send logs of iteration %d', log['iter_num']) 105 | 106 | def log(self, phase, num_iters, outputs): 107 | if self.log_queue is not None: 108 | logs = { 109 | 'time': str(datetime.now()), 110 | 'instance_id': self.instance_id, 111 | 'flow_id': phase, 112 | 'iter_num': num_iters, 113 | 'outputs': outputs, 114 | 'msg': '' 115 | } 116 | self.log_queue.put(logs) 117 | 118 | 119 | class PaviLoggerHook(LoggerHook): 120 | 121 | def __init__(self, 122 | url, 123 | username=None, 124 | password=None, 125 | instance_id=None, 126 | interval=10, 127 | reset_meter=True, 128 | ignore_last=True): 129 | self.pavi_logger = PaviLogger(url, username, password, instance_id) 130 | super(PaviLoggerHook, self).__init__(interval, reset_meter, 131 | ignore_last) 132 | 133 | @master_only 134 | def connect(self, 135 | model_name, 136 | work_dir=None, 137 | info=dict(), 138 | timeout=5, 139 | logger=None): 140 | return self.pavi_logger.connect(model_name, work_dir, info, timeout, 141 | logger) 142 | 143 | @master_only 144 | def log(self, runner): 145 | log_outs = runner.log_buffer.output.copy() 146 | log_outs.pop('time', None) 147 | log_outs.pop('data_time', None) 148 | self.pavi_logger.log(runner.mode, runner.num_iters, log_outs) 149 | -------------------------------------------------------------------------------- /torchpack/runner/hooks/lr_updater.py: -------------------------------------------------------------------------------- 1 | from __future__ import division 2 | 3 | from .hook import Hook 4 | 5 | 6 | class LrUpdaterHook(Hook): 7 | 8 | def __init__(self, 9 | by_epoch=True, 10 | warm_up=None, 11 | warm_up_iters=0, 12 | warm_up_ratio=0.1, 13 | **kwargs): 14 | # validate the "warm_up" argument 15 | if warm_up is not None: 16 | if warm_up not in ['constant', 'linear', 'exp']: 17 | raise ValueError( 18 | '"{}" is not a supported type for warming up, valid types' 19 | ' are "constant" and "linear"'.format(warm_up)) 20 | if warm_up is not None: 21 | assert warm_up_iters > 0, \ 22 | '"warm_up_iters" must be a positive integer' 23 | assert 0 < warm_up_ratio <= 1.0, \ 24 | '"warm_up_ratio" must be in range (0,1]' 25 | 26 | self.by_epoch = by_epoch 27 | self.warm_up = warm_up 28 | self.warm_up_iters = warm_up_iters 29 | self.warm_up_ratio = warm_up_ratio 30 | 31 | self.base_lr = [] # initial lr for all param groups 32 | self.regular_lr = [] # expected lr if no warming up is performed 33 | 34 | def _set_lr(self, runner, lr_groups): 35 | for param_group, lr in zip(runner.optimizer.param_groups, lr_groups): 36 | param_group['lr'] = lr 37 | 38 | def get_lr(self, runner, base_lr): 39 | raise NotImplementedError 40 | 41 | def get_regular_lr(self, runner): 42 | return [self.get_lr(runner, _base_lr) for _base_lr in self.base_lr] 43 | 44 | def get_warmup_lr(self, cur_iters): 45 | if self.warm_up == 'constant': 46 | warmup_lr = [_lr * self.warm_up_ratio for _lr in self.regular_lr] 47 | elif self.warm_up == 'linear': 48 | k = (1 - cur_iters / self.warm_up_iters) * (1 - self.warm_up_ratio) 49 | warmup_lr = [_lr * (1 - k) for _lr in self.regular_lr] 50 | elif self.warm_up == 'exp': 51 | k = self.warm_up_ratio**(1 - cur_iters / self.warm_up_iters) 52 | warmup_lr = [_lr * k for _lr in self.regular_lr] 53 | return warmup_lr 54 | 55 | def before_run(self, runner): 56 | # NOTE: when resuming from a checkpoint, if 'initial_lr' is not saved, 57 | # it will be set according to the optimizer params 58 | for group in runner.optimizer.param_groups: 59 | group.setdefault('initial_lr', group['lr']) 60 | self.base_lr = [ 61 | group['initial_lr'] for group in runner.optimizer.param_groups 62 | ] 63 | 64 | def before_train_epoch(self, runner): 65 | if not self.by_epoch: 66 | return 67 | self.regular_lr = self.get_regular_lr(runner) 68 | self._set_lr(runner, self.regular_lr) 69 | 70 | def before_train_iter(self, runner): 71 | cur_iters = runner.num_iters 72 | if not self.by_epoch: 73 | self.regular_lr = self.get_regular_lr(runner) 74 | if self.warm_up is None or cur_iters >= self.warm_up_iters: 75 | self._set_lr(runner, self.regular_lr) 76 | else: 77 | warmup_lr = self.get_warmup_lr(cur_iters) 78 | self._set_lr(runner, warmup_lr) 79 | elif self.by_epoch: 80 | if self.warm_up is None or cur_iters > self.warm_up_iters: 81 | return 82 | elif cur_iters == self.warm_up_iters: 83 | self._set_lr(runner, self.regular_lr) 84 | else: 85 | warmup_lr = self.get_warmup_lr(cur_iters) 86 | self._set_lr(runner, warmup_lr) 87 | 88 | 89 | class FixedLrUpdaterHook(LrUpdaterHook): 90 | 91 | def __init__(self, **kwargs): 92 | super(FixedLrUpdaterHook, self).__init__(**kwargs) 93 | 94 | def get_lr(self, runner, base_lr): 95 | return base_lr 96 | 97 | 98 | class StepLrUpdaterHook(LrUpdaterHook): 99 | 100 | def __init__(self, step, gamma=0.1, **kwargs): 101 | assert isinstance(step, (list, int)) 102 | if isinstance(step, list): 103 | for s in step: 104 | assert isinstance(s, int) and s > 0 105 | elif isinstance(step, int): 106 | assert step > 0 107 | else: 108 | raise TypeError('"step" must be a list or integer') 109 | self.step = step 110 | self.gamma = gamma 111 | super(StepLrUpdaterHook, self).__init__(**kwargs) 112 | 113 | def get_lr(self, runner, base_lr): 114 | progress = runner.epoch if self.by_epoch else runner.num_iters 115 | 116 | if isinstance(self.step, int): 117 | return base_lr * (self.gamma**(progress // self.step)) 118 | 119 | exp = len(self.step) 120 | for i, s in enumerate(self.step): 121 | if progress < s: 122 | exp = i 123 | break 124 | return base_lr * self.gamma**exp 125 | 126 | 127 | class ExpLrUpdaterHook(LrUpdaterHook): 128 | 129 | def __init__(self, gamma, **kwargs): 130 | self.gamma = gamma 131 | super(ExpLrUpdaterHook, self).__init__(**kwargs) 132 | 133 | def get_lr(self, runner, base_lr): 134 | progress = runner.epoch if self.by_epoch else runner.num_iters 135 | return base_lr * self.gamma**progress 136 | 137 | 138 | class PolyLrUpdaterHook(LrUpdaterHook): 139 | 140 | def __init__(self, power=1., **kwargs): 141 | self.power = power 142 | super(PolyLrUpdaterHook, self).__init__(**kwargs) 143 | 144 | def get_lr(self, runner, base_lr): 145 | if self.by_epoch: 146 | progress = runner.epoch 147 | max_progress = runner.max_epoch 148 | else: 149 | progress = runner.num_iters 150 | max_progress = runner.max_iter 151 | return base_lr * (1 - progress / max_progress)**self.power 152 | 153 | 154 | class InvLrUpdaterHook(LrUpdaterHook): 155 | 156 | def __init__(self, gamma, power=1., **kwargs): 157 | self.gamma = gamma 158 | self.power = power 159 | super(InvLrUpdaterHook, self).__init__(**kwargs) 160 | 161 | def get_lr(self, runner, base_lr): 162 | progress = runner.epoch if self.by_epoch else runner.num_iters 163 | return base_lr * (1 + self.gamma * progress)**(-self.power) 164 | -------------------------------------------------------------------------------- /torchpack/runner/runner.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import os 3 | import time 4 | 5 | import torch 6 | from torch.nn.parallel import DataParallel, DistributedDataParallel 7 | 8 | from torchpack.io import load_checkpoint, save_checkpoint 9 | from torchpack.runner.hooks import (Hook, LrUpdaterHook, CheckpointSaverHook, 10 | TimerHook, OptimizerStepperHook) 11 | from torchpack.runner.log_buffer import LogBuffer 12 | from torchpack.runner.utils import get_dist_info, get_host_info 13 | 14 | 15 | class Runner(object): 16 | 17 | def __init__(self, 18 | model, 19 | optimizer, 20 | batch_processor, 21 | work_dir=None, 22 | log_level=logging.INFO): 23 | self.model = model 24 | self.optimizer = self.set_optimizer(optimizer) 25 | assert callable(batch_processor) 26 | self.batch_processor = batch_processor 27 | 28 | self.rank, self.world_size = get_dist_info() 29 | 30 | if isinstance(work_dir, str): 31 | self.work_dir = os.path.abspath(work_dir) 32 | if not os.path.isdir(self.work_dir): 33 | os.makedirs(self.work_dir) 34 | elif work_dir is None: 35 | self.work_dir = work_dir 36 | else: 37 | raise TypeError('"work_dir" must be a str or None') 38 | 39 | self.logger = self.init_logger(work_dir, log_level) 40 | 41 | if isinstance(self.model, (DataParallel, DistributedDataParallel)): 42 | self._model_name = self.model.module.__class__.__name__ 43 | else: 44 | self._model_name = self.model.__class__.__name__ 45 | 46 | self.log_buffer = LogBuffer() 47 | self.hooks = [] 48 | self.max_epoch = 0 49 | self.max_iter = 0 50 | self.epoch = 0 51 | self.num_iters = 0 52 | self.num_epoch_iters = 0 53 | self.mode = None 54 | 55 | @property 56 | def model_name(self): 57 | return self._model_name 58 | 59 | def set_optimizer(self, optimizer): 60 | if isinstance(optimizer, dict): 61 | optim_cls = getattr(torch.optim, optimizer['algorithm']) 62 | optimizer = optim_cls(self.model.parameters(), **optimizer['args']) 63 | elif not isinstance(optimizer, torch.optim.Optimizer): 64 | raise TypeError( 65 | '"optimizer" must be either an Optimizer object or a dict') 66 | return optimizer 67 | 68 | def init_logger(self, log_dir=None, level=logging.INFO): 69 | logging.basicConfig( 70 | format='%(asctime)s - %(levelname)s - %(message)s', level=level) 71 | logger = logging.getLogger(__name__) 72 | if log_dir: 73 | filename = '{}_{}.log'.format( 74 | time.strftime('%Y%m%d_%H%M%S', time.localtime()), self.rank) 75 | log_file = os.path.join(log_dir, filename) 76 | file_handler = logging.FileHandler(log_file, 'w') 77 | file_handler.setFormatter( 78 | logging.Formatter('%(asctime)s - %(levelname)s - %(message)s')) 79 | logger.addHandler(file_handler) 80 | return logger 81 | 82 | def current_lr(self): 83 | return [group['lr'] for group in self.optimizer.param_groups] 84 | 85 | def register_hook(self, hook, hook_type=None, priority=50): 86 | assert isinstance(priority, int) and priority >= 0 and priority <= 100 87 | if hook_type is None: 88 | assert isinstance(hook, Hook) 89 | else: 90 | if isinstance(hook, dict): 91 | hook = hook_type(**hook) 92 | elif not isinstance(hook, hook_type): 93 | raise TypeError('hook must be a {} object or a dict'.format( 94 | hook_type.__name__)) 95 | if hasattr(hook, 'priority'): 96 | raise ValueError('"priority" is a reserved attribute for hooks') 97 | hook.priority = priority 98 | # insert the hook to a sorted list 99 | inserted = False 100 | for i in range(len(self.hooks) - 1, -1, -1): 101 | if priority >= self.hooks[i].priority: 102 | self.hooks.insert(i + 1, hook) 103 | inserted = True 104 | break 105 | if not inserted: 106 | self.hooks.insert(0, hook) 107 | 108 | def call_hook(self, fn_name): 109 | for hook in self.hooks: 110 | getattr(hook, fn_name)(self) 111 | 112 | def load_checkpoint(self, filename, map_location='cpu', strict=False): 113 | self.logger.info('load checkpoint from %s', filename) 114 | return load_checkpoint(self.model, filename, map_location, strict, 115 | self.logger) 116 | 117 | def save_checkpoint(self, out_dir, filename_tmpl='epoch_{}.pth'): 118 | save_checkpoint( 119 | self.model, 120 | self.epoch + 1, 121 | self.num_iters, 122 | out_dir=out_dir, 123 | filename_tmpl=filename_tmpl) 124 | 125 | def train(self, data_loader, **kwargs): 126 | self.model.train() 127 | self.mode = 'train' 128 | self.data_loader = data_loader 129 | self.max_iter = self.max_epoch * len(data_loader) 130 | self.call_hook('before_train_epoch') 131 | for i, data_batch in enumerate(data_loader): 132 | self.num_epoch_iters = i 133 | self.call_hook('before_train_iter') 134 | outputs = self.batch_processor( 135 | self.model, data_batch, train_mode=True, **kwargs) 136 | if not isinstance(outputs, dict): 137 | raise TypeError('batch_processor() must return a dict') 138 | if 'log_vars' in outputs: 139 | self.log_buffer.update(outputs['log_vars'], 140 | outputs['num_samples']) 141 | self.outputs = outputs 142 | self.call_hook('after_train_iter') 143 | self.num_iters += 1 144 | self.call_hook('after_train_epoch') 145 | self.epoch += 1 146 | 147 | def val(self, data_loader, **kwargs): 148 | self.model.eval() 149 | self.mode = 'val' 150 | self.data_loader = data_loader 151 | self.call_hook('before_val_epoch') 152 | for i, data_batch in enumerate(data_loader): 153 | self.num_epoch_iters = i 154 | self.call_hook('before_val_iter') 155 | outputs = self.batch_processor( 156 | self.model, data_batch, train_mode=False, **kwargs) 157 | if not isinstance(outputs, dict): 158 | raise TypeError('batch_processor() must return a dict') 159 | if 'log_vars' in outputs: 160 | self.log_buffer.update(outputs['log_vars'], 161 | outputs['num_samples']) 162 | self.outputs = outputs 163 | self.call_hook('after_val_iter') 164 | self.call_hook('after_val_epoch') 165 | 166 | def resume(self, checkpoint, resume_optimizer=True, 167 | map_location='default'): 168 | if map_location == 'default': 169 | device_id = torch.cuda.current_device() 170 | checkpoint = self.load_checkpoint( 171 | checkpoint, 172 | map_location=lambda storage, loc: storage.cuda(device_id)) 173 | else: 174 | checkpoint = self.load_checkpoint( 175 | checkpoint, map_location=map_location) 176 | self.epoch = checkpoint['epoch'] 177 | self.num_iters = checkpoint['num_iters'] 178 | if 'optimizer' in checkpoint and resume_optimizer: 179 | self.optimizer.load_state_dict(checkpoint['optimizer']) 180 | self.logger.info('resumed epoch %d, iter %d', self.epoch, 181 | self.num_iters) 182 | 183 | def run(self, data_loaders, workflow, max_epoch, **kwargs): 184 | assert isinstance(data_loaders, list) 185 | self.max_epoch = max_epoch 186 | work_dir = self.work_dir if self.work_dir is not None else 'NONE' 187 | self.logger.info('Start running, host: %s, work_dir: %s', 188 | get_host_info(), work_dir) 189 | self.logger.info('workflow: %s, max: %d epochs', workflow, max_epoch) 190 | self.call_hook('before_run') 191 | while self.epoch < max_epoch: 192 | for i, flow in enumerate(workflow): 193 | mode, epochs = flow 194 | if isinstance(mode, str): 195 | if not hasattr(self, mode): 196 | raise ValueError( 197 | 'runner has no method named "{}" to run an epoch'. 198 | format(mode)) 199 | epoch_runner = getattr(self, mode) 200 | elif callable(mode): 201 | epoch_runner = mode 202 | else: 203 | raise TypeError('mode in workflow must be a str or ' 204 | 'callable function, not {}'.format( 205 | type(mode))) 206 | for _ in range(epochs): 207 | if mode == 'train' and self.epoch >= max_epoch: 208 | return 209 | epoch_runner(data_loaders[i], **kwargs) 210 | time.sleep(1) # wait for some hooks like loggers to finish 211 | self.call_hook('after_run') 212 | 213 | def register_lr_hooks(self, lr_config): 214 | if isinstance(lr_config, LrUpdaterHook): 215 | self.register_hook(lr_config) 216 | elif isinstance(lr_config, dict): 217 | assert 'policy' in lr_config 218 | from .hooks import lr_updater 219 | hook_name = lr_config['policy'].title() + 'LrUpdaterHook' 220 | if not hasattr(lr_updater, hook_name): 221 | raise ValueError('"{}" does not exist'.format(hook_name)) 222 | hook_cls = getattr(lr_updater, hook_name) 223 | self.register_hook(hook_cls(**lr_config)) 224 | else: 225 | raise TypeError('"lr_config" must be either a LrUpdaterHook object' 226 | ' or dict, not {}'.format(type(lr_config))) 227 | 228 | def register_logger_hooks(self, log_config): 229 | self.register_hook(TimerHook()) 230 | log_interval = log_config['interval'] 231 | from . import hooks 232 | for logger_name, args in log_config['hooks']: 233 | if isinstance(logger_name, str): 234 | logger_cls = getattr(hooks, logger_name) 235 | elif isinstance(logger_name, type): 236 | logger_cls = logger_name 237 | else: 238 | raise TypeError( 239 | 'logger name must be a string of hook type, not {}'.format( 240 | logger_name)) 241 | kwargs = args.copy() 242 | if 'interval' not in kwargs: 243 | kwargs['interval'] = log_interval 244 | self.register_hook(logger_cls(**kwargs), priority=60) 245 | 246 | def register_default_hooks(self, 247 | lr_config, 248 | grad_clip_config=None, 249 | checkpoint_config=None, 250 | log_config=None): 251 | """Register several default hooks""" 252 | if grad_clip_config is None: 253 | grad_clip_config = {} 254 | if checkpoint_config is None: 255 | checkpoint_config = {} 256 | self.register_lr_hooks(lr_config) 257 | self.register_hook(grad_clip_config, OptimizerStepperHook) 258 | self.register_hook(checkpoint_config, CheckpointSaverHook) 259 | if log_config is not None: 260 | self.register_logger_hooks(log_config) 261 | --------------------------------------------------------------------------------