├── .gitignore ├── README.md ├── __init__.py ├── test_logger.py └── trainer ├── __init__.py ├── plugins ├── __init__.py ├── accuracy.py ├── constant.py ├── logger.py ├── loss.py ├── monitor.py ├── plugin.py ├── progress.py ├── saver.py ├── time.py └── visdom_logger.py └── trainer.py /.gitignore: -------------------------------------------------------------------------------- 1 | *.pyc 2 | trainer/plugins/visdom_sample_logger.py -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Add a TF-slim like framework to PyTorch, to enable rapid research. 2 | ### And bring the TensorBoard-like power of Visdom to PyTorch! 3 | 4 | ## This code is now integrated into [TNT](https://github.com/pytorch/tnt) 5 | TNT is the 'official' framework for PyTorch and is expected to be merged into PyTorch itself. I'll push updates with new features directly to TNT. 6 | 7 | --- 8 | 9 | ## Sample 10 | Visdom is a powerful and flexible platform for visualizing data, from FB. It fulfills much the same role as tensorboard, and is simple to use. 11 | ![Example](https://user-images.githubusercontent.com/5157485/28799619-2bebef8c-75fe-11e7-898d-202a6c6d3239.png) 12 | 13 | 14 | This repo contains a collection of tools for easily logging to Visdom, and for reusing these tools across different projects. It will also contain nice tools for training models. The vizualization above is logged and also saved periodically (easily adjustable with a parameter) using the code below: 15 | ``` 16 | # Plugins produce statistics 17 | progress_plug = ProgressMonitor() 18 | random_plug = RandomMonitor(10000) 19 | image_plug = ConstantMonitor(data.coffee().swapaxes(0,2).swapaxes(1,2), "image") 20 | 21 | # Loggers are a special type of plugin which, surprise, logs the stats 22 | logger = Logger(["progress"], [(2, 'iteration')]) 23 | text_logger = VisdomTextLogger(["progress"], [(2, 'iteration')], update_type='APPEND', 24 | env=env, opts=dict(title='Example logging')) 25 | scatter_logger = VisdomPlotLogger('scatter', ["progress.samples_used", "progress.percent"], [(1, 'iteration')], 26 | env=env, opts=dict(title='Percent Done vs Samples Used')) 27 | hist_logger = VisdomLogger('histogram', ["random.data"], [(2, 'iteration')], 28 | env=env, opts=dict(title='Random!', numbins=20)) 29 | image_logger = VisdomLogger('image', ["image.data"], [(2, 'iteration')], env=env) 30 | 31 | 32 | # Create a saver 33 | saver = VisdomSaver(envs=[env]) 34 | 35 | # Register the plugins with the trainer 36 | train.register_plugin(progress_plug) 37 | train.register_plugin(random_plug) 38 | train.register_plugin(image_plug) 39 | 40 | train.register_plugin(logger) 41 | train.register_plugin(text_logger) 42 | train.register_plugin(scatter_logger) 43 | train.register_plugin(hist_logger) 44 | train.register_plugin(image_logger) 45 | 46 | train.register_plugin(saver) 47 | ``` 48 | 49 | 50 | --- 51 | ## References 52 | The trainer and plugin framework is taken, with slight modifications, from the main PyTorch branch. Ideally, the functionality from this repo can be pulled back into PyTorch so it is more easily available, and can be used with some existing great libraries like 53 | - [TNT](http://github.com/PyTorch/tnt) 54 | - [TorchSample](http://github.com/ncullen93/torchsample) 55 | 56 | Also consider [Inferno](https://github.com/nasimrahaman/inferno) which is new and under heavy active development. 57 | -------------------------------------------------------------------------------- /__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/alexsax/pytorch-visdom/6e46601d71ea417ec2a5b39316d2a0ea8ac921cf/__init__.py -------------------------------------------------------------------------------- /test_logger.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import division 3 | from __future__ import print_function 4 | 5 | import torch 6 | import torch.optim as optim 7 | import torch.nn as nn 8 | import torch.nn.functional as F 9 | from trainer import Trainer 10 | from trainer.plugins.logger import Logger 11 | from trainer.plugins.visdom_logger import * 12 | from trainer.plugins.progress import ProgressMonitor 13 | from trainer.plugins.random import RandomMonitor 14 | from trainer.plugins.constant import ConstantMonitor 15 | from skimage import data 16 | 17 | class ShallowMLP(nn.Module): 18 | def __init__(self, shape, force_no_cuda=False): 19 | super(ShallowMLP, self).__init__() 20 | self.in_shape = shape[0] 21 | self.hidden_shape = shape[1] 22 | self.out_shape = shape[2] 23 | self.fc1 = nn.Linear(self.in_shape, self.hidden_shape) 24 | self.relu = F.relu 25 | self.fc2 = nn.Linear(self.hidden_shape, self.out_shape) 26 | 27 | self.use_cuda = torch.cuda.is_available() and not force_no_cuda 28 | if self.use_cuda: 29 | self = self.cuda() 30 | 31 | def forward(self, x): 32 | x = self.fc1(x) 33 | x = self.relu(x) 34 | y = self.fc2(x) 35 | return y 36 | 37 | class SimpleDataset(object): 38 | def __init__(self, n, force_no_cuda=False): 39 | super(SimpleDataset, self) 40 | self.n = n 41 | self.i = 0 42 | self.use_cuda = torch.cuda.is_available() and not force_no_cuda 43 | 44 | def __iter__(self): 45 | return self 46 | 47 | def __next__(self): 48 | return self.next() 49 | 50 | def next(self): 51 | if self.i >= self.n: 52 | raise StopIteration() 53 | cur = self.i 54 | self.i += 1 55 | if self.use_cuda: 56 | return torch.cuda.FloatTensor([[cur]]), torch.cuda.FloatTensor([[cur]]) 57 | else: 58 | return torch.FloatTensor([[cur]]), torch.FloatTensor([[cur]]) 59 | 60 | def __len__(self): 61 | return self.n 62 | 63 | if __name__=="__main__": 64 | env = 'samples' 65 | force_no_cuda = True 66 | model = ShallowMLP((1,5,1), force_no_cuda=force_no_cuda) 67 | dataset = SimpleDataset(5, force_no_cuda) 68 | 69 | optimizer = optim.SGD(model.parameters(), 0.001) 70 | criterion = nn.L1Loss() 71 | train = Trainer(model, 72 | criterion=criterion, 73 | optimizer=optimizer, 74 | dataset=dataset) 75 | 76 | # Plugins produce statistics 77 | progress_plug = ProgressMonitor() 78 | random_plug = RandomMonitor(10000) 79 | image_plug = ConstantMonitor(data.coffee().swapaxes(0,2).swapaxes(1,2), "image") 80 | 81 | # Loggers are a special type of plugin which, surprise, logs the stats 82 | logger = Logger(["progress"], [(2, 'iteration')]) 83 | text_logger = VisdomTextLogger(["progress"], [(2, 'iteration')], update_type='APPEND', 84 | env=env, opts=dict(title='Example logging')) 85 | scatter_logger = VisdomPlotLogger('scatter', ["progress.samples_used", "progress.percent"], [(1, 'iteration')], 86 | env=env, opts=dict(title='Percent Done vs Samples Used')) 87 | hist_logger = VisdomLogger('histogram', ["random.data"], [(2, 'iteration')], 88 | env=env, opts=dict(title='Random!', numbins=20)) 89 | image_logger = VisdomLogger('image', ["image.data"], [(2, 'iteration')], env=env) 90 | 91 | 92 | # Create a saver 93 | saver = VisdomSaver(envs=[env]) 94 | 95 | # Register the plugins with the trainer 96 | train.register_plugin(progress_plug) 97 | train.register_plugin(random_plug) 98 | train.register_plugin(image_plug) 99 | 100 | train.register_plugin(logger) 101 | train.register_plugin(text_logger) 102 | train.register_plugin(scatter_logger) 103 | train.register_plugin(hist_logger) 104 | train.register_plugin(image_logger) 105 | 106 | train.register_plugin(saver) 107 | 108 | train.run() 109 | -------------------------------------------------------------------------------- /trainer/__init__.py: -------------------------------------------------------------------------------- 1 | 2 | from .trainer import Trainer 3 | -------------------------------------------------------------------------------- /trainer/plugins/__init__.py: -------------------------------------------------------------------------------- 1 | from .progress import ProgressMonitor 2 | from .accuracy import AccuracyMonitor 3 | from .time import TimeMonitor 4 | from .loss import LossMonitor 5 | from .logger import Logger 6 | -------------------------------------------------------------------------------- /trainer/plugins/accuracy.py: -------------------------------------------------------------------------------- 1 | from .monitor import Monitor 2 | 3 | 4 | class AccuracyMonitor(Monitor): 5 | stat_name = 'accuracy' 6 | 7 | def __init__(self, *args, **kwargs): 8 | kwargs.setdefault('unit', '%') 9 | kwargs.setdefault('precision', 2) 10 | super(AccuracyMonitor, self).__init__(*args, **kwargs) 11 | 12 | def _get_value(self, iteration, input, target, output, loss): 13 | batch_size = input.size(0) 14 | predictions = output.max(1)[1].type_as(target) 15 | correct = predictions.eq(target) 16 | if not hasattr(correct, 'sum'): 17 | correct = correct.cpu() 18 | correct = correct.sum() 19 | return 100. * correct / batch_size 20 | -------------------------------------------------------------------------------- /trainer/plugins/constant.py: -------------------------------------------------------------------------------- 1 | """ Trivial logger which repeatedly logs the same value """ 2 | import numpy as np 3 | from .plugin import Plugin 4 | 5 | 6 | class ConstantMonitor(Plugin): 7 | def __init__(self, data, stat_name='constant'): 8 | super(ConstantMonitor, self).__init__([(1, 'iteration'), (1, 'epoch')]) 9 | self.stat_name = stat_name 10 | self.data = data 11 | 12 | def register(self, trainer): 13 | self.trainer = trainer 14 | stats = self.trainer.stats.setdefault(self.stat_name, {}) 15 | stats['data'] = self.data 16 | 17 | def iteration(self, iteration, input, *args): 18 | pass 19 | 20 | def epoch(self, *args): 21 | pass 22 | 23 | 24 | -------------------------------------------------------------------------------- /trainer/plugins/logger.py: -------------------------------------------------------------------------------- 1 | """ Base logging class""" 2 | from collections import defaultdict 3 | from six import string_types 4 | from .plugin import Plugin 5 | 6 | def is_sequence(arg): 7 | return (not hasattr(arg, "strip") and 8 | (hasattr(arg, "__getitem__") or 9 | hasattr(arg, "__iter__"))) 10 | 11 | class Logger(Plugin): 12 | """Logger plugin for Trainer""" 13 | alignment = 4 14 | separator = '#' * 80 15 | 16 | def __init__(self, fields, interval=[(1, 'iteration'), (1, 'epoch')]): 17 | """ 18 | Args: 19 | fields: The fields to log. May either be the name of some stat 20 | (e.g. ProgressMonitor) will have `stat_name='progress'`, 21 | in which case all of the fields under `log_HOOK_fields` 22 | will be logged. Finer-grained control can be specified by 23 | using individual fields such as `progress.percent`. 24 | interval: A List of 2-tuples where each tuple contains 25 | (k, HOOK). 26 | k (int): The logger will be called every 'k' HOOK 27 | HOOK (string): The logger will be called at the given hook 28 | 29 | Examples: 30 | >>> progress_m = ProgressMonitor() 31 | >>> logger = Logger(["progress"], [(2, 'iteration')]) 32 | """ 33 | if not is_sequence(fields): 34 | raise ValueError("'fields' must be a sequence of strings, not {}".format(type(fields))) 35 | 36 | for i, val in enumerate(fields): 37 | if not isinstance(val, string_types): 38 | raise ValueError("Element {} of 'fields' ({}) must be a string.".format( 39 | i, val)) 40 | 41 | super(Logger, self).__init__(interval) 42 | self.field_widths = defaultdict(lambda: defaultdict(int)) 43 | self.fields = list(map(lambda f: f.split('.'), fields)) 44 | 45 | def _join_results(self, results): 46 | joined_out = map(lambda i: (i[0], ' '.join(i[1])), results) 47 | joined_fields = map(lambda i: '{}: {}'.format(i[0], i[1]), joined_out) 48 | return '\t'.join(joined_fields) 49 | 50 | def log(self, msg): 51 | print(msg) 52 | 53 | def register(self, trainer): 54 | self.trainer = trainer 55 | 56 | def gather_stats(self): 57 | result = {} 58 | return result 59 | 60 | def _align_output(self, field_idx, output): 61 | for output_idx, o in enumerate(output): 62 | if len(o) < self.field_widths[field_idx][output_idx]: 63 | num_spaces = self.field_widths[field_idx][output_idx] - len(o) 64 | output[output_idx] += ' ' * num_spaces 65 | else: 66 | self.field_widths[field_idx][output_idx] = len(o) 67 | 68 | def _gather_outputs(self, field, log_fields, stat_parent, stat, require_dict=False): 69 | output = [] 70 | name = '' 71 | if isinstance(stat, dict): 72 | log_fields = stat.get(log_fields, []) 73 | name = stat.get('log_name', '.'.join(field)) 74 | for f in log_fields: 75 | output.append(f.format(**stat)) 76 | elif not require_dict: 77 | name = '.'.join(field) 78 | number_format = stat_parent.get('log_format', '') 79 | unit = stat_parent.get('log_unit', '') 80 | fmt = '{' + number_format + '}' + unit 81 | output.append(fmt.format(stat)) 82 | return name, output 83 | 84 | def _log_all(self, log_fields, prefix=None, suffix=None, require_dict=False): 85 | results = [] 86 | for field_idx, field in enumerate(self.fields): 87 | parent, stat = None, self.trainer.stats 88 | for f in field: 89 | parent, stat = stat, stat[f] 90 | name, output = self._gather_outputs(field, log_fields, 91 | parent, stat, require_dict) 92 | if not output: 93 | continue 94 | self._align_output(field_idx, output) 95 | results.append((name, output)) 96 | if not results: 97 | return 98 | output = self._join_results(results) 99 | if prefix is not None: 100 | self.log(prefix) 101 | self.log(output) 102 | if suffix is not None: 103 | self.log(suffix) 104 | 105 | def iteration(self, *args): 106 | self._log_all('log_iter_fields') 107 | 108 | def epoch(self, epoch_idx): 109 | self._log_all('log_epoch_fields', 110 | prefix=self.separator + '\nEpoch summary:', 111 | suffix=self.separator, 112 | require_dict=True) 113 | -------------------------------------------------------------------------------- /trainer/plugins/loss.py: -------------------------------------------------------------------------------- 1 | from .monitor import Monitor 2 | 3 | 4 | class LossMonitor(Monitor): 5 | stat_name = 'loss' 6 | 7 | def _get_value(self, iteration, input, target, output, loss): 8 | return loss[0] 9 | -------------------------------------------------------------------------------- /trainer/plugins/monitor.py: -------------------------------------------------------------------------------- 1 | from .plugin import Plugin 2 | 3 | 4 | class Monitor(Plugin): 5 | 6 | def __init__(self, running_average=True, epoch_average=True, smoothing=0.7, 7 | precision=None, number_format=None, unit=''): 8 | if precision is None: 9 | precision = 4 10 | if number_format is None: 11 | number_format = '.{}f'.format(precision) 12 | number_format = ':' + number_format 13 | super(Monitor, self).__init__([(1, 'iteration'), (1, 'epoch')]) 14 | 15 | self.smoothing = smoothing 16 | self.with_running_average = running_average 17 | self.with_epoch_average = epoch_average 18 | 19 | self.log_format = number_format 20 | self.log_unit = unit 21 | self.log_epoch_fields = None 22 | self.log_iter_fields = ['{last' + number_format + '}' + unit] 23 | if self.with_running_average: 24 | self.log_iter_fields += [' ({running_avg' + number_format + '}' + unit + ')'] 25 | if self.with_epoch_average: 26 | self.log_epoch_fields = ['{epoch_mean' + number_format + '}' + unit] 27 | 28 | def register(self, trainer): 29 | self.trainer = trainer 30 | stats = self.trainer.stats.setdefault(self.stat_name, {}) 31 | stats['log_format'] = self.log_format 32 | stats['log_unit'] = self.log_unit 33 | stats['log_iter_fields'] = self.log_iter_fields 34 | if self.with_epoch_average: 35 | stats['log_epoch_fields'] = self.log_epoch_fields 36 | if self.with_epoch_average: 37 | stats['epoch_stats'] = (0, 0) 38 | 39 | def iteration(self, *args): 40 | stats = self.trainer.stats.setdefault(self.stat_name, {}) 41 | stats['last'] = self._get_value(*args) 42 | 43 | if self.with_epoch_average: 44 | stats['epoch_stats'] = tuple(sum(t) for t in 45 | zip(stats['epoch_stats'], (stats['last'], 1))) 46 | 47 | if self.with_running_average: 48 | previous_avg = stats.get('running_avg', 0) 49 | stats['running_avg'] = previous_avg * self.smoothing + \ 50 | stats['last'] * (1 - self.smoothing) 51 | 52 | def epoch(self, idx): 53 | stats = self.trainer.stats.setdefault(self.stat_name, {}) 54 | if self.with_epoch_average: 55 | epoch_stats = stats['epoch_stats'] 56 | stats['epoch_mean'] = epoch_stats[0] / epoch_stats[1] 57 | stats['epoch_stats'] = (0, 0) 58 | -------------------------------------------------------------------------------- /trainer/plugins/plugin.py: -------------------------------------------------------------------------------- 1 | 2 | class Plugin(object): 3 | 4 | def __init__(self, interval=None): 5 | """ 6 | Args: 7 | interval: A list, e.g. [(10, 'iteration'), (1, 'epoch')] which 8 | specifies that the plugin should be called every 10 9 | iterations and also every epoch. 10 | """ 11 | if interval is None: 12 | interval = [] 13 | self.trigger_interval = interval 14 | 15 | def register(self, trainer): 16 | raise NotImplementedError 17 | 18 | 19 | class PluginFactory(Plugin): 20 | 21 | def __init__(self, fn, register_fn=None, interval=None): 22 | """Creates a Plugin which applies fn at each hook in 'interval'""" 23 | super(PluginFactory, self).__init__(interval) 24 | for _, name in interval: 25 | setattr(self, name, fn) 26 | 27 | if not register_fn: 28 | def register(self, trainer): 29 | self.trainer = trainer 30 | self.register = register 31 | -------------------------------------------------------------------------------- /trainer/plugins/progress.py: -------------------------------------------------------------------------------- 1 | from .plugin import Plugin 2 | 3 | 4 | class ProgressMonitor(Plugin): 5 | stat_name = 'progress' 6 | 7 | def __init__(self): 8 | super(ProgressMonitor, self).__init__([(1, 'iteration'), (1, 'epoch')]) 9 | 10 | def register(self, trainer): 11 | self.trainer = trainer 12 | stats = self.trainer.stats.setdefault(self.stat_name, {}) 13 | stats['samples_used'] = 0 14 | stats['epoch_size'] = len(trainer.dataset) 15 | stats['log_iter_fields'] = [ 16 | '{samples_used}/{epoch_size}', 17 | '({percent:.2f}%)' 18 | ] 19 | 20 | def iteration(self, iteration, input, *args): 21 | stats = self.trainer.stats.setdefault(self.stat_name, {}) 22 | stats['samples_used'] += 1 23 | stats['percent'] = 100. * stats['samples_used'] / stats['epoch_size'] 24 | 25 | def epoch(self, *args): 26 | stats = self.trainer.stats.setdefault(self.stat_name, {}) 27 | stats['samples_used'] = 0 28 | stats['percent'] = 0 29 | -------------------------------------------------------------------------------- /trainer/plugins/saver.py: -------------------------------------------------------------------------------- 1 | from .plugin import Plugin 2 | import torch 3 | 4 | class Saver(Plugin): 5 | stat_name = 'progress' 6 | 7 | def __init__(self, filename, interval=(1, 'epoch'), should_save=lambda x: True): 8 | ''' 9 | Args: 10 | filename: The filename to save the model under. This can contain format 11 | variables, and filename.format(**save_dict) will be called. 12 | should_save: A function which takes in 'save_dict' and returns either False, 13 | if it should not save, or the either True (in which case it will save 14 | under 'filename'), or alternatively the filename can be returned. 15 | 16 | should_save will have access to the 'trainer', which also gives it access 17 | to all of the saved stats, which can be used to make the decision whether 18 | to save. 19 | ''' 20 | super(Saver, self).__init__([interval]) 21 | self.filename = filename 22 | self.should_save = should_save 23 | 24 | def register(self, trainer): 25 | self.trainer = trainer 26 | stats = self.trainer.stats.setdefault(self.stat_name, {}) 27 | stats['samples_used'] = 0 28 | stats['epoch_size'] = len(trainer.dataset) 29 | stats['log_iter_fields'] = [ 30 | '{samples_used}/{epoch_size}', 31 | '({percent:.2f}%)' 32 | ] 33 | 34 | def make_param_dict(self): 35 | param_dict = { 36 | 'state_dict': self.trainer.model.state_dict(), 37 | 'optimizer' : self.trainer.optimizer.state_dict(), 38 | 'saved_stats': self.trainer.statse 39 | } 40 | if not param_dict: 41 | param_dict = { 42 | self.trainer.model 43 | } 44 | torch.save(param_dict, self.filename) 45 | 46 | def iteration(self, iteration, input, *args): 47 | 48 | stats = self.trainer.stats.setdefault(self.stat_name, {}) 49 | stats['samples_used'] += 1 50 | stats['percent'] = 100. * stats['samples_used'] / stats['epoch_size'] 51 | 52 | def epoch(self, epoch_num): 53 | 54 | stats = self.trainer.stats.setdefault(self.stat_name, {}) 55 | stats['samples_used'] = 0 56 | stats['percent'] = 0 57 | -------------------------------------------------------------------------------- /trainer/plugins/time.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | import time 3 | 4 | from .monitor import Monitor 5 | 6 | 7 | class TimeMonitor(Monitor): 8 | stat_name = 'time' 9 | 10 | def __init__(self, *args, **kwargs): 11 | kwargs.setdefault('unit', 'ms') 12 | kwargs.setdefault('precision', 0) 13 | super(TimeMonitor, self).__init__(*args, **kwargs) 14 | self.last_time = None 15 | 16 | def _get_value(self, *args): 17 | if self.last_time: 18 | now = time.time() 19 | duration = now - self.last_time 20 | self.last_time = now 21 | return duration * 1000 22 | else: 23 | self.last_time = time.time() 24 | return 0 25 | -------------------------------------------------------------------------------- /trainer/plugins/visdom_logger.py: -------------------------------------------------------------------------------- 1 | """ Logging to Visdom server """ 2 | from collections import defaultdict 3 | import numpy as np 4 | import visdom 5 | 6 | from .plugin import Plugin 7 | from .logger import Logger 8 | 9 | 10 | class BaseVisdomLogger(Logger): 11 | ''' 12 | The base class for logging output to Visdom. 13 | 14 | ***THIS CLASS IS ABSTRACT AND MUST BE SUBCLASSED*** 15 | 16 | Note that the Visdom server is designed to also handle a server architecture, 17 | and therefore the Visdom server must be running at all times. The server can 18 | be started with 19 | $ python -m visdom.server 20 | and you probably want to run it from screen or tmux. 21 | ''' 22 | _viz = visdom.Visdom() 23 | 24 | @property 25 | def viz(self): 26 | return type(self)._viz 27 | 28 | def __init__(self, fields, interval=None, win=None, env=None, opts={}): 29 | super(BaseVisdomLogger, self).__init__(fields, interval) 30 | self.win = win 31 | self.env = env 32 | self.opts = opts 33 | 34 | def log(self, *args, **kwargs): 35 | raise NotImplementedError("log not implemented for BaseVisdomLogger, which is an abstract class.") 36 | 37 | def _viz_prototype(self, vis_fn): 38 | ''' Outputs a function which will log the arguments to Visdom in an appropriate way. 39 | 40 | Args: 41 | vis_fn: A function, such as self.vis.image 42 | ''' 43 | def _viz_logger(*args, **kwargs): 44 | self.win = vis_fn(*args, 45 | win=self.win, 46 | env=self.env, 47 | opts=self.opts, 48 | **kwargs) 49 | return _viz_logger 50 | 51 | def _log_all(self, log_fields, prefix=None, suffix=None, require_dict=False): 52 | ''' Gathers the stats form self.trainer.stats and passes them into self.log, as a list ''' 53 | results = [] 54 | for field_idx, field in enumerate(self.fields): 55 | parent, stat = None, self.trainer.stats 56 | for f in field: 57 | parent, stat = stat, stat[f] 58 | results.append(stat) 59 | self.log(*results) 60 | 61 | def epoch(self, epoch_idx): 62 | super(BaseVisdomLogger, self).epoch(epoch_idx) 63 | self.viz.save() 64 | 65 | class VisdomSaver(Plugin): 66 | ''' Serialize the state of the Visdom server to disk. 67 | Unless you have a fancy schedule, where different are saved with different frequencies, 68 | you probably only need one of these. 69 | ''' 70 | 71 | def __init__(self, envs=None, interval=[(1, 'epoch')]): 72 | super(VisdomSaver, self).__init__(interval) 73 | self.envs = envs 74 | self.viz = visdom.Visdom() 75 | for _, name in interval: 76 | setattr(self, name, self.save) 77 | 78 | def register(self, trainer): 79 | self.trainer = trainer 80 | 81 | def save(self, *args, **kwargs): 82 | self.viz.save(self.envs) 83 | 84 | 85 | class VisdomLogger(BaseVisdomLogger): 86 | ''' 87 | A generic Visdom class that works with the majority of Visdom plot types. 88 | ''' 89 | 90 | def __init__(self, plot_type, fields, interval=None, win=None, env=None, opts={}): 91 | ''' 92 | Args: 93 | plot_type: The name of the plot type, in Visdom 94 | fields: The fields to log. May either be the name of some stat (e.g. ProgressMonitor) 95 | will have `stat_name='progress'`, in which case all of the fields under 96 | `log_HOOK_fields` will be logged. Finer-grained control can be specified 97 | by using individual fields such as `progress.percent`. 98 | interval: A List of 2-tuples where each tuple contains (k, HOOK_TIME). 99 | k (int): The logger will be called every 'k' HOOK_TIMES 100 | HOOK_TIME (string): The logger will be called at the given hook 101 | 102 | Examples: 103 | >>> # Image example 104 | >>> img_to_use = skimage.data.coffee().swapaxes(0,2).swapaxes(1,2) 105 | >>> image_plug = ConstantMonitor(img_to_use, "image") 106 | >>> image_logger = VisdomLogger('image', ["image.data"], [(2, 'iteration')]) 107 | 108 | >>> # Histogram example 109 | >>> hist_plug = ConstantMonitor(np.random.rand(10000), "random") 110 | >>> hist_logger = VisdomLogger('histogram', ["random.data"], [(2, 'iteration')], opts=dict(title='Random!', numbins=20)) 111 | ''' 112 | super(VisdomLogger, self).__init__(fields, interval, win, env, opts) 113 | self.plot_type = plot_type 114 | self.chart = getattr(self.viz, plot_type) 115 | self.viz_logger = self._viz_prototype(self.chart) 116 | 117 | def log(self, *args, **kwargs): 118 | self.viz_logger(*args, **kwargs) 119 | 120 | 121 | class VisdomPlotLogger(BaseVisdomLogger): 122 | 123 | def __init__(self, plot_type, fields, interval=None, win=None, env=None, opts={}): 124 | ''' 125 | Args: 126 | plot_type: {scatter, line} 127 | 128 | Examples: 129 | >>> train = Trainer(model, criterion, optimizer, dataset) 130 | >>> progress_m = ProgressMonitor() 131 | >>> scatter_logger = VisdomScatterLogger(["progress.samples_used", "progress.percent"], [(2, 'iteration')]) 132 | >>> train.register_plugin(progress_m) 133 | >>> train.register_plugin(scatter_logger) 134 | ''' 135 | super(VisdomPlotLogger, self).__init__(fields, interval, win, env, opts) 136 | valid_plot_types = { 137 | "scatter": self.viz.scatter, 138 | "line": self.viz.line } 139 | 140 | # Set chart type 141 | if 'plot_type' in self.opts: 142 | if plot_type not in valid_plot_types.keys(): 143 | raise ValueError("plot_type \'{}\' not found. Must be one of {}".format( 144 | plot_type, valid_plot_types.keys())) 145 | self.chart = valid_plot_types[plot_type] 146 | else: 147 | self.chart = self.viz.scatter 148 | 149 | def log(self, *args, **kwargs): 150 | if self.win is not None: 151 | if len(args) != 2: 152 | raise ValueError("When logging to {}, must pass in x and y values (and optionally z).".format( 153 | type(self))) 154 | x, y = args 155 | self.viz.updateTrace( 156 | X=np.array([x]), 157 | Y=np.array([y]), 158 | win=self.win, 159 | env=self.env, 160 | opts=self.opts) 161 | else: 162 | self.win = self.chart( 163 | X=np.array([args]), 164 | win=self.win, 165 | env=self.env, 166 | opts=self.opts) 167 | 168 | 169 | class VisdomTextLogger(BaseVisdomLogger): 170 | ''' 171 | Creates a text window in visdom and logs output to it. 172 | The output can be formatted with fancy HTML, and it new output can 173 | be set to 'append' or 'replace' mode. 174 | ''' 175 | valid_update_types = ['REPLACE', 'APPEND'] 176 | 177 | def __init__(self, fields, interval=None, win=None, env=None, opts={}, update_type=valid_update_types[0]): 178 | ''' 179 | Args: 180 | fields: The fields to log. May either be the name of some stat (e.g. ProgressMonitor) 181 | will have `stat_name='progress'`, in which case all of the fields under 182 | `log_HOOK_fields` will be logged. Finer-grained control can be specified 183 | by using individual fields such as `progress.percent`. 184 | interval: A List of 2-tuples where each tuple contains (k, HOOK_TIME). 185 | k (int): The logger will be called every 'k' HOOK_TIMES 186 | HOOK_TIME (string): The logger will be called at the given hook 187 | update_type: One of {'REPLACE', 'APPEND'}. Default 'REPLACE'. 188 | 189 | Examples: 190 | >>> progress_m = ProgressMonitor() 191 | >>> logger = VisdomTextLogger(["progress"], [(2, 'iteration')]) 192 | >>> train.register_plugin(progress_m) 193 | >>> train.register_plugin(logger) 194 | ''' 195 | super(VisdomTextLogger, self).__init__(fields, interval, win, env, opts) 196 | self.text = '' 197 | 198 | if update_type not in self.valid_update_types: 199 | raise ValueError("update type '{}' not found. Must be one of {}".format(update_type, self.valid_update_types)) 200 | self.update_type = update_type 201 | 202 | self.viz_logger = self._viz_prototype(self.viz.text) 203 | 204 | 205 | def log(self, msg, *args, **kwargs): 206 | text = msg 207 | if self.update_type == 'APPEND' and self.text: 208 | self.text = "
".join([self.text, text]) 209 | else: 210 | self.text = text 211 | self.viz_logger([self.text]) 212 | 213 | def _log_all(self, log_fields, prefix=None, suffix=None, require_dict=False): 214 | results = [] 215 | for field_idx, field in enumerate(self.fields): 216 | parent, stat = None, self.trainer.stats 217 | for f in field: 218 | parent, stat = stat, stat[f] 219 | name, output = self._gather_outputs(field, log_fields, 220 | parent, stat, require_dict) 221 | if not output: 222 | continue 223 | self._align_output(field_idx, output) 224 | results.append((name, output)) 225 | if not results: 226 | return 227 | output = self._join_results(results) 228 | if prefix is not None: 229 | self.log(prefix) 230 | self.log(output) 231 | if suffix is not None: 232 | self.log(suffix) 233 | -------------------------------------------------------------------------------- /trainer/trainer.py: -------------------------------------------------------------------------------- 1 | import heapq 2 | from torch.autograd import Variable 3 | 4 | 5 | class Trainer(object): 6 | ''' 7 | The Trainer is a lightweight wrapper to encapsulate a typical training 8 | loop. 9 | It contains 4 callbacks in (accesssible from trainer.get_plugin_queues()). 10 | Plugins (and subclasses thereof, or anything that provides the required 11 | methods) can be registered to trigger at these predefined times. 12 | ''' 13 | 14 | def __init__(self, model=None, criterion=None, optimizer=None, dataset=None): 15 | self.model = model 16 | self.criterion = criterion 17 | self.optimizer = optimizer 18 | self.dataset = dataset 19 | self.iterations = 0 20 | self.stats = {} 21 | self.plugin_queues = { 22 | 'iteration': [], 23 | 'epoch': [], 24 | 'batch': [], 25 | 'update': [], 26 | } 27 | 28 | def get_plugin_queues(self): 29 | return self.plugin_queues 30 | 31 | def register_plugin(self, plugin): 32 | plugin.register(self) 33 | 34 | intervals = plugin.trigger_interval 35 | if not isinstance(intervals, list): 36 | intervals = [intervals] 37 | for duration, unit in intervals: 38 | queue = self.plugin_queues[unit] 39 | queue.append((duration, len(queue), plugin)) 40 | 41 | def call_plugins(self, queue_name, time, *args): 42 | args = (time,) + args 43 | queue = self.plugin_queues[queue_name] 44 | if len(queue) == 0: 45 | return 46 | while queue[0][0] <= time: 47 | plugin = queue[0][2] 48 | getattr(plugin, queue_name)(*args) 49 | for trigger in plugin.trigger_interval: 50 | if trigger[1] == queue_name: 51 | interval = trigger[0] 52 | new_item = (time + interval, queue[0][1], plugin) 53 | heapq.heappushpop(queue, new_item) 54 | 55 | def run(self, epochs=1): 56 | for q in self.plugin_queues.values(): 57 | heapq.heapify(q) 58 | 59 | for i in range(1, epochs + 1): 60 | self.train() 61 | self.call_plugins('epoch', i) 62 | 63 | def train(self): 64 | for i, data in enumerate(self.dataset, self.iterations + 1): 65 | batch_input, batch_target = data 66 | self.call_plugins('batch', i, batch_input, batch_target) 67 | input_var = Variable(batch_input) 68 | target_var = Variable(batch_target) 69 | 70 | plugin_data = [None, None] 71 | 72 | def closure(): 73 | batch_output = self.model(input_var) 74 | loss = self.criterion(batch_output, target_var) 75 | loss.backward() 76 | if plugin_data[0] is None: 77 | plugin_data[0] = batch_output.data 78 | plugin_data[1] = loss.data 79 | return loss 80 | 81 | self.optimizer.zero_grad() 82 | self.optimizer.step(closure) 83 | self.call_plugins('iteration', i, batch_input, batch_target, 84 | *plugin_data) 85 | self.call_plugins('update', i, self.model) 86 | 87 | self.iterations += i 88 | --------------------------------------------------------------------------------