├── .gitignore ├── README.md ├── aggmo.py ├── src ├── aggmo.py ├── config.py ├── configs │ ├── ae.json │ ├── cifar-10.json │ ├── cifar-100.json │ └── templates │ │ └── optim │ │ ├── adam.json │ │ ├── aggmo.json │ │ ├── exp_aggmo.json │ │ ├── nesterov.json │ │ └── sgd.json ├── engine.py ├── logger.py ├── main.py ├── models │ ├── __init__.py │ ├── ae.py │ ├── base.py │ ├── nnet.py │ └── resnet.py └── utils.py └── tensorflow ├── AggMo-Test.ipynb └── aggmo.py /.gitignore: -------------------------------------------------------------------------------- 1 | **/__pycache__ 2 | **/*.pyc 3 | 4 | **/.ipynb_checkpoints -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Aggregated Momentum 2 | 3 | This repository contains code to reproduce the experiments from ["Aggregated Momentum: Stability Through Passive Damping"](https://arxiv.org/abs/1804.00325). 4 | 5 | Both pytorch and tensorflow implementations of the AggMo optimizer are included. 6 | 7 | ## AggMo Optimizer 8 | 9 | ### Pytorch 10 | 11 | The `aggmo.py` file provides a pytorch implementation of AggMo. The optimizer can be constructed as follows: 12 | 13 | ```python 14 | optimizer = aggmo.AggMo(model.parameters(), lr, betas=[0, 0.9, 0.99]) 15 | ``` 16 | 17 | The AggMo class also has an "exponential form" constructor. In this case the damping vector is specified by two hyparameters, `K` - the number of beta values, and `a` - the exponential scale factor. For i=0...K-1 , each beta_i = 1 - a^i . 18 | The following is equivalent to using the beta values [0, 0.9, 0.99]: 19 | 20 | ```python 21 | optimizer = aggmo.AggMo.from_exp_form(model.parameters(), lr, a=0.1, k=3) 22 | ``` 23 | 24 | ### Tensorflow 25 | 26 | There is also a tensorflow implementation within the `tensorflow` folder. **This version has not been carefully tested**. 27 | 28 | ```python 29 | optimizer = aggmo.AggMo(lr, betas=[0, 0.9, 0.99]) 30 | ``` 31 | 32 | Or using the exponential form: 33 | 34 | ```python 35 | optimizer = aggmo.AggMo.from_exp_form(lr, a=0.1, k=3) 36 | ``` 37 | 38 | ## Running Experiments 39 | 40 | Code to run experiments can be found in the `src` directory. Each task and optimizer has their own config file which can be easily overridden from the command line. 41 | 42 | The first argument points to the task configuration. The optimizer is specified with `--optim `. Additional config overrides can be given after `-o` in the format e.g. `-o optim.lr_schedule.lr_decay=0.5`. 43 | 44 | _The optimizer configs do not provide optimal hyperparameters for every task._ 45 | 46 | 47 | ### Autoencoders 48 | 49 | From the `src` directory: 50 | 51 | ``` 52 | python main.py configs/ae.json --optim aggmo 53 | ``` 54 | 55 | ### Classification 56 | 57 | From the `src` directory: 58 | 59 | ``` 60 | python main.py configs/cifar-10.json --optim aggmo 61 | ``` 62 | 63 | 64 | ``` 65 | python main.py configs/cifar-100.json --optim aggmo 66 | ``` 67 | 68 | ### LSTMs 69 | 70 | The LSTM code is not directly included here. We made direct use of the [official code](https://github.com/salesforce/awd-lstm-lm) from ["Regularizing and Optimizing LSTM Language Models"](https://arxiv.org/abs/1708.02182). You can run these experiments by using the AggMo optimizer within this repository. The model hyperparameters used are detailed in the appendix. 71 | -------------------------------------------------------------------------------- /aggmo.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.optim.optimizer import Optimizer, required 3 | 4 | 5 | class AggMo(Optimizer): 6 | r"""Implements Aggregated Momentum Gradient Descent 7 | """ 8 | 9 | def __init__(self, params, lr=required, betas=[0.0, 0.9, 0.99], weight_decay=0): 10 | defaults = dict(lr=lr, betas=betas, weight_decay=weight_decay) 11 | super(AggMo, self).__init__(params, defaults) 12 | 13 | @classmethod 14 | def from_exp_form(cls, params, lr=required, a=0.1, k=3, weight_decay=0): 15 | betas = [1- a**i for i in range(k)] 16 | return cls(params, lr, betas, weight_decay) 17 | 18 | def __setstate__(self, state): 19 | super(AggMo, self).__setstate__(state) 20 | 21 | def step(self, closure=None): 22 | """Performs a single optimization step. 23 | Arguments: 24 | closure (callable, optional): A closure that reevaluates the model 25 | and returns the loss. 26 | """ 27 | loss = None 28 | if closure is not None: 29 | loss = closure() 30 | 31 | for group in self.param_groups: 32 | weight_decay = group['weight_decay'] 33 | betas = group['betas'] 34 | total_mom = float(len(betas)) 35 | 36 | for p in group['params']: 37 | if p.grad is None: 38 | continue 39 | d_p = p.grad.data 40 | if weight_decay != 0: 41 | d_p.add_(weight_decay, p.data) 42 | param_state = self.state[p] 43 | if 'momentum_buffer' not in param_state: 44 | param_state['momentum_buffer'] = {} 45 | for beta in betas: 46 | param_state['momentum_buffer'][beta] = torch.zeros_like(p.data) 47 | for beta in betas: 48 | buf = param_state['momentum_buffer'][beta] 49 | # import pdb; pdb.set_trace() 50 | buf.mul_(beta).add_(d_p) 51 | p.data.sub_(group['lr'] / total_mom , buf) 52 | return loss 53 | 54 | def zero_momentum_buffers(self): 55 | for group in self.param_groups: 56 | betas = group['betas'] 57 | for p in group['params']: 58 | param_state = self.state[p] 59 | param_state['momentum_buffer'] = {} 60 | for beta in betas: 61 | param_state['momentum_buffer'][beta] = torch.zeros_like(p.data) 62 | 63 | def update_hparam(self, name, value): 64 | for param_group in self.param_groups: 65 | param_group[name] = value -------------------------------------------------------------------------------- /src/aggmo.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.optim.optimizer import Optimizer, required 3 | 4 | 5 | class AggMo(Optimizer): 6 | r"""Implements Aggregated Momentum Gradient Descent 7 | """ 8 | 9 | def __init__(self, params, lr=required, betas=[0.0, 0.9, 0.99], weight_decay=0): 10 | defaults = dict(lr=lr, betas=betas, weight_decay=weight_decay) 11 | super(AggMo, self).__init__(params, defaults) 12 | 13 | @classmethod 14 | def from_exp_form(cls, params, lr=required, a=0.1, k=3, weight_decay=0): 15 | betas = [1- a**i for i in range(k)] 16 | return cls(params, lr, betas, weight_decay) 17 | 18 | def __setstate__(self, state): 19 | super(AggMo, self).__setstate__(state) 20 | 21 | def step(self, closure=None): 22 | """Performs a single optimization step. 23 | Arguments: 24 | closure (callable, optional): A closure that reevaluates the model 25 | and returns the loss. 26 | """ 27 | loss = None 28 | if closure is not None: 29 | loss = closure() 30 | 31 | for group in self.param_groups: 32 | weight_decay = group['weight_decay'] 33 | betas = group['betas'] 34 | total_mom = float(len(betas)) 35 | 36 | for p in group['params']: 37 | if p.grad is None: 38 | continue 39 | d_p = p.grad.data 40 | if weight_decay != 0: 41 | d_p.add_(weight_decay, p.data) 42 | param_state = self.state[p] 43 | if 'momentum_buffer' not in param_state: 44 | param_state['momentum_buffer'] = {} 45 | for beta in betas: 46 | param_state['momentum_buffer'][beta] = torch.zeros_like(p.data) 47 | for beta in betas: 48 | buf = param_state['momentum_buffer'][beta] 49 | # import pdb; pdb.set_trace() 50 | buf.mul_(beta).add_(d_p) 51 | p.data.sub_(group['lr'] / total_mom , buf) 52 | return loss 53 | 54 | def zero_momentum_buffers(self): 55 | for group in self.param_groups: 56 | betas = group['betas'] 57 | for p in group['params']: 58 | param_state = self.state[p] 59 | param_state['momentum_buffer'] = {} 60 | for beta in betas: 61 | param_state['momentum_buffer'][beta] = torch.zeros_like(p.data) 62 | 63 | def update_hparam(self, name, value): 64 | for param_group in self.param_groups: 65 | param_group[name] = value -------------------------------------------------------------------------------- /src/config.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import json 3 | 4 | import collections 5 | 6 | from jinja2 import Environment, FileSystemLoader, StrictUndefined 7 | 8 | def update(d, u): 9 | for k, v in u.items(): 10 | if isinstance(v, collections.Mapping): 11 | d[k] = update(d.get(k, {}), v) 12 | else: 13 | if '+' in v: 14 | v = [float(x) for x in v.split('+')] 15 | try: 16 | d[k] = type(d[k])(v) 17 | except (TypeError, ValueError) as e: 18 | raise TypeError(e) # types not compatible 19 | except KeyError as e: 20 | d[k] = v # No matching key in dict 21 | return d 22 | 23 | 24 | class ConfigParse(argparse.Action): 25 | def __call__(self, parser, namespace, values, option_string=None): 26 | options_dict = {} 27 | for overrides in values.split(','): 28 | k, v = overrides.split('=') 29 | k_parts = k.split('.') 30 | dic = options_dict 31 | for key in k_parts[:-1]: 32 | dic = dic.setdefault(key, {}) 33 | dic[k_parts[-1]] = v 34 | setattr(namespace, self.dest, options_dict) 35 | 36 | 37 | def get_config_overrides(): 38 | parser = argparse.ArgumentParser(description='Experiments for aggregated momentum') 39 | parser.add_argument('config', help='Base config file') 40 | parser.add_argument('-o', action=ConfigParse, help='Config option overrides. Comma separated, e.g. optim.lr_init=1.0,optim.lr_decay=0.1') 41 | args, template_args = parser.parse_known_args() 42 | template_dict = dict(zip(template_args[:-1:2], template_args[1::2])) 43 | template_dict = { k.lstrip('-'): v for k,v in template_dict.items() } 44 | return args,template_dict 45 | 46 | def process_config(verbose=True): 47 | args, template_args = get_config_overrides() 48 | 49 | with open(args.config, 'r') as f: 50 | template = f.read() 51 | 52 | env = Environment(loader=FileSystemLoader('configs/templates/'), 53 | undefined=StrictUndefined) 54 | 55 | config = json.loads(env.from_string(template).render(**template_args)) 56 | 57 | if args.o is not None: 58 | print(args.o) 59 | config = update(config, args.o) 60 | 61 | if verbose: 62 | import pprint 63 | pp = pprint.PrettyPrinter() 64 | print('-------- Config --------') 65 | pp.pprint(config) 66 | print('------------------------') 67 | return config 68 | -------------------------------------------------------------------------------- /src/configs/ae.json: -------------------------------------------------------------------------------- 1 | { 2 | "optim": { 3 | "optimizer": { {% include 'optim/%s.json' % optim %} }, 4 | "lr_schedule": { 5 | "name": "step", 6 | "lr_decay": 0.1, 7 | "milestones": [200, 400, 800], 8 | "last_epoch": -1 9 | }, 10 | "epochs": 1000, 11 | "batch_size": 200, 12 | "wdecay": 0.0, 13 | "criterion": { 14 | "tag": "mse", 15 | "minmax": "min" 16 | }, 17 | "finetune": { 18 | "epochs": 0, 19 | "final_mom": 0.9, 20 | "warm": false 21 | }, 22 | "patience": 250 23 | }, 24 | "model": { 25 | "name": "ce_fc_ae", 26 | "layers": [1000, 500, 250, 30], 27 | "activation": "relu" 28 | }, 29 | "data": { 30 | "name": "mnist", 31 | "root": "data", 32 | "transform": false, 33 | "validation": true, 34 | "train_size": 0.9, 35 | "input_dim": 784, 36 | "num_workers": 0, 37 | "class_count": 10 38 | }, 39 | "cuda": true, 40 | "task": "encoder", 41 | "output_root": "out/mnist_ae", 42 | "exp_name": "ae" 43 | } -------------------------------------------------------------------------------- /src/configs/cifar-10.json: -------------------------------------------------------------------------------- 1 | { 2 | "optim": { 3 | "optimizer": { {% include 'optim/%s.json' % optim %} }, 4 | "lr_schedule": { 5 | "name": "step", 6 | "lr_decay": 0.1, 7 | "milestones": [150, 250], 8 | "last_epoch": -1 9 | }, 10 | "epochs": 400, 11 | "batch_size": 128, 12 | "wdecay": 5e-4, 13 | "criterion": { 14 | "tag": "acc", 15 | "minmax": "max" 16 | }, 17 | "finetune": { 18 | "epochs": 0, 19 | "final_mom": 0.9, 20 | "warm": false 21 | }, 22 | "patience": 250 23 | }, 24 | "model": { 25 | "name": "resnet34" 26 | }, 27 | "data": { 28 | "name": "cifar10", 29 | "root": "data", 30 | "transform": { 31 | "name": "cifar", 32 | "norm_mean": [0.4914, 0.4822, 0.4465], 33 | "norm_std": [0.2023, 0.1994, 0.2010] 34 | }, 35 | "validation": true, 36 | "train_size": 0.9, 37 | "input_dim": 3072, 38 | "num_workers": 0, 39 | "class_count": 10 40 | }, 41 | "cuda": true, 42 | "task": "classify", 43 | "output_root": "out/cifar-10", 44 | "exp_name": "classify" 45 | } -------------------------------------------------------------------------------- /src/configs/cifar-100.json: -------------------------------------------------------------------------------- 1 | { 2 | "optim": { 3 | "optimizer": { {% include 'optim/%s.json' % optim %} }, 4 | "lr_schedule": { 5 | "name": "step", 6 | "lr_decay": 0.1, 7 | "milestones": [150, 250], 8 | "last_epoch": -1 9 | }, 10 | "epochs": 400, 11 | "batch_size": 128, 12 | "wdecay": 5e-4, 13 | "criterion": { 14 | "tag": "acc", 15 | "minmax": "max" 16 | }, 17 | "finetune": { 18 | "epochs": 0, 19 | "final_mom": 0.9, 20 | "warm": false 21 | }, 22 | "patience": 250 23 | }, 24 | "model": { 25 | "name": "resnet34" 26 | }, 27 | "data": { 28 | "name": "cifar100", 29 | "root": "data", 30 | "transform": { 31 | "name": "cifar", 32 | "norm_mean": [0.5071, 0.4867, 0.4408], 33 | "norm_std": [0.2675, 0.2565, 0.2761] 34 | }, 35 | "validation": true, 36 | "train_size": 0.9, 37 | "input_dim": 3072, 38 | "num_workers": 0, 39 | "class_count": 100 40 | }, 41 | "cuda": true, 42 | "task": "classify", 43 | "output_root": "out/cifar_100", 44 | "exp_name": "classify" 45 | } -------------------------------------------------------------------------------- /src/configs/templates/optim/adam.json: -------------------------------------------------------------------------------- 1 | "name": "adam", 2 | "lr": 0.001, 3 | "beta1": 0.9, 4 | "beta2": 0.999 -------------------------------------------------------------------------------- /src/configs/templates/optim/aggmo.json: -------------------------------------------------------------------------------- 1 | "name": "aggmo", 2 | "lr": 0.01, 3 | "betas": [0.0, 0.9, 0.99] 4 | -------------------------------------------------------------------------------- /src/configs/templates/optim/exp_aggmo.json: -------------------------------------------------------------------------------- 1 | "name": "aggmo_exp", 2 | "lr": 0.01, 3 | "a": 0.1, 4 | "K": 3 5 | -------------------------------------------------------------------------------- /src/configs/templates/optim/nesterov.json: -------------------------------------------------------------------------------- 1 | "name": "nesterov", 2 | "lr": 0.01, 3 | "momentum": 0.9 4 | -------------------------------------------------------------------------------- /src/configs/templates/optim/sgd.json: -------------------------------------------------------------------------------- 1 | "name": "sgd", 2 | "lr": 0.01, 3 | "momentum": 0.9 4 | -------------------------------------------------------------------------------- /src/engine.py: -------------------------------------------------------------------------------- 1 | ''' 2 | Based on code from https://github.com/pytorch/tnt/blob/master/torchnet/engine/engine.py 3 | 4 | Edited by Jake Snell 5 | 6 | (Minor tweaks by James Lucas) 7 | 8 | ''' 9 | class Engine(object): 10 | def __init__(self): 11 | self.hooks = {} 12 | 13 | def hook(self, name, state): 14 | if name in self.hooks: 15 | self.hooks[name](state) 16 | 17 | def train(self, model, iterator, maxepoch, optimizer): 18 | state = { 19 | 'model': model, 20 | 'iterator': iterator, 21 | 'maxepoch': maxepoch, 22 | 'optimizer': optimizer, 23 | 'epoch': 0, 24 | 't': 0, 25 | 'train': True, 26 | 'stop': False 27 | } 28 | model.train() 29 | self.hook('on_start', state) 30 | while state['epoch'] < state['maxepoch'] and not state['stop']: 31 | self.hook('on_start_epoch', state) 32 | for sample in state['iterator']: 33 | state['sample'] = sample 34 | self.hook('on_sample', state) 35 | 36 | def closure(): 37 | loss, output = state['model'].loss(state['sample']) 38 | state['output'] = output 39 | state['loss'] = loss 40 | loss.backward() 41 | self.hook('on_forward', state) 42 | # to free memory in save_for_backward 43 | # state['output'] = None 44 | # state['loss'] = None 45 | return loss 46 | 47 | state['optimizer'].zero_grad() 48 | state['optimizer'].step(closure) 49 | self.hook('on_update', state) 50 | state['t'] += 1 51 | state['epoch'] += 1 52 | self.hook('on_end_epoch', state) 53 | self.hook('on_end', state) 54 | return state 55 | 56 | def test(self, model, iterator): 57 | state = { 58 | 'model': model, 59 | 'iterator': iterator, 60 | 't': 0, 61 | 'train': False, 62 | } 63 | 64 | model.eval() 65 | self.hook('on_start', state) 66 | for sample in state['iterator']: 67 | state['sample'] = sample 68 | self.hook('on_sample', state) 69 | 70 | def closure(): 71 | loss, output = state['model'].loss(state['sample'], test=True) 72 | state['output'] = output 73 | state['loss'] = loss 74 | self.hook('on_forward', state) 75 | # to free memory in save_for_backward 76 | # state['output'] = None 77 | # state['loss'] = None 78 | 79 | closure() 80 | state['t'] += 1 81 | self.hook('on_end', state) 82 | # Put back into training mode! 83 | model.train() 84 | return state 85 | -------------------------------------------------------------------------------- /src/logger.py: -------------------------------------------------------------------------------- 1 | import os, errno 2 | import sys 3 | import json 4 | 5 | 6 | class Logger(object): 7 | ''' 8 | Base Logger object 9 | 10 | Initializes the log directory and creates log files given by name in arguments. 11 | Can be used to append future log values to each file. 12 | ''' 13 | 14 | def __init__(self, log_dir, *args): 15 | self.log_dir = log_dir 16 | 17 | try: 18 | os.makedirs(log_dir) 19 | except OSError as e: 20 | if e.errno != errno.EEXIST: 21 | raise 22 | 23 | with open(os.path.join(self.log_dir, 'cmd.txt'), 'w') as f: 24 | f.write(" ".join(sys.argv)) 25 | 26 | self.log_names = [a for a in args] 27 | for arg in self.log_names: 28 | setattr(self, 'log_{}'.format(arg), lambda epoch,value,name=arg: self.log(name, epoch, value)) 29 | self.init_logfile(arg) 30 | 31 | def log_config(self, config): 32 | with open(os.path.join(self.log_dir, 'config.json'), 'w') as f: 33 | json.dump(config, f) 34 | 35 | def init_logfile(self, name): 36 | fname = self.get_log_fname(name) 37 | 38 | with open(fname, 'w') as log_file: 39 | log_file.write("epoch,{}\n".format(name)) 40 | 41 | def get_log_fname(self, name): 42 | return os.path.join(self.log_dir, '{}.log'.format(name)) 43 | 44 | def log(self, name, epoch, value): 45 | if name not in self.log_names: 46 | self.init_logfile(name) 47 | self.log_names.append(name) 48 | fname = self.get_log_fname(name) 49 | 50 | with open(fname, 'a') as log_file: 51 | log_file.write("{},{}\n".format(epoch, value)) 52 | 53 | def log_test_value(self, name, value): 54 | test_name = 'test_' + name 55 | self.init_logfile(test_name) 56 | self.log(test_name, 0, value) 57 | -------------------------------------------------------------------------------- /src/main.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | from logger import Logger 4 | from config import process_config 5 | from engine import Engine 6 | from models import get_model 7 | 8 | from utils import * 9 | 10 | from functools import partial 11 | 12 | import numpy as np 13 | from tqdm import tqdm 14 | 15 | import datetime 16 | import os, errno 17 | 18 | def get_experiment_name(config): 19 | now = datetime.datetime.now() 20 | 21 | base_exp_name = config['exp_name'] 22 | task_name = config['task'] 23 | data_name = config['data']['name'] 24 | optim_name = config['optim']['optimizer']['name'] 25 | 26 | exp_name = "{}_{}_{}_{}_{}".format(base_exp_name, task_name, data_name, optim_name, now.strftime("%Y%m%d_%H-%M-%S-%f")) 27 | return exp_name 28 | 29 | def train(model, loaders, config): 30 | exp_dir = os.path.join(config['output_root'], get_experiment_name(config)) 31 | log_dir = os.path.join(exp_dir, 'logs') 32 | model_dir = os.path.join(exp_dir, 'checkpoints') 33 | best_model_path = os.path.join(model_dir, 'best_model.pt') 34 | img_dir = os.path.join(exp_dir, 'imgs') 35 | 36 | model.cuda() 37 | 38 | optimizer = get_optimizer(config, model.parameters()) 39 | scheduler = get_scheduler(config, optimizer) 40 | 41 | logger = Logger(log_dir) 42 | logger.log_config(config) 43 | 44 | engine = Engine() 45 | 46 | if config['optim']['criterion']['minmax'] == 'min': 47 | best_val = np.inf 48 | else: 49 | best_val = -np.inf 50 | 51 | def log_meters(prefix, logger, state): 52 | if 'epoch' in state: 53 | epoch = state['epoch'] 54 | else: 55 | epoch = 0 56 | for tag, meter in state['model'].meters.items(): 57 | file_id = '{}_{}'.format(prefix, tag) 58 | logger.log(file_id, epoch, meter.value()[0]) 59 | 60 | def save_best_model(state, best_val): 61 | criterion = config['optim']['criterion'] 62 | new_best = False 63 | for tag, meter in state['model'].meters.items(): 64 | if tag == criterion['tag']: 65 | new_val = meter.value()[0] 66 | if criterion['minmax'] == 'min': 67 | if new_val < best_val: 68 | best_val = new_val 69 | new_best = True 70 | else: 71 | if new_val > best_val: 72 | best_val = new_val 73 | new_best = True 74 | break 75 | if new_best: 76 | print('Saving new best model') 77 | save_model(state['model'], best_model_path) 78 | return best_val, new_best 79 | 80 | 81 | def on_sample(state): 82 | if config['cuda']: 83 | state['sample'] = [x.cuda() for x in state['sample']] 84 | 85 | def on_forward(state): 86 | state['model'].add_to_meters(state) 87 | 88 | def on_start(state): 89 | state['loader'] = state['iterator'] 90 | 91 | def on_start_epoch(state): 92 | state['model'].reset_meters() 93 | state['iterator'] = tqdm(state['loader'], desc='Epoch {}'.format(state['epoch'])) 94 | 95 | def on_end_epoch(hook_state, state): 96 | scheduler.step() 97 | print("Training loss: {:.4f}".format(state['model'].meters['loss'].value()[0])) 98 | log_meters('train', logger, state) 99 | 100 | if ('reconstruction' in state['output']) and (state['epoch'] % 20 == 0): 101 | recon = state['output']['reconstruction'].data.view(-1, 28, 28).unsqueeze(1) 102 | save_imgs(recon, 'reconstruction_{}.jpg'.format(state['epoch']), img_dir) 103 | 104 | if state['epoch'] % 20 == 0: 105 | save_path = os.path.join(model_dir, "model_{}.pt".format(state['epoch'])) 106 | save_model(model, save_path) 107 | 108 | # do validation at the end of each epoch 109 | if config['data']['validation']: 110 | state['model'].reset_meters() 111 | engine.test(model, loaders['validation']) 112 | print("Val loss: {:.4f}".format(state['model'].meters['loss'].value()[0])) 113 | log_meters('val', logger, state) 114 | hook_state['best_val'], new_best = save_best_model(state, hook_state['best_val']) 115 | if new_best: 116 | hook_state['wait'] = 0 117 | else: 118 | hook_state['wait'] += 1 119 | if hook_state['wait'] > config['optim']['patience']: 120 | state['stop'] = True 121 | 122 | if state['epoch'] == (config['optim']['epochs'] - config['optim']['finetune']['epochs']): 123 | print('Momentum fine tuning for last stage') 124 | if config['optim']['finetune']['warm']: 125 | finetune_mom = config['optim']['finetune']['final_mom'] 126 | for group in state['optimizer'].param_groups: 127 | mom = [0.0, finetune_mom] if config['optim']['optimizer'].lower() == 'aggmo' else finetune_mom 128 | group['momentum'] = mom 129 | else: 130 | #TODO: The cold-start finetuning does not change the momentum 131 | state['optimizer'] = get_optimizer(config, state['model'].parameters()) 132 | for tag, meter in state['model'].meters.items(): 133 | file_id = 'pre_finetune_{}'.format(tag) 134 | logger.log(file_id, state['epoch'], meter.value()[0]) 135 | 136 | engine.hooks['on_start'] = on_start 137 | engine.hooks['on_sample'] = on_sample 138 | engine.hooks['on_forward'] = on_forward 139 | engine.hooks['on_start_epoch'] = on_start_epoch 140 | engine.hooks['on_end_epoch'] = partial(on_end_epoch, {'best_val': best_val, 'wait': 0}) 141 | engine.train(model, loaders['train'], maxepoch=config['optim']['epochs'], optimizer=optimizer) 142 | 143 | model.reset_meters() 144 | if os.path.exists(best_model_path): 145 | model.load_state_dict(torch.load(best_model_path)) 146 | if loaders['test'] is not None: 147 | log_meters('test', logger, engine.test(model, loaders['test'])) 148 | return model 149 | 150 | if __name__ == '__main__': 151 | config = process_config() 152 | model_init = get_model(config) 153 | loaders = load_data(config) 154 | model = train(model_init, loaders, config) 155 | -------------------------------------------------------------------------------- /src/models/__init__.py: -------------------------------------------------------------------------------- 1 | MODEL_REGISTRY = {} 2 | 3 | def register_model(model_name): 4 | def decorator(f): 5 | MODEL_REGISTRY[model_name] = f 6 | return f 7 | 8 | return decorator 9 | 10 | def get_model(config): 11 | model_name = config['model']['name'] 12 | if model_name in MODEL_REGISTRY: 13 | return MODEL_REGISTRY[model_name](config) 14 | else: 15 | raise ValueError("Unknown model {:s}".format(model_name)) 16 | 17 | import models.ae 18 | import models.nnet 19 | import models.resnet 20 | -------------------------------------------------------------------------------- /src/models/ae.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | from models.base import AutoencoderModel 6 | from models.nnet import FCNet 7 | 8 | class AutoEncoder(torch.nn.Module): 9 | 10 | def __init__(self, encoder, decoder): 11 | super(AutoEncoder, self).__init__() 12 | 13 | self.encoder = encoder 14 | self.decoder = decoder 15 | self.input_size = self.encoder.input_dim 16 | 17 | def forward(self, x): 18 | x = self.encoder(x) 19 | x = self.decoder(x) 20 | return F.sigmoid(x) 21 | 22 | class FCEncoder(FCNet): 23 | 24 | def __init__(self, config): 25 | super(FCEncoder, self).__init__(config['model']['layers'].copy(), 26 | config['data']['input_dim'], 27 | config['model']['activation']) 28 | 29 | from models import register_model 30 | 31 | @register_model('ce_fc_ae') 32 | def load_ce_fc_ae(config): 33 | encoder = FCEncoder(config) 34 | 35 | layer_rev = config['model']['layers'][::-1] 36 | decoder_in_dim = layer_rev.pop(0) 37 | layer_rev.append(config['data']['input_dim']) 38 | decoder = FCNet(layer_rev, decoder_in_dim, config['model']['activation']) 39 | 40 | autoencoder = AutoEncoder(encoder, decoder) 41 | return AutoencoderModel(autoencoder) 42 | 43 | 44 | -------------------------------------------------------------------------------- /src/models/base.py: -------------------------------------------------------------------------------- 1 | from collections import OrderedDict 2 | 3 | import torch 4 | import torch.nn as nn 5 | import torch.nn.functional as F 6 | 7 | import torchnet as tnt 8 | 9 | from torch.autograd import Variable 10 | 11 | def batch_mean_mse(recon, inputs): 12 | return torch.sum(torch.mean((recon - inputs) ** 2, 0)) 13 | 14 | class ExperimentModel(nn.Module): 15 | 16 | def __init__(self, model): 17 | super(ExperimentModel, self).__init__() 18 | self.model = model 19 | self.init_meters() 20 | 21 | def forward(self, x): 22 | return self.model(x) 23 | 24 | def loss(self, sample, test=False): 25 | raise NotImplementedError 26 | 27 | def input_size(self): 28 | return self.model.input_size 29 | 30 | def init_meters(self): 31 | self.meters = OrderedDict([('loss', tnt.meter.AverageValueMeter())]) 32 | 33 | def reset_meters(self): 34 | for meter in self.meters.values(): 35 | meter.reset() 36 | 37 | def add_to_meters(self, state): 38 | self.meters['loss'].add(state['loss'].data.item()) 39 | 40 | class ClassificationModel(ExperimentModel): 41 | 42 | def init_meters(self): 43 | super(ClassificationModel, self).init_meters() 44 | self.meters['acc'] = tnt.meter.ClassErrorMeter(accuracy=True) 45 | 46 | def loss(self, sample, test=False): 47 | inputs = sample[0] 48 | targets = sample[1] 49 | o = self.model.forward(inputs) 50 | return F.cross_entropy(o, targets), {'logits': o} 51 | 52 | def add_to_meters(self, state): 53 | self.meters['loss'].add(state['loss'].data.item()) 54 | self.meters['acc'].add(state['output']['logits'].data, state['sample'][1]) 55 | 56 | class AutoencoderModel(ExperimentModel): 57 | 58 | def init_meters(self): 59 | super(AutoencoderModel, self).init_meters() 60 | self.meters['mse'] = tnt.meter.AverageValueMeter() 61 | 62 | def loss(self, sample, test=False): 63 | inputs = sample[0] 64 | reconstruction = self.model.forward(inputs) 65 | 66 | mse = batch_mean_mse(reconstruction, inputs.view(-1, self.input_size())) 67 | loss = F.binary_cross_entropy(reconstruction, inputs.view(-1, self.input_size())) 68 | return loss, {'reconstruction': reconstruction, 'mse': mse} 69 | 70 | def add_to_meters(self, state): 71 | self.meters['loss'].add(state['loss'].data.item()) 72 | self.meters['mse'].add(state['output']['mse'].data.item()) 73 | -------------------------------------------------------------------------------- /src/models/nnet.py: -------------------------------------------------------------------------------- 1 | ''' 2 | CIFAR-10 Classification 3 | ''' 4 | 5 | import torch 6 | import torch.nn as nn 7 | 8 | from models.base import ClassificationModel 9 | 10 | class FCNet(torch.nn.Module): 11 | def __init__(self, layers, input_dim, activation='relu'): 12 | super(FCNet, self).__init__() 13 | 14 | self.layer_sizes = layers.copy() 15 | self.input_dim = input_dim 16 | 17 | self.layer_sizes.insert(0, self.input_dim) 18 | 19 | if activation == 'relu': 20 | act_func = nn.ReLU 21 | elif activation == 'sigmoid': 22 | act_func = nn.Sigmoid 23 | else: 24 | raise Exception('Unexpected activation function. ReLU or Sigmoid supported.') 25 | 26 | 27 | layers = [nn.Linear(self.layer_sizes[0], self.layer_sizes[1])] 28 | 29 | for i in range(2, len(self.layer_sizes)): 30 | layers.append(act_func()) 31 | layers.append(nn.Linear(self.layer_sizes[i-1], self.layer_sizes[i])) 32 | 33 | 34 | self.model = nn.Sequential(*layers) 35 | 36 | def __len__(self): 37 | return len(self.model) 38 | 39 | def __getitem__(self, idx): 40 | self.model[idx] 41 | 42 | def forward(self, x): 43 | x = x.view(-1, self.input_dim) 44 | return self.model(x) 45 | 46 | class AlexNet(nn.Module): 47 | 48 | def __init__(self, class_count): 49 | super(AlexNet, self).__init__() 50 | self.features = nn.Sequential( 51 | nn.Conv2d(3, 64, kernel_size=11, stride=4, padding=5), 52 | nn.ReLU(inplace=True), 53 | nn.MaxPool2d(kernel_size=2, stride=2), 54 | nn.Conv2d(64, 192, kernel_size=5, padding=2), 55 | nn.ReLU(inplace=True), 56 | nn.MaxPool2d(kernel_size=2, stride=2), 57 | nn.Conv2d(192, 384, kernel_size=3, padding=1), 58 | nn.ReLU(inplace=True), 59 | nn.Conv2d(384, 256, kernel_size=3, padding=1), 60 | nn.ReLU(inplace=True), 61 | nn.Conv2d(256, 256, kernel_size=3, padding=1), 62 | nn.ReLU(inplace=True), 63 | nn.MaxPool2d(kernel_size=2, stride=2), 64 | ) 65 | self.fc = nn.Linear(256, class_count) 66 | 67 | def forward(self, x): 68 | x = x .view(-1, 3, 32, 32) 69 | x = self.features(x) 70 | x = x.view(x.size(0), -1) 71 | x = self.fc(x) 72 | return x 73 | 74 | 75 | from models import register_model 76 | 77 | @register_model('alexnet_cifar') 78 | def load_alexnet_classification(config): 79 | model = AlexNet(config['data']['class_count']) 80 | return ClassificationModel(model) -------------------------------------------------------------------------------- /src/models/resnet.py: -------------------------------------------------------------------------------- 1 | ''' 2 | Credit: github/kuangliu - https://github.com/kuangliu/pytorch-cifar/blob/master/models/resnet.py 3 | Modified by James Lucas. 4 | 5 | The code here can be used to create basic ResNets and Wide Resnets. 6 | 7 | 8 | ResNet in PyTorch. 9 | 10 | Reference: 11 | [1] Kaiming He, Xiangyu Zhang, Shaoqing Ren, Jian Sun 12 | Deep Residual Learning for Image Recognition. arXiv:1512.03385 13 | ''' 14 | import torch 15 | import torch.nn as nn 16 | import torch.nn.functional as F 17 | 18 | from torch.autograd import Variable 19 | 20 | from models.base import ClassificationModel 21 | import torchvision.models as tvmodels 22 | 23 | class BasicBlock(nn.Module): 24 | expansion = 1 25 | 26 | def __init__(self, in_planes, planes, stride=1): 27 | super(BasicBlock, self).__init__() 28 | self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=3, stride=stride, padding=1, bias=False) 29 | self.bn1 = nn.BatchNorm2d(planes) 30 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=1, padding=1, bias=False) 31 | self.bn2 = nn.BatchNorm2d(planes) 32 | 33 | self.shortcut = nn.Sequential() 34 | if stride != 1 or in_planes != self.expansion*planes: 35 | self.shortcut = nn.Sequential( 36 | nn.Conv2d(in_planes, self.expansion*planes, kernel_size=1, stride=stride, bias=False), 37 | nn.BatchNorm2d(self.expansion*planes) 38 | ) 39 | 40 | def forward(self, x): 41 | out = F.relu(self.bn1(self.conv1(x))) 42 | out = self.bn2(self.conv2(out)) 43 | out += self.shortcut(x) 44 | out = F.relu(out) 45 | return out 46 | 47 | 48 | class Bottleneck(nn.Module): 49 | expansion = 4 50 | 51 | def __init__(self, in_planes, planes, stride=1): 52 | super(Bottleneck, self).__init__() 53 | self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=1, bias=False) 54 | self.bn1 = nn.BatchNorm2d(planes) 55 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, padding=1, bias=False) 56 | self.bn2 = nn.BatchNorm2d(planes) 57 | self.conv3 = nn.Conv2d(planes, self.expansion*planes, kernel_size=1, bias=False) 58 | self.bn3 = nn.BatchNorm2d(self.expansion*planes) 59 | 60 | self.shortcut = nn.Sequential() 61 | if stride != 1 or in_planes != self.expansion*planes: 62 | self.shortcut = nn.Sequential( 63 | nn.Conv2d(in_planes, self.expansion*planes, kernel_size=1, stride=stride, bias=False), 64 | nn.BatchNorm2d(self.expansion*planes) 65 | ) 66 | 67 | def forward(self, x): 68 | out = F.relu(self.bn1(self.conv1(x))) 69 | out = F.relu(self.bn2(self.conv2(out))) 70 | out = self.bn3(self.conv3(out)) 71 | out += self.shortcut(x) 72 | out = F.relu(out) 73 | return out 74 | 75 | 76 | class ResNet(nn.Module): 77 | def __init__(self, block, block_config, num_classes=10): 78 | super(ResNet, self).__init__() 79 | 80 | num_blocks = block_config['num_blocks'] 81 | channels = block_config['num_channels'] 82 | assert len(channels) == len(num_blocks) 83 | 84 | self.in_planes = channels[0] 85 | 86 | self.conv1 = nn.Conv2d(3, channels[0], kernel_size=3, stride=1, padding=1, bias=False) 87 | self.bn1 = nn.BatchNorm2d(channels[0]) 88 | 89 | 90 | self.layers = [self._make_layer(block, channels[0], num_blocks[0], stride=1)] 91 | 92 | for i in range(1,len(channels)): 93 | self.layers.append(self._make_layer(block, channels[i], num_blocks[i], stride=2)) 94 | self.layers = nn.Sequential(*self.layers) 95 | self.avgpool = nn.AvgPool2d(block_config['pool_size']) 96 | self.linear = nn.Linear(channels[-1]*block.expansion, num_classes) 97 | 98 | def _make_layer(self, block, planes, num_blocks, stride): 99 | strides = [stride] + [1]*(num_blocks-1) 100 | layers = [] 101 | for stride in strides: 102 | layers.append(block(self.in_planes, planes, stride)) 103 | self.in_planes = planes * block.expansion 104 | return nn.Sequential(*layers) 105 | 106 | def forward(self, x): 107 | out = F.relu(self.bn1(self.conv1(x))) 108 | out = self.layers(out) 109 | out = self.avgpool(out) 110 | out = out.view(out.size(0), -1) 111 | out = self.linear(out) 112 | return out 113 | 114 | from models import register_model 115 | 116 | #### Imagenet Models #### 117 | @register_model('resnet18') 118 | def ResNet18(config): 119 | block_config = { 120 | "num_blocks": [2,2,2,2], 121 | "num_channels": [64, 128, 256, 512], 122 | "pool_size": 4 123 | } 124 | return ClassificationModel(ResNet(BasicBlock, block_config, config['data']['class_count'])) 125 | 126 | @register_model('resnet34') 127 | def ResNet34(config): 128 | block_config = { 129 | "num_blocks": [3,4,6,3], 130 | "num_channels": [64, 128, 256, 512], 131 | "pool_size": 4 132 | } 133 | return ClassificationModel(ResNet(BasicBlock, block_config, config['data']['class_count'])) 134 | 135 | @register_model('resnet50') 136 | def ResNet50(config): 137 | block_config = { 138 | "num_blocks": [3,4,6,3], 139 | "num_channels": [64, 128, 256, 512], 140 | "pool_size": 4 141 | } 142 | return ClassificationModel(ResNet(Bottleneck, block_config, config['data']['class_count'])) 143 | 144 | @register_model('resnet101') 145 | def ResNet101(config): 146 | block_config = { 147 | "num_blocks": [3,4,23,3], 148 | "num_channels": [64, 128, 256, 512], 149 | "pool_size": 4 150 | } 151 | return ClassificationModel(ResNet(Bottleneck, block_config, config['data']['class_count'])) 152 | 153 | @register_model('resnet152') 154 | def ResNet152(config): 155 | block_config = { 156 | "num_blocks": [3,8,36,3], 157 | "num_channels": [64, 128, 256, 512], 158 | "pool_size": 4 159 | } 160 | return ClassificationModel(ResNet(Bottleneck, block_config, config['data']['class_count'])) 161 | 162 | 163 | #### CIFAR MODELS #### 164 | @register_model('resnet20') 165 | def CifarResNet20(config): 166 | block_config = { 167 | "num_blocks": [3,3,3], 168 | "num_channels": [16, 32, 64], 169 | "pool_size": 8 170 | } 171 | return ClassificationModel(ResNet(BasicBlock, block_config, config['data']['class_count'])) 172 | 173 | @register_model('resnet32') 174 | def CifarResNet32(config): 175 | block_config = { 176 | "num_blocks": [5,5,5], 177 | "num_channels": [16, 32, 64], 178 | "pool_size": 8 179 | } 180 | return ClassificationModel(ResNet(BasicBlock, block_config, config['data']['class_count'])) 181 | 182 | @register_model('resnet44') 183 | def CifarResNet44(config): 184 | block_config = { 185 | "num_blocks": [7,7,7], 186 | "num_channels": [16, 32, 64], 187 | "pool_size": 8 188 | } 189 | return ClassificationModel(ResNet(Bottleneck, block_config, config['data']['class_count'])) 190 | 191 | @register_model('resnet56') 192 | def CifarResNet56(config): 193 | block_config = { 194 | "num_blocks": [9,9,9], 195 | "num_channels": [16, 32, 64], 196 | "pool_size": 8 197 | } 198 | return ClassificationModel(ResNet(Bottleneck, block_config, config['data']['class_count'])) 199 | 200 | ##### Torchvision imagenet models #### 201 | @register_model('imagenet-resnet18') 202 | def ImageNetResNet18(config): 203 | return ClassificationModel(nn.DataParallel(tvmodels.resnet18(False))) 204 | 205 | @register_model('imagenet-resnet34') 206 | def ImageNetResNet34(config): 207 | return ClassificationModel(nn.DataParallel(tvmodels.resnet34(False))) 208 | 209 | @register_model('imagenet-resnet50') 210 | def ImageNetResNet50(config): 211 | return ClassificationModel(nn.DataParallel(tvmodels.resnet50(False))) 212 | 213 | @register_model('imagenet-resnet101') 214 | def ImageNetResNet101(config): 215 | return ClassificationModel(nn.DataParallel(tvmodels.resnet101(False))) 216 | 217 | @register_model('imagenet-resnet152') 218 | def ImageNetResNet152(config): 219 | return ClassificationModel(nn.DataParallel(tvmodels.resnet152(False))) 220 | -------------------------------------------------------------------------------- /src/utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | from torch.autograd import Variable 4 | from torch.utils.data import DataLoader 5 | import torch.optim.lr_scheduler as lr_scheduler 6 | 7 | import torchnet as tnt 8 | import torchvision 9 | import torchvision.datasets as datasets 10 | import torchvision.transforms as transforms 11 | from torchvision.utils import save_image 12 | 13 | 14 | from aggmo import AggMo 15 | 16 | import datetime 17 | import os, errno 18 | 19 | def get_optimizer(config, params): 20 | optim_cfg = config['optim']['optimizer'] 21 | optim_name = optim_cfg['name'] 22 | lr = optim_cfg['lr'] 23 | 24 | if optim_name == 'sgd': 25 | optimizer = torch.optim.SGD(params, lr, momentum=optim_cfg['momentum'], weight_decay=config['optim']['wdecay']) 26 | elif optim_name == 'nesterov': 27 | optimizer = torch.optim.SGD(params, lr, momentum=optim_cfg['momentum'], nesterov=True, 28 | weight_decay=config['optim']['wdecay']) 29 | elif optim_name == 'aggmo': 30 | optimizer = AggMo(params, lr, betas=optim_cfg['betas'], weight_decay=config['optim']['wdecay']) 31 | elif optim_name == 'aggmo_exp': 32 | optimizer = AggMo.from_exp_form(params, lr, a=optim_cfg['a'], k=optim_cfg['K'], weight_decay=config['optim']['wdecay']) 33 | elif optim_name =='adam': 34 | optimizer = torch.optim.Adam(params, lr, betas=(optim_cfg['beta1'], optim_cfg['beta2']), weight_decay=config['optim']['wdecay']) 35 | else: 36 | raise Exception('Unknown optimizer') 37 | return optimizer 38 | 39 | def get_scheduler(config, optimizer): 40 | lr_schedule_conf = config['optim']['lr_schedule'] 41 | if lr_schedule_conf['name'] == 'exp': 42 | return lr_scheduler.ExponentialLR(optimizer, lr_schedule_conf['lr_decay'], lr_schedule_conf['last_epoch']) 43 | elif lr_schedule_conf['name'] == 'step': 44 | return lr_scheduler.MultiStepLR(optimizer, lr_schedule_conf['milestones'], lr_schedule_conf['lr_decay']) 45 | 46 | def save_model(model, save_path): 47 | try: 48 | os.makedirs(os.path.dirname(save_path)) 49 | except OSError as e: 50 | if e.errno != errno.EEXIST: 51 | raise 52 | 53 | torch.save(model.state_dict(), save_path) 54 | 55 | def save_imgs(tensor, fname, save_dir): 56 | try: 57 | os.makedirs(save_dir) 58 | except OSError as e: 59 | if e.errno != errno.EEXIST: 60 | raise 61 | 62 | save_image(tensor, os.path.join(save_dir, fname)) 63 | 64 | def get_data_transforms(config): 65 | train_transform = None 66 | test_transform = None 67 | 68 | transform_cfg = config['data']['transform'] 69 | if transform_cfg: 70 | name = transform_cfg['name'] 71 | normalize = transforms.Normalize(mean=transform_cfg['norm_mean'], 72 | std=transform_cfg['norm_std']) 73 | if name == 'cifar': 74 | train_transform = transforms.Compose([ 75 | transforms.RandomCrop(32, padding=4), 76 | transforms.RandomHorizontalFlip(), 77 | transforms.ToTensor(), 78 | normalize 79 | ]) 80 | 81 | test_transform = transforms.Compose([ 82 | transforms.ToTensor(), 83 | normalize 84 | ]) 85 | 86 | elif name == 'imagenet': 87 | train_transform = transforms.Compose([ 88 | transforms.RandomResizedCrop(224), 89 | transforms.RandomHorizontalFlip(), 90 | transforms.ToTensor(), 91 | normalize, 92 | ]) 93 | test_transform = transforms.Compose([ 94 | transforms.Resize(256), 95 | transforms.CenterCrop(224), 96 | transforms.ToTensor(), 97 | normalize, 98 | ]) 99 | else: 100 | train_transform = transforms.ToTensor() 101 | if test_transform is None: 102 | test_transform = transforms.ToTensor() 103 | return train_transform, test_transform 104 | 105 | def load_data(config): 106 | data_name = config['data']['name'].lower() 107 | path = os.path.join(config['data']['root'], data_name) 108 | 109 | train_transform, test_transform = get_data_transforms(config) 110 | 111 | if data_name == 'mnist': 112 | train_data = datasets.MNIST(path, download=True, transform=train_transform) 113 | val_data = datasets.MNIST(path, download=True, transform=test_transform) 114 | test_data = datasets.MNIST(path, train=False, download=True, transform=test_transform) 115 | elif data_name == 'cifar10': 116 | train_data = datasets.CIFAR10(path, download=True, transform=train_transform) 117 | val_data = datasets.CIFAR10(path, download=True, transform=test_transform) 118 | test_data = datasets.CIFAR10(path, train=False, download=True, transform=test_transform) 119 | elif data_name == 'cifar100': 120 | train_data = datasets.CIFAR100(path, download=True, transform=train_transform) 121 | val_data = datasets.CIFAR100(path, download=True, transform=test_transform) 122 | test_data = datasets.CIFAR100(path, train=False, download=True, transform=test_transform) 123 | elif data_name == 'fashion-mnist': 124 | train_data = datasets.FashionMNIST(path, download=True, transform=train_transform) 125 | val_data = datasets.FashionMNIST(path, download=True, transform=test_transform) 126 | test_data = datasets.FashionMNIST(path, train=False, download=True, transform=test_transform) 127 | elif data_name == 'imagenet-torchvision': 128 | train_data = datasets.ImageFolder(os.path.join(path, 'train'), transform=train_transform) 129 | val_data = datasets.ImageFolder(os.path.join(path, 'valid'), transform=test_transform) 130 | # Currently not loaded 131 | test_data = None 132 | else: 133 | raise NotImplementedError('Data name %s not supported' % data_name) 134 | 135 | # Manually readjust train/val size for memory saving 136 | if data_name != 'imagenet-torchvision': 137 | data_size = len(train_data) 138 | train_size = int(data_size * config['data']['train_size']) 139 | 140 | train_data.train_data = train_data.train_data[:train_size] 141 | train_data.train_labels = train_data.train_labels[:train_size] 142 | 143 | if config['data']['train_size'] != 1: 144 | val_data.train_data = val_data.train_data[train_size:] 145 | val_data.train_labels = val_data.train_labels[train_size:] 146 | else: 147 | val_data = None 148 | 149 | batch_size = config['optim']['batch_size'] 150 | num_workers = config['data']['num_workers'] 151 | loaders = { 152 | 'train': DataLoader(train_data, batch_size=batch_size, shuffle=True, num_workers=num_workers), 153 | 'validation': DataLoader(val_data, batch_size=batch_size, num_workers=num_workers), 154 | 'test': DataLoader(test_data, batch_size=batch_size, num_workers=num_workers) 155 | } 156 | 157 | return loaders -------------------------------------------------------------------------------- /tensorflow/aggmo.py: -------------------------------------------------------------------------------- 1 | """AggMo for TensorFlow.""" 2 | 3 | from tensorflow.python.eager import context 4 | from tensorflow.python.framework import ops 5 | from tensorflow.python.ops import control_flow_ops 6 | from tensorflow.python.ops import math_ops 7 | from tensorflow.python.ops import resource_variable_ops 8 | from tensorflow.python.ops import state_ops 9 | from tensorflow.python.ops import variable_scope 10 | from tensorflow.python.training import optimizer 11 | 12 | 13 | class AggMo(optimizer.Optimizer): 14 | def __init__(self, learning_rate=0.1, betas=[0, 0.9, 0.99], use_locking=False, name="AggMo"): 15 | super(AggMo, self).__init__(use_locking, name) 16 | self._lr = learning_rate 17 | self._betas = betas 18 | 19 | @classmethod 20 | def from_exp_form(cls, learning_rate, a=0.1, k=3, use_locking=False, name="AggMo"): 21 | betas = [1.0 - a**i for i in range(K)] 22 | return cls(learning_rate, betas, use_locking, name) 23 | 24 | def _create_slots(self, var_list): 25 | # Create slots for each momentum component 26 | for v in var_list : 27 | for i in range(len(self._betas)): 28 | self._zeros_slot(v, "momentum_{}".format(i), self._name) 29 | 30 | def _prepare(self): 31 | learning_rate = self._lr 32 | if callable(learning_rate): 33 | learning_rate = learning_rate() 34 | self._lr_tensor = ops.convert_to_tensor(learning_rate, name="learning_rate") 35 | 36 | betas = self._betas 37 | if callable(betas): 38 | betas = betas() 39 | self._betas_tensor = ops.convert_to_tensor(betas, name="betas") 40 | 41 | def _apply_dense(self, grad, var): 42 | lr = math_ops.cast(self._lr_tensor / len(self._betas), var.dtype.base_dtype) 43 | betas = math_ops.cast(self._betas_tensor, var.dtype.base_dtype) 44 | 45 | momentums = [] 46 | summed_momentum = 0.0 47 | for i in range(len(self._betas)): 48 | m = self.get_slot(var, "momentum_{}".format(i)) 49 | m_t = state_ops.assign(m, betas[i] * m + grad) 50 | summed_momentum += m_t 51 | momentums.append(m_t) 52 | var_update = state_ops.assign_sub(var, lr * summed_momentum, use_locking=self._use_locking) 53 | return control_flow_ops.group(*[var_update, *momentums]) 54 | 55 | def _resource_apply_dense(self, grad, var): 56 | var = var.handle 57 | lr = math_ops.cast(self._lr_tensor / len(self._betas), var.dtype.base_dtype) 58 | betas = math_ops.cast(self._betas_tensor, var.dtype.base_dtype) 59 | 60 | momentums = [] 61 | summed_momentum = 0.0 62 | for i in range(len(self._betas)): 63 | m = self.get_slot(var, "momentum_{}".format(i)) 64 | m_t = state_ops.assign(m, betas[i] * m + grad) 65 | summed_momentum += m_t 66 | momentums.append(m_t) 67 | var_update = state_ops.assign_sub(var, lr * summed_momentum, use_locking=self._use_locking) 68 | return control_flow_ops.group(*[var_update, *momentums]) 69 | 70 | def _apply_sparse_shared(self, grad, var, indices, scatter_add): 71 | lr = math_ops.cast(self._lr_tensor / len(self._betas), var.dtype.base_dtype) 72 | betas = math_ops.cast(self._betas_tensor, var.dtype.base_dtype) 73 | 74 | momentums = [] 75 | summed_momentum = 0.0 76 | for i in range(len(self._betas)): 77 | m = self.get_slot(var, "momentum_{}".format(i)) 78 | m_t = state_ops.assign(m, betas[i] * m, use_locking=self._use_locking) 79 | with ops.control_dependencies([m_t]): 80 | m_t = scatter_add(m, indices, grad) 81 | momentums.append(m_t) 82 | summed_momentum += m_t 83 | var_update = state_ops.assign_sub(var, lr * summed_momentum, use_locking=self._use_locking) 84 | return control_flow_ops.group(*[var_update, *momentums]) 85 | 86 | def _apply_sparse(self, grad, var): 87 | return self._apply_sparse_shared( 88 | grad.values, var, grad.indices, 89 | lambda x, i, v: state_ops.scatter_add( # pylint: disable=g-long-lambda 90 | x, i, v, use_locking=self._use_locking)) 91 | 92 | def _resource_scatter_add(self, x, i, v): 93 | with ops.control_dependencies( 94 | [resource_variable_ops.resource_scatter_add(x.handle, i, v)]): 95 | return x.value() 96 | 97 | def _resource_apply_sparse(self, grad, var, indices): 98 | return self._apply_sparse_shared( 99 | grad, var, indices, self._resource_scatter_add) 100 | --------------------------------------------------------------------------------