├── README.md ├── callbacks.py ├── fit_stanford_cars.py ├── get_data.sh ├── lm_val_fns.py ├── prepare_data.py ├── train_cifar10.py ├── train_model.py ├── train_rnn.py └── utils.py /README.md: -------------------------------------------------------------------------------- 1 | # Experiments with Adam 2 | 3 | This repo contains the scripts used to perform the experiments in [this blog post](http://www.fast.ai/2018/07/02/adam-weight-decay/). If you're using this code or our results, please cite it appropriately. 4 | 5 | You will find 6 | - a script to train [cifar10](https://www.cs.toronto.edu/~kriz/cifar.html) to >94% accuracy in 30 epochs without Test Time Augmentation or 18 with. 7 | - a script to finetune a pretrained resnet50 on the [Standford cars dataset](https://ai.stanford.edu/~jkrause/cars/car_dataset.html) to 90% accuracy in 60 epochs. 8 | - a script to train an AWD LSTM (or QRNN) to the same perplexity as [the Saleforce team](https://github.com/salesforce/awd-lstm-lm) that created them but in just 90 epochs. 9 | 10 | ## Requirements 11 | 12 | - the [fastai library](https://github.com/fastai/fastai) is necessary to run all the models. If you didn't pip-install it, don't forget to have a simlink pointing to it in the directory where you clone this repo. 13 | - pytorch 0.4.0 is necessary to have the implementation of amsgrad inside Adam. 14 | - additonaly, the QRNNs model requires the [cupy library](https://github.com/cupy/cupy). 15 | 16 | ## To install 17 | 18 | Run the script get_data.sh that will download and organize the data needed for the models 19 | 20 | ## Experiments 21 | 22 | ### Cifar10 dataset 23 | 24 | - this should train cifar10 to 94.25% accuracy (on average): 25 | ``` 26 | python train_cifar10.py 3e-3 --wd=0.1 --wd_loss=False 27 | ``` 28 | - this should train cifar10 to 94% accuracy (on average) but faster: 29 | ``` 30 | python train_cifar10.py 3e-3 --wd=0.1 --wd_loss=False --cyc_len=18 --tta=True 31 | ``` 32 | 33 | ### Stanford cars dataset 34 | 35 | - this should get 90% accuracy (on average) without TTA, 91% with: 36 | ``` 37 | python fit_stanford_cars.py '(1e-2,3e-3)' --wd=1e-3 --tta=True 38 | ``` 39 | 40 | ### Language models 41 | 42 | - this should train an AWD LSTM to 68.7/65.5 perplexity without cache pointer, 52.9/50.9 with 43 | ``` 44 | python train_rnn.py 5e-3 --wd=1.2e-6 --alpha=3 --beta=1.5 45 | ``` 46 | 47 | - this should train an AWD QRNN to 69.6/66.7 perplexity without cache pointer, 53.6/51.7 with 48 | ``` 49 | python train_rnn.py 5e-3 --wd=1e-6 --qrnn=True 50 | ``` 51 | 52 | 53 | -------------------------------------------------------------------------------- /callbacks.py: -------------------------------------------------------------------------------- 1 | from fastai.conv_learner import * 2 | 3 | 4 | class LogResults(Callback): 5 | """ 6 | Callback to log all the results of the training: 7 | - at the end of each epoch: training loss, validation loss and metrics 8 | """ 9 | 10 | def __init__(self, learn, fname): 11 | super().__init__() 12 | self.learn, self.fname = learn, fname 13 | 14 | def on_train_begin(self): 15 | self.logs, self.epoch, self.n = "", 0, 0 16 | names = ["epoch", "trn_loss", "val_loss", "accuracy"] 17 | layout = "{!s:10} " * len(names) 18 | self.logs += layout.format(*names) + "\n" 19 | 20 | def on_batch_end(self, metrics): 21 | self.loss = metrics 22 | 23 | def on_epoch_end(self, metrics): 24 | self.save_stats(self.epoch, [self.loss] + metrics) 25 | self.epoch += 1 26 | 27 | def save_stats(self, epoch, values, decimals=6): 28 | layout = "{!s:^10}" + " {!s:10}" * len(values) 29 | values = [epoch] + list(np.round(values, decimals)) 30 | self.logs += layout.format(*values) + "\n" 31 | 32 | def on_train_end(self): 33 | with open(self.fname, 'a') as f: f.write(self.logs) 34 | -------------------------------------------------------------------------------- /fit_stanford_cars.py: -------------------------------------------------------------------------------- 1 | from fastai.conv_learner import * 2 | from utils import get_opt_fn, get_phases, log_msg 3 | from callbacks import * 4 | import fire 5 | 6 | def main_train(lr, moms, wd, wd_loss, opt_fn, bs, cyc_len, beta2, amsgrad, div, pct, lin_end, tta, div_lr, fname): 7 | """ 8 | Trains a Language Model 9 | 10 | lr (float): maximum learning rate 11 | moms (float/tuple): value of the momentum/beta1. If tuple, cyclical momentums will be used 12 | wd (float): weight decay to be used 13 | wd_loss (bool): weight decay computed inside the loss if True (l2 reg) else outside (true wd) 14 | opt_fn (optimizer): name of the optim function to use (should be SGD, RMSProp or Adam) 15 | bs (int): batch size 16 | cyc_len (int): length of the cycle 17 | beta2 (float): beta2 parameter of Adam or alpha parameter of RMSProp 18 | amsgrad (bool): for Adam, sues amsgrad or not 19 | div (float): value to divide the maximum learning rate by 20 | pct (float): percentage to leave for the annealing at the end 21 | lin_end (bool): if True, the annealing phase goes from the minimum lr to 1/100th of it linearly 22 | if False, uses a cosine annealing to 0 23 | tta (bool): if True, uses Test Time Augmentation to evaluate the model 24 | div_lr (float): number to divide the diffential lrs 25 | """ 26 | arch, sz = resnet50, 224 27 | PATH = Path("data/cars/") 28 | val_idxs = list(range(8144, 16185)) 29 | tfms =tfms_from_model(arch, sz, aug_tfms = transforms_side_on, max_zoom=1.05) 30 | data = ImageClassifierData.from_csv(PATH, '', PATH/'annots.csv', bs, tfms, val_idxs=val_idxs) 31 | learn = ConvLearner.pretrained(arch, data, metrics=[accuracy]) 32 | mom = moms[0] if isinstance(moms, Iterable) else moms 33 | opt_fn = get_opt_fn(opt_fn, mom, beta2, amsgrad) 34 | learn.opt_fn = opt_fn 35 | nbs = [cyc_len[0] * (1-pct) / 2, cyc_len[0] * (1-pct) / 2, cyc_len[0] * pct] 36 | phases = get_phases(lr[0], moms, opt_fn, div, list(nbs), wd, lin_end, wd_loss) 37 | learn.fit_opt_sched(phases, callbacks=[LogResults(learn, fname)]) 38 | learn.unfreeze() 39 | lrs = np.array([lr[1]/(div_lr**2), lr[1]/div_lr, lr[1]]) 40 | nbs = [cyc_len[1] * (1-pct) / 2, cyc_len[1] * (1-pct) / 2, cyc_len[1] * pct] 41 | phases = get_phases(lrs, moms, opt_fn, div, list(nbs), wd, lin_end, wd_loss) 42 | learn.fit_opt_sched(phases, callbacks=[LogResults(learn, fname)]) 43 | if tta: 44 | preds, targs = learn.TTA() 45 | probs = np.exp(preds) 46 | probs = np.mean(probs,0) 47 | acc = learn.metrics[0](V(probs), V(targs)) 48 | loss = learn.crit(V(np.log(probs)), V(targs)).item() 49 | log_msg(open(fname, 'a'), f'Final loss: {loss}, Final accuracy: {acc}') 50 | 51 | def train_lm(lr, moms=(0.95,0.85), wd=1.2e-6, wd_loss=True, opt_fn='Adam', bs=128, cyc_len=(20,40), beta2=0.99, amsgrad=False, 52 | div=10, pct=0.1, lin_end=True, tta=False, div_lr=3, name='', cuda_id=0, nb_exp=1): 53 | """ 54 | Launches the trainings. 55 | 56 | See main_train for the description of all the arguments. 57 | name (string): name to be added to the log file 58 | cuda_id (int): index of the GPU to use 59 | nb_exp (int): number of experiments to run in a row 60 | """ 61 | torch.cuda.set_device(cuda_id) 62 | init_text = f'{name}_{cuda_id}' + '\n' 63 | init_text += f'lr: {lr}; moms: {moms}; wd: {wd}; wd_loss: {wd_loss}; opt_fn: {opt_fn}; bs: {bs}; cyc_len: {cyc_len};' 64 | init_text += f'beta2: {beta2}; amsgrad: {amsgrad}; div: {div}; pct: {pct}; lin_end: {lin_end}; tta: {tta}; div_lr: {div_lr}' 65 | fname = f'logs_{name}_{cuda_id}.txt' 66 | log_msg(open(fname, 'w'), init_text) 67 | for i in range(nb_exp): 68 | log_msg(open(fname, 'a'), '\n' + f'Experiment {i+1}') 69 | main_train(lr, moms, wd, wd_loss, opt_fn, bs, cyc_len, beta2, amsgrad, div, pct, lin_end, tta, div_lr, fname) 70 | 71 | if __name__ == '__main__': fire.Fire(train_lm) -------------------------------------------------------------------------------- /get_data.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | ROOT="data" 3 | CAR_DIR="${ROOT}/cars" 4 | C10_DIR="${ROOT}/cifar10" 5 | WT2_DIR="${ROOT}/wikitext" 6 | 7 | mkdir -p "${ROOT}" 8 | mkdir -p "${CAR_DIR}" 9 | mkdir -p "${C10_DIR}" 10 | mkdir -p "${WT2_DIR}" 11 | 12 | echo "Downloading the datasets" 13 | wget -c "http://imagenet.stanford.edu/internal/car196/car_ims.tgz" -P "${CAR_DIR}" 14 | wget -c "http://imagenet.stanford.edu/internal/car196/cars_annos.mat" -P "${CAR_DIR}" 15 | wget -c "http://pjreddie.com/media/files/cifar.tgz" -P "${C10_DIR}" 16 | wget -c "https://s3.amazonaws.com/research.metamind.io/wikitext/wikitext-2-v1.zip" -P "${WT2_DIR}" 17 | 18 | 19 | echo "Uncompressing the datasets" 20 | cd "${CAR_DIR}" 21 | tar -xzf "car_ims.tgz" 22 | cd ../.. 23 | cd "${C10_DIR}" 24 | tar -xzf "cifar.tgz" 25 | cd ../.. 26 | cd "${WT2_DIR}" 27 | unzip -q "wikitext-2-v1.zip" 28 | cd ../.. 29 | 30 | echo "Preparing the datasets" 31 | python prepare_data.py 32 | 33 | 34 | -------------------------------------------------------------------------------- /lm_val_fns.py: -------------------------------------------------------------------------------- 1 | from fastai.text import * 2 | 3 | class TextReader(): 4 | """ Returns a language model iterator that iterates through batches that are of length N(bptt,5) 5 | The first batch returned is always bptt+25; the max possible width. This is done because of they way that pytorch 6 | allocates cuda memory in order to prevent multiple buffers from being created as the batch width grows. 7 | """ 8 | def __init__(self, nums, bptt, backwards=False): 9 | self.bptt,self.backwards = bptt,backwards 10 | self.data = self.batchify(nums) 11 | self.i,self.iter = 0,0 12 | self.n = len(self.data) 13 | 14 | def __iter__(self): 15 | self.i,self.iter = 0,0 16 | while self.i < self.n-1 and self.iter 0: 79 | targ_cache = targ_history[:start+i] if start + i <= window else targ_history[start+i-window:start+i] 80 | hid_cache = hid_history[:start+i] if start + i <= window else hid_history[start+i-window:start+i] 81 | all_dot_prods = torch.mv(theta * hid_cache, hiddens[i]) 82 | exp_dot_prods = F.softmax(all_dot_prods).unsqueeze(1) 83 | p_cache = (exp_dot_prods.expand_as(targ_cache) * targ_cache).sum(0).squeeze() 84 | p = (1-lambd) * pv + lambd * p_cache 85 | targ_pred = p[targets[i]] 86 | total_loss -= torch.log(targ_pred.detach()) 87 | targ_history = targ_history[-window:] 88 | hid_history = hid_history[-window:] 89 | mean = total_loss / (bptt * len(data_source)) 90 | return mean, np.exp(mean) -------------------------------------------------------------------------------- /prepare_data.py: -------------------------------------------------------------------------------- 1 | import os 2 | import shutil 3 | from pathlib import Path 4 | import pandas as pd 5 | from scipy.io import loadmat 6 | import fire 7 | 8 | def prepare_cars(): 9 | PATH = Path('data/cars') 10 | annots = loadmat(PATH/'cars_annos.mat') 11 | trn_ids, trn_classes, val_ids, val_classes = [], [], [], [] 12 | for annot in annots['annotations'][0]: 13 | if int(annot[6]) == 1: 14 | val_classes.append(int(annot[5])) 15 | val_ids.append(annot[0][0]) 16 | else: 17 | trn_classes.append(int(annot[5])) 18 | trn_ids.append(annot[0][0]) 19 | df_trn = pd.DataFrame({'fname': trn_ids, 'class': trn_classes}, columns=['fname', 'class']) 20 | df_val = pd.DataFrame({'fname': val_ids, 'class': val_classes}, columns=['fname', 'class']) 21 | combined = df_trn.append(df_val) 22 | combined.reset_index(inplace=True) 23 | combined.drop(['index'], 1, inplace=True) 24 | combined.to_csv(PATH/'annots.csv', index=False) 25 | 26 | def prepare_cifar10(): 27 | PATH = Path('data/cifar10') 28 | TMP_PATH = PATH/'cifar' 29 | shutil.move(TMP_PATH/'train', PATH/'train') 30 | shutil.move(TMP_PATH/'test', PATH/'test') 31 | classes = ['airplane', 'automobile', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck'] 32 | for d in ['train', 'test']: 33 | for clas in classes: 34 | os.mkdir(PATH/d/clas) 35 | fnames = list(PATH.glob(f'{d}/*{clas}.png')) 36 | for fname in fnames: shutil.move(fname, PATH/d/clas/str(fname)[len(str(PATH/d))+1:]) 37 | shutil.rmtree(TMP_PATH) 38 | 39 | def prepare_wt2(): 40 | PATH = Path('data/wikitext') 41 | TMP_PATH = PATH/'wikitext-2' 42 | for name in ['train', 'valid', 'test']: 43 | shutil.move(TMP_PATH/f'wiki.{name}.tokens', PATH/f'wiki.{name}.tokens') 44 | shutil.rmtree(TMP_PATH) 45 | 46 | 47 | def prepare_data(): 48 | prepare_cars() 49 | prepare_cifar10() 50 | prepare_wt2() 51 | 52 | if __name__ == '__main__': fire.Fire(prepare_data) -------------------------------------------------------------------------------- /train_cifar10.py: -------------------------------------------------------------------------------- 1 | from fastai.conv_learner import * 2 | from fastai.models.cifar10.wideresnet import wrn_22 3 | from utils import get_opt_fn, get_phases, log_msg 4 | from callbacks import * 5 | import fire 6 | 7 | def main_train(lr, moms, wd, wd_loss, opt_fn, bs, cyc_len, beta2, amsgrad, div, pct, lin_end, tta, fname): 8 | """ 9 | Trains a Language Model 10 | 11 | lr (float): maximum learning rate 12 | moms (float/tuple): value of the momentum/beta1. If tuple, cyclical momentums will be used 13 | wd (float): weight decay to be used 14 | wd_loss (bool): weight decay computed inside the loss if True (l2 reg) else outside (true wd) 15 | opt_fn (optimizer): name of the optim function to use (should be SGD, RMSProp or Adam) 16 | bs (int): batch size 17 | cyc_len (int): length of the cycle 18 | beta2 (float): beta2 parameter of Adam or alpha parameter of RMSProp 19 | amsgrad (bool): for Adam, sues amsgrad or not 20 | div (float): value to divide the maximum learning rate by 21 | pct (float): percentage to leave for the annealing at the end 22 | lin_end (bool): if True, the annealing phase goes from the minimum lr to 1/100th of it linearly 23 | if False, uses a cosine annealing to 0 24 | tta (bool): if True, uses Test Time Augmentation to evaluate the model 25 | """ 26 | stats = (np.array([ 0.4914 , 0.48216, 0.44653]), np.array([ 0.24703, 0.24349, 0.26159])) 27 | sz=32 28 | PATH = Path("data/cifar10/") 29 | tfms = tfms_from_stats(stats, sz, aug_tfms=[RandomCrop(sz), RandomFlip()], pad=sz//8) 30 | data = ImageClassifierData.from_paths(PATH, val_name='test', tfms=tfms, bs=bs) 31 | m = wrn_22() 32 | learn = ConvLearner.from_model_data(m, data) 33 | learn.crit = nn.CrossEntropyLoss() 34 | learn.metrics = [accuracy] 35 | mom = moms[0] if isinstance(moms, Iterable) else moms 36 | opt_fn = get_opt_fn(opt_fn, mom, beta2, amsgrad) 37 | learn.opt_fn = opt_fn 38 | nbs = [cyc_len * (1-pct) / 2, cyc_len * (1-pct) / 2, cyc_len * pct] 39 | phases = get_phases(lr, moms, opt_fn, div, list(nbs), wd, lin_end, wd_loss) 40 | learn.fit_opt_sched(phases, callbacks=[LogResults(learn, fname)]) 41 | if tta: 42 | preds, targs = learn.TTA() 43 | probs = np.exp(preds)/np.exp(preds).sum(2)[:,:,None] 44 | probs = np.mean(probs,0) 45 | acc = learn.metrics[0](V(probs), V(targs)) 46 | loss = learn.crit(V(np.log(probs)), V(targs)).item() 47 | log_msg(open(fname, 'a'), f'Final loss: {loss}, Final accuracy: {acc}') 48 | 49 | def train_lm(lr, moms=(0.95,0.85), wd=1.2e-6, wd_loss=True, opt_fn='Adam', bs=128, cyc_len=30, beta2=0.99, amsgrad=False, 50 | div=10, pct=0.075, lin_end=True, tta=False, name='', cuda_id=0, nb_exp=1): 51 | """ 52 | Launches the trainings. 53 | 54 | See main_train for the description of all the arguments. 55 | name (string): name to be added to the log file 56 | cuda_id (int): index of the GPU to use 57 | nb_exp (int): number of experiments to run in a row 58 | """ 59 | torch.cuda.set_device(cuda_id) 60 | init_text = f'{name}_{cuda_id}' + '\n' 61 | init_text += f'lr: {lr}; moms: {moms}; wd: {wd}; wd_loss: {wd_loss}; opt_fn: {opt_fn}; bs: {bs}; cyc_len: {cyc_len};' 62 | init_text += f'beta2: {beta2}; amsgrad: {amsgrad}; div: {div}; pct: {pct}; lin_end: {lin_end}; tta: {tta}' 63 | fname = f'logs_{name}_{cuda_id}.txt' 64 | log_msg(open(fname, 'w'), init_text) 65 | for i in range(nb_exp): 66 | log_msg(open(fname, 'a'), '\n' + f'Experiment {i+1}') 67 | main_train(lr, moms, wd, wd_loss, opt_fn, bs, cyc_len, beta2, amsgrad, div, pct, lin_end, tta, fname) 68 | 69 | if __name__ == '__main__': fire.Fire(train_lm) -------------------------------------------------------------------------------- /train_model.py: -------------------------------------------------------------------------------- 1 | from fastai.conv_learner import * 2 | from fastai.models.cifar10.wideresnet import wrn_22 3 | from train_phases import * 4 | from callbacks import * 5 | from sklearn.metrics import fbeta_score 6 | import fire 7 | 8 | def f2(preds, targs, start=0.17, end=0.24, step=0.01): 9 | with warnings.catch_warnings(): 10 | warnings.simplefilter("ignore") 11 | return max([fbeta_score(targs, (preds>th), 2, average='samples') 12 | for th in np.arange(start,end,step)]) 13 | 14 | def get_data(data, bs): 15 | PATH = Path(f'../data/{data}/') 16 | if data=='dogscats': 17 | sz, arch = 224, resnet34 18 | tfms = tfms_from_model(arch, sz, aug_tfms=transforms_side_on, max_zoom=1.05) 19 | data = ImageClassifierData.from_paths(PATH, tfms=tfms, bs=bs) 20 | learn = ConvLearner.pretrained(arch, data) 21 | frozen, log_probs, to_sf = True, True, False 22 | elif data=='planet': 23 | sz, arch = 128, resnet34 24 | label_csv = PATH/'train_v2.csv' 25 | n = len(list(open(label_csv)))-1 26 | val_idxs = get_cv_idxs(n) 27 | tfms = tfms_from_model(arch, sz, aug_tfms=transforms_top_down, max_zoom=1.05) 28 | data = ImageClassifierData.from_csv(PATH, 'train-jpg', label_csv, tfms=tfms, 29 | suffix='.jpg', val_idxs=val_idxs, test_name='test-jpg') 30 | data = data.resize(int(sz*1.3), 'tmp') 31 | learn = ConvLearner.pretrained(arch, data, metrics=[f2]) 32 | frozen, log_probs, to_sf = True, False, False 33 | elif data=='cifar10': 34 | stats = (np.array([ 0.4914 , 0.48216, 0.44653]), np.array([ 0.24703, 0.24349, 0.26159])) 35 | sz=32 36 | tfms = tfms_from_stats(stats, sz, aug_tfms=[RandomCrop(sz), RandomFlip()], pad=sz//8) 37 | data = ImageClassifierData.from_paths(PATH, val_name='test', tfms=tfms, bs=bs) 38 | m = wrn_22() 39 | learn = ConvLearner.from_model_data(m, data) 40 | learn.crit = nn.CrossEntropyLoss() 41 | learn.metrics = [accuracy] 42 | frozen, log_probs, to_sf = False, False, True 43 | return learn, frozen, log_probs, to_sf 44 | 45 | def get_opt_fn(opt_fn, mom, beta, eps, amsgrad): 46 | if opt_fn=='SGD': res = optim.SGD 47 | elif opt_fn=='RMSProp': res = optim.RMSprop if beta is None else partial(optim.RMSProp, alpha=beta) 48 | else: 49 | if beta is None: beta=0.999 50 | if eps is None: eps=1e-8 51 | res = partial(optim.Adam, betas=(mom,beta), eps=eps, amsgrad=amsgrad) 52 | return res 53 | 54 | def get_trn_phases(trn_met, lr, n_cyc, moms, opt_fn, cyc_len, cyc_mul, div, pct, wd, l2_reg, true_wd): 55 | if trn_met=='CAR': trn_phases = CAR_phases(lr, n_cyc, moms, opt_fn, cyc_len, cyc_mul, wd, l2_reg, true_wd) 56 | elif trn_met=='CLR': trn_phases = CLR_phases(lr, n_cyc, moms, opt_fn, cyc_len, div, wd, l2_reg, true_wd) 57 | else: trn_phases = OCY_phases(lr, moms, opt_fn, cyc_len, div, pct, wd, l2_reg, true_wd) 58 | return trn_phases 59 | 60 | def launch_training(lr, mom, bs=64, mom2=None, wd=0, trn_met='CAR', n_cyc=1, cyc_len=1, cyc_mul=1, div=10, pct = 0.1, 61 | opt_fn='Adam', beta=None, eps=None, eps2=None, true_wd=False, snap=False, swa=False, tta=False, amsgrad=False, cuda_id=0, name='', 62 | data='dogscats', freeze_first=None, div_diff_lr=None, l2_reg=True, init_text=''): 63 | assert trn_met in {'CAR', 'CLR', '1CY'}, 'trn_met should be CAR (Cos Anneal with restart), CLR (cyclical learning rates) or 1CY (1cycle)' 64 | assert opt_fn in {'SGD', 'RMSProp', 'Adam'}, 'optim should be SGD, RMSProp or Adam' 65 | torch.cuda.set_device(cuda_id) 66 | learn, frozen, log_probs, to_sf = get_data(data, bs) 67 | opt_fn = get_opt_fn(opt_fn, mom, beta, eps, amsgrad) 68 | learn.opt_fn = opt_fn 69 | moms = mom if mom2 is None else (mom,mom2) 70 | if freeze_first is None: freeze_first=frozen 71 | if freeze_first: 72 | if not isinstance(lr, tuple): lr = (lr, lr) 73 | if not isinstance(n_cyc,tuple): n_cyc = (n_cyc,n_cyc) 74 | if not isinstance(cyc_len,tuple): cyc_len = (cyc_len,cyc_len) 75 | if not isinstance(cyc_mul,tuple): cyc_mul = (cyc_mul,cyc_mul) 76 | trn_phases = get_trn_phases(trn_met, lr[0], n_cyc[0], moms, opt_fn, cyc_len[0], cyc_mul[0], div, pct, wd, l2_reg, true_wd) 77 | cbs = [LogResults(learn, f'logs_{name}_{cuda_id}.txt', init_text + '\n\nPhase1')] 78 | learn.fit_opt_sched(trn_phases, use_swa=swa, callbacks=cbs) 79 | learn.unfreeze() 80 | lr, n_cyc, cyc_len, cyc_mul = lr[1], n_cyc[1], cyc_len[1], cyc_mul[1] 81 | if div_diff_lr is None: 82 | div_diff_lr = 10 if data=='dogscats' else (3 if data=='planets' else 1) 83 | if div_diff_lr != 1: lrs = np.array([lr/(div_diff_lr**2), lr/div_diff_lr, lr]) 84 | else: lrs = lr 85 | trn_phases = get_trn_phases(trn_met, lrs, n_cyc, moms, opt_fn, cyc_len, cyc_mul, div, pct, wd, l2_reg, true_wd) 86 | nbs = [phase.epochs for phase in trn_phases] 87 | log_rec = LogResults(learn, f'logs_{name}_{cuda_id}.txt', '\nPhase2') 88 | cbs = [log_rec] 89 | if snap: cbs.append(SaveModel(learn, 'cycle')) 90 | if eps2: cbs.append(EpsScheduler(learn, sum(nbs), eps, eps2)) 91 | learn.fit_opt_sched(trn_phases, use_swa=swa, callbacks=cbs) 92 | #val_los, scale = validate_scale(learn) 93 | #print(f'Scaled loss: {val_los} scale: {scale}') 94 | #with open(f'logs_{name}_{cuda_id}.txt','a') as f: f.write('\n' + f'Scaled loss: {val_los} scale: {scale}') 95 | if tta or snap: 96 | probs, targs = get_probs(learn, n_cyc, tta, snap, log_probs, to_sf) 97 | acc = learn.metrics[0](V(probs), V(targs)) 98 | if log_probs: probs = np.log(probs) 99 | loss = learn.crit(V(probs), V(targs)).item() 100 | print(f'Final loss: {loss} Final metric: {acc}') 101 | with open(f'logs_{name}_{cuda_id}.txt','a') as f: f.write('\n' + f'Final loss: {loss} Final metric: {acc}') 102 | 103 | def get_probs(learn, n_cyc, tta, snap, logs, to_sf): 104 | if tta and not snap: 105 | preds, targs = learn.TTA() 106 | if logs: probs = np.exp(preds) 107 | elif to_sf: probs = np.exp(preds)/np.exp(preds).sum(2)[:,:,None] 108 | return np.mean(probs,0), targs 109 | probs = [] 110 | for i in range(1,n_cyc+1): 111 | learn.load('cycle' + str(i)) 112 | preds, targs = learn.predict_with_targs() if not tta else learn.TTA() 113 | if logs: preds = np.exp(preds) 114 | elif to_sf: preds = np.exp(preds)/np.exp(preds).sum(2)[:,:,None] if tta else np.exp(preds)/np.exp(preds).sum()[:,None] 115 | if tta: preds = np.mean(preds,0) 116 | probs.append(preds[None]) 117 | probs = np.concatenate(probs, 0) 118 | return np.mean(probs, 0), targs 119 | 120 | def validate_scale(learn): 121 | batch_cnts = [] 122 | losses = np.zeros((100,len(learn.data.val_dl))) 123 | learn.model.eval() 124 | scales = np.linspace(0.5,1.5,100) 125 | with no_grad_context(): 126 | j=0 127 | for (*x,y) in iter(learn.data.val_dl): 128 | if isinstance(x,list): x = x[0] 129 | preds = learn.model(VV(x)) 130 | batch_cnts.append(len(x)) 131 | for i, scale in enumerate(scales): 132 | l = learn.crit(preds * scale, VV(y)) 133 | losses[i,j] = l 134 | j+=1 135 | final_losses = np.average(losses, 1, weights=batch_cnts) 136 | best_los = np.min(final_losses) 137 | idx = np.argmin(final_losses) 138 | return best_los, scales[idx] 139 | 140 | def train_model(lr, mom, bs=64, mom2=None, wd=0, trn_met='CAR', n_cyc=1, cyc_len=1, cyc_mul=1, div=10, pct = 0.1, opt_fn='Adam', 141 | beta=None, eps=None, eps2=None, true_wd=False, snap=False, swa=False, tta=False, amsgrad=False, cuda_id=0, name='', data='dogscats', 142 | freeze_first=None, div_diff_lr=None, l2_reg=True, nb_exp=5): 143 | if os.path.isfile(f'logs_{name}_{cuda_id}.txt'): 144 | os.remove(f'logs_{name}_{cuda_id}.txt') 145 | init_text = f'{name}_{cuda_id}\n' 146 | init_text = f'lr {lr}; max_mom {mom}; batch_size {bs}; min_mom {mom2}; weight_decay {wd} train_method {trn_met}; num_cycles {n_cyc}; ' 147 | init_text += f'cycle_len {cyc_len}; cycle_mult {cyc_mul}; lr_div {div}; pct_relax {pct}; optimizer {opt_fn}; beta {beta}; ' 148 | init_text += f'true_wd {true_wd}; snapshot_ensemble {snap}; use_swa {swa}; tta {tta}; amsgrad {amsgrad}; data {data}; ' 149 | init_text += f'freeze_first {freeze_first}' 150 | print(init_text) 151 | for i in range(0,nb_exp): 152 | print('\n' + f'Experiment {i+1}') 153 | launch_training(lr, mom, bs, mom2, wd, trn_met, n_cyc, cyc_len, cyc_mul, div, pct, opt_fn, 154 | beta, eps, eps2, true_wd, snap, swa, tta, amsgrad, cuda_id, name, data, freeze_first, div_diff_lr, l2_reg, init_text) 155 | 156 | if __name__ == '__main__': fire.Fire(train_model) 157 | 158 | -------------------------------------------------------------------------------- /train_rnn.py: -------------------------------------------------------------------------------- 1 | from fastai.text import * 2 | from utils import get_opt_fn, get_phases, log_msg 3 | from callbacks import * 4 | from lm_val_fns import * 5 | import fire 6 | 7 | EOS = '' 8 | PATH = Path('data/wikitext/') 9 | 10 | def read_file(filename): 11 | """ 12 | Reads the file in filemane and prepares the tokens. 13 | """ 14 | tokens = [] 15 | with open(PATH/filename) as f: 16 | for line in f: 17 | tokens.append(line.split() + [EOS]) 18 | return np.array(tokens) 19 | 20 | def main_train(lr, moms, wd, wd_loss, opt_fn, bs, bptt, drops, beta2, amsgrad, div, nbs, lin_end, clip, alpha, beta, qrnn, bias, fname): 21 | """ 22 | Trains a Language Model 23 | 24 | lr (float): maximum learning rate 25 | moms (float/tuple): value of the momentum/beta1. If tuple, cyclical momentums will be used 26 | wd (float): weight decay to be used 27 | wd_loss (bool): weight decay computed inside the loss if True (l2 reg) else outside (true wd) 28 | opt_fn (optimizer): name of the optim function to use (should be SGD, RMSProp or Adam) 29 | bs (int): batch size 30 | bptt (int): bptt parameter for the training 31 | drops (np.array of float): dropouts to use 32 | beta2 (float): beta2 parameter of Adam or alpha parameter of RMSProp 33 | amsgrad (bool): for Adam, sues amsgrad or not 34 | div (float): value to divide the maximum learning rate by 35 | nbs (list): number of epochs for each phase (ascending, constant if len==4, descending, annealing) 36 | lin_end (bool): if True, the annealing phase goes from the minimum lr to 1/100th of it linearly 37 | if False, uses a cosine annealing to 0 38 | clip (float): value of gradient clipping to use 39 | alpha (float): alpha parameter for the AR regularization function 40 | beta (float): beta parameter for the AR regularization function 41 | qrnn (bool): if True, will use QRNNs instead of LSTMs 42 | bias (bool): if True, the decoder in the LM has bias 43 | """ 44 | trn_tok = read_file('wiki.train.tokens') 45 | val_tok = read_file('wiki.valid.tokens') 46 | tst_tok = read_file('wiki.test.tokens') 47 | cnt = Counter(word for sent in trn_tok for word in sent) 48 | itos = [o for o,c in cnt.most_common()] 49 | itos.insert(0,'_pad_') 50 | vocab_size = len(itos) 51 | if qrnn: em_sz, nh, nl = 400, 1550, 4 52 | else: em_sz, nh, nl = 400, 1150, 3 53 | stoi = collections.defaultdict(lambda : 5, {w:i for i,w in enumerate(itos)}) 54 | trn_ids = np.array([([stoi[w] for w in s]) for s in trn_tok]) 55 | val_ids = np.array([([stoi[w] for w in s]) for s in val_tok]) 56 | tst_ids = np.array([([stoi[w] for w in s]) for s in tst_tok]) 57 | trn_dl = LanguageModelLoader(np.concatenate(trn_ids), bs, bptt) 58 | val_dl = LanguageModelLoader(np.concatenate(val_ids), bs, bptt) 59 | md = LanguageModelData(PATH, 0, vocab_size, trn_dl, val_dl, bs=bs, bptt=bptt) 60 | defaut_drops = np.array([0.6,0.4,0.5,0.1,0.2]) if not qrnn else np.array([0.4,0.4,0.1,0.1,0.2]) 61 | drops = defaut_drops if drops is None else np.array(list(drops)) 62 | mom = moms[0] if isinstance(moms, Iterable) else moms 63 | opt_fn = get_opt_fn(opt_fn, mom, beta2, amsgrad) 64 | learner= md.get_model(opt_fn, em_sz, nh, nl, dropouti=drops[0], dropout=drops[1], wdrop=drops[2], 65 | dropoute=drops[3], dropouth=drops[4], qrnn=qrnn, bias=bias) 66 | learner.metrics = [accuracy] 67 | learner.clip = clip 68 | learner.reg_fn = partial(seq2seq_reg, alpha=alpha, beta=beta) 69 | learner.unfreeze() 70 | phases = get_phases(lr, moms, opt_fn, div, list(nbs), wd, lin_end, wd_loss) 71 | learner.fit_opt_sched(phases, callbacks=[LogResults(learner, fname)]) 72 | val_los, val_pp = my_validate(learner.model, np.concatenate(val_ids)) 73 | log_msg(open(fname, 'a'), f'Validation loss: {val_los}, Validation perplexity: {val_pp}') 74 | tst_los, tst_pp = my_validate(learner.model, np.concatenate(tst_ids)) 75 | log_msg(open(fname, 'a'), f'Test loss: {tst_los}, Test perplexity: {tst_pp}') 76 | cache_vlos, cache_vpp = my_cache_pointer(learner.model, np.concatenate(val_ids), vocab_size) 77 | log_msg(open(fname, 'a'), f'Cache validation loss: {cache_vlos}, Cache validation perplexity: {cache_vpp}') 78 | cache_tlos, cache_tpp = my_cache_pointer(learner.model, np.concatenate(tst_ids), vocab_size) 79 | log_msg(open(fname, 'a'), f'Cache test loss: {cache_tlos}, Cache test perplexity: {cache_tpp}') 80 | 81 | def train_lm(lr, moms=(0.8,0.7), wd=1.2e-6, wd_loss=True, opt_fn='Adam', bs=100, bptt=70, drops=None, beta2=0.99, amsgrad=False, 82 | div=10, nbs=(7.5,37.5,37.5,7.5), lin_end=False, clip=0.12, alpha=2, beta=1, qrnn=False, bias=True, 83 | name='', cuda_id=0, nb_exp=1): 84 | """ 85 | Launches the trainings. 86 | 87 | See main_train for the description of all the arguments. 88 | name (string): name to be added to the log file 89 | cuda_id (int): index of the GPU to use 90 | nb_exp (int): number of experiments to run in a row 91 | """ 92 | torch.cuda.set_device(cuda_id) 93 | init_text = f'{name}_{cuda_id}' + '\n' 94 | init_text += f'lr: {lr}; moms: {moms}; wd: {wd}; wd_loss: {wd_loss}; opt_fn: {opt_fn}; bs: {bs}; bptt: {bptt}; drops: {drops};' 95 | init_text += f'beta2: {beta2}; amsgrad: {amsgrad}; div: {div}; nbs: {nbs}; lin_end: {lin_end}; clip: {clip}; alpha: {alpha}; beta: {beta}; ' 96 | init_text += f'qrnn: {qrnn}; bias: {bias}' 97 | fname = f'logs_{name}_{cuda_id}.txt' 98 | log_msg(open(fname, 'w'), init_text) 99 | for i in range(nb_exp): 100 | log_msg(open(fname, 'a'), '\n' + f'Experiment {i+1}') 101 | main_train(lr, moms, wd, wd_loss, opt_fn, bs, bptt, drops, beta2, amsgrad, div, nbs, lin_end, clip, alpha, beta, qrnn, bias, fname) 102 | 103 | if __name__ == '__main__': fire.Fire(train_lm) -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | from fastai.conv_learner import * 2 | 3 | def log_msg(file, msg): 4 | print(msg) 5 | file.write('\n' + msg) 6 | 7 | def get_opt_fn(opt_fn, mom, beta, amsgrad): 8 | """ 9 | Helper function to return a proper optim function from its name 10 | 11 | opt_fn (string): name of the optim function (should be SGD, RMSProp or Adam) 12 | mom (float): momentum to use (or beta1 in the case of Adam) 13 | beta (float): alpha parameter in RMSProp and beta2 in Adam 14 | amsgrad (bool): for Adam only, uses amsgrad or not 15 | """ 16 | assert opt_fn in {'SGD', 'RMSProp', 'Adam'}, 'optim should be SGD, RMSProp or Adam' 17 | if opt_fn=='SGD': res = optim.SGD 18 | elif opt_fn=='RMSProp': res = optim.RMSprop if beta is None else partial(optim.RMSProp, alpha=beta) 19 | else: res = partial(optim.Adam, amsgrad=amsgrad) if beta is None else partial(optim.Adam, betas=(mom,beta), amsgrad=amsgrad) 20 | return res 21 | 22 | def get_one_phase(nb, opt_fn, lr, lr_decay, moms, wd, wd_loss): 23 | """ 24 | Helper function to create one training phase. 25 | 26 | nb (int): number of epochs 27 | opt_fn (optimizer): the optim function to use 28 | lr (float/tuple): the learning rate(s) to use. If tuple, going from the first to the second 29 | lr_decay (DecayType): the decay type to go from lr1 to lr2 30 | moms (float/tuple): the momentum(s) to use. If tuple, going from the first to the second linearly 31 | wd (float): weight decay 32 | wd_loss (bool): weight decay computed inside the loss if True (l2 reg) else outside (true wd) 33 | """ 34 | if isinstance(moms, Iterable): 35 | return TrainingPhase(nb, opt_fn, lr=lr, lr_decay=lr_decay, momentum=moms, 36 | momentum_decay=DecayType.LINEAR, wds=wd, wd_loss=wd_loss) 37 | else: 38 | return TrainingPhase(nb, opt_fn, lr=lr, lr_decay=lr_decay, momentum=moms, 39 | wds=wd, wd_loss=wd_loss) 40 | 41 | def get_phases(lr, moms, opt_fn, div, nbs, wd, lin_end=False, wd_loss=True): 42 | """ 43 | Creates the phases for a 1cycle policy (or a variant) 44 | 45 | lr (float): maximum learning rate 46 | moms (float/tuple): value of the momentum/beta1. If tuple, cyclical momentums will be used 47 | opt_fn (optimizer): the optim function to use 48 | div (float): value to divide the maximum learning rate by 49 | nbs (list): number of epochs for each phase (ascending, constant if len==4, descending, annealing) 50 | wd (float): weight decay 51 | lin_end (bool): if True, the annealing phase goes from the minimum lr to 1/100th of it linearly 52 | if False, uses a cosine annealing to 0 53 | wd_loss (bool): weight decay computed inside the loss if True (l2 reg) else outside (true wd) 54 | """ 55 | max_mom = moms[0] if isinstance(moms, Iterable) else moms 56 | min_mom = moms[1] if isinstance(moms, Iterable) else moms 57 | moms_r = (moms[1],moms[0]) if isinstance(moms, Iterable) else moms 58 | phases = [get_one_phase(nbs[0], opt_fn, (lr/div,lr), DecayType.LINEAR, moms, wd, wd_loss)] 59 | if len(nbs)==4: 60 | phases.append(get_one_phase(nbs[1], opt_fn, lr, DecayType.NO, min_mom, wd, wd_loss)) 61 | nbs = [nbs[0]] + nbs[2:] 62 | phases.append(get_one_phase(nbs[1], opt_fn, (lr,lr/div), DecayType.LINEAR, moms_r, wd, wd_loss)) 63 | if lin_end: 64 | phases.append(get_one_phase(nbs[2], opt_fn, (lr/div,lr/(100*div)), DecayType.LINEAR, max_mom, wd, wd_loss)) 65 | else: 66 | phases.append(get_one_phase(nbs[2], opt_fn, lr/div, DecayType.COSINE, max_mom, wd, wd_loss)) 67 | return phases --------------------------------------------------------------------------------