├── .gitignore ├── README.md ├── curves ├── imdb-curves │ ├── acclip-imdb-lstm │ ├── adam-imdb-lstm │ └── sgd-imdb-lstm ├── mrpc-curves │ ├── acclip-mrpc-bert │ ├── adam-mrpc-bert │ └── sgd-mrpc-bert ├── squad-curves │ ├── acclip-squad-bert │ └── adam-squad-bert └── sst-curves │ ├── acclip-sst-lstm │ ├── adam-sst-lstm │ └── sgd-sst-lstm ├── download_mrpc_data.py ├── models └── lstm_attn.py ├── noise_plot.ipynb ├── optimizers └── ACClip.py ├── plot.ipynb ├── run_mrpc.py ├── run_mrpc.sh ├── run_text_classifier.py └── utils.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | *.egg-info/ 24 | .installed.cfg 25 | *.egg 26 | MANIFEST 27 | 28 | # PyInstaller 29 | # Usually these files are written by a python script from a template 30 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 31 | *.manifest 32 | *.spec 33 | 34 | # Installer logs 35 | pip-log.txt 36 | pip-delete-this-directory.txt 37 | 38 | # Unit test / coverage reports 39 | htmlcov/ 40 | .tox/ 41 | .coverage 42 | .coverage.* 43 | .cache 44 | nosetests.xml 45 | coverage.xml 46 | *.cover 47 | .hypothesis/ 48 | .pytest_cache/ 49 | 50 | # Translations 51 | *.mo 52 | *.pot 53 | 54 | # Django stuff: 55 | *.log 56 | local_settings.py 57 | db.sqlite3 58 | 59 | # Flask stuff: 60 | instance/ 61 | .webassets-cache 62 | 63 | # Scrapy stuff: 64 | .scrapy 65 | 66 | # Sphinx documentation 67 | docs/_build/ 68 | 69 | # PyBuilder 70 | target/ 71 | 72 | # Jupyter Notebook 73 | .ipynb_checkpoints 74 | 75 | # pyenv 76 | .python-version 77 | 78 | # celery beat schedule file 79 | celerybeat-schedule 80 | 81 | # SageMath parsed files 82 | *.sage.py 83 | 84 | # Environments 85 | .env 86 | .venv 87 | env/ 88 | venv/ 89 | ENV/ 90 | env.bak/ 91 | venv.bak/ 92 | 93 | # Spyder project settings 94 | .spyderproject 95 | .spyproject 96 | 97 | # Rope project settings 98 | .ropeproject 99 | 100 | # mkdocs documentation 101 | /site 102 | 103 | # mypy 104 | .mypy_cache/ 105 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # ACClip-Pytorch 2 | The PyTorch implementation of ICLR2020 submission: Why ADAM Beats SGD for Attention Models (https://openreview.net/pdf?id=SJx37TEtDH) 3 | 4 | Github Repo: https://github.com/rivercold/ACClip-Pytorch 5 | 6 | ## Requirments 7 | Strongly recommend that you have a Anaconda enviroment. (https://www.anaconda.com/distribution/) 8 | * Python >= 3.6 9 | * PyTorch >= 1.0 (https://pytorch.org/get-started/locally/) 10 | * torchtext >= 0.4.0 (conda install -c pytorch torchtext) 11 | * numpy >= 1.16.4 12 | * matplotlib >= 3.1.0 13 | * transformers >= 2.1.1 (pip install transformers) 14 | 15 | ## Core Implementaions 16 | `ACClip` optimizer: https://github.com/rivercold/ACClip-Pytorch/blob/master/optimizers/ACClip.py 17 | `print_noise` function for LSTM: https://github.com/rivercold/ACClip-Pytorch/blob/master/run_text_classifier.py#L22 18 | `print_noise` function for BERT: https://github.com/rivercold/ACClip-Pytorch/blob/master/run_mrpc.py#L224 19 | 20 | ## Setup 21 | 22 | ### Text Classification 23 | 24 | #### IMDB 25 | ```shell script 26 | $ python run_text_classifier.py --optimizer=acclip --lr=0.001 --epoch=30 27 | ``` 28 | optmizers chosen from `acclip`, `adam` and `sgd`. 29 | 30 | ##### Plot Noise 31 | ```shell script 32 | $ python run_text_classifier.py --optimizer=sgd --lr=0.1 --epoch=30 --mode=plot 33 | ``` 34 | 35 | #### SST 36 | ```shell script 37 | $ python run_text_classifier.py --optimizer=acclip --lr=0.001 --epoch=20 --dataset=sst 38 | ``` 39 | 40 | ### GLUE task 41 | * Donwload the datasets 42 | ```shell script 43 | $ python download_mrpc_data.py 44 | ``` 45 | 46 | * Run Bert model, where you can specify your optimizers and learning rates in ```run_mrpc.sh``` 47 | ```shell script 48 | $ sh ./run_mrpc.sh 49 | ``` 50 | 51 | * To plot noise 52 | switch `do_train` with `do_plot` in `./run_mrpc.sh` 53 | 54 | ## Plot for visualization 55 | For training, the evaluation results will be written to the `curves` folder. 56 | Install Jupyter notebook to run ```plot.ipynb```. 57 | 58 | For plotting noise norm, the evaluation results will be written to `noises` folder. 59 | Run ```noise_plot.ipynb``` 60 | -------------------------------------------------------------------------------- /curves/imdb-curves/acclip-imdb-lstm: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rivercold/ACClip-Pytorch/8141bc87b1493e31c0a29999db2e315a4787721f/curves/imdb-curves/acclip-imdb-lstm -------------------------------------------------------------------------------- /curves/imdb-curves/adam-imdb-lstm: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rivercold/ACClip-Pytorch/8141bc87b1493e31c0a29999db2e315a4787721f/curves/imdb-curves/adam-imdb-lstm -------------------------------------------------------------------------------- /curves/imdb-curves/sgd-imdb-lstm: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rivercold/ACClip-Pytorch/8141bc87b1493e31c0a29999db2e315a4787721f/curves/imdb-curves/sgd-imdb-lstm -------------------------------------------------------------------------------- /curves/mrpc-curves/acclip-mrpc-bert: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rivercold/ACClip-Pytorch/8141bc87b1493e31c0a29999db2e315a4787721f/curves/mrpc-curves/acclip-mrpc-bert -------------------------------------------------------------------------------- /curves/mrpc-curves/adam-mrpc-bert: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rivercold/ACClip-Pytorch/8141bc87b1493e31c0a29999db2e315a4787721f/curves/mrpc-curves/adam-mrpc-bert -------------------------------------------------------------------------------- /curves/mrpc-curves/sgd-mrpc-bert: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rivercold/ACClip-Pytorch/8141bc87b1493e31c0a29999db2e315a4787721f/curves/mrpc-curves/sgd-mrpc-bert -------------------------------------------------------------------------------- /curves/squad-curves/acclip-squad-bert: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rivercold/ACClip-Pytorch/8141bc87b1493e31c0a29999db2e315a4787721f/curves/squad-curves/acclip-squad-bert -------------------------------------------------------------------------------- /curves/squad-curves/adam-squad-bert: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rivercold/ACClip-Pytorch/8141bc87b1493e31c0a29999db2e315a4787721f/curves/squad-curves/adam-squad-bert -------------------------------------------------------------------------------- /curves/sst-curves/acclip-sst-lstm: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rivercold/ACClip-Pytorch/8141bc87b1493e31c0a29999db2e315a4787721f/curves/sst-curves/acclip-sst-lstm -------------------------------------------------------------------------------- /curves/sst-curves/adam-sst-lstm: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rivercold/ACClip-Pytorch/8141bc87b1493e31c0a29999db2e315a4787721f/curves/sst-curves/adam-sst-lstm -------------------------------------------------------------------------------- /curves/sst-curves/sgd-sst-lstm: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rivercold/ACClip-Pytorch/8141bc87b1493e31c0a29999db2e315a4787721f/curves/sst-curves/sgd-sst-lstm -------------------------------------------------------------------------------- /download_mrpc_data.py: -------------------------------------------------------------------------------- 1 | ''' Script for downloading all GLUE data. 2 | Note: for legal reasons, we are unable to host MRPC. 3 | You can either use the version hosted by the SentEval team, which is already tokenized, 4 | or you can download the original data from (https://download.microsoft.com/download/D/4/6/D46FF87A-F6B9-4252-AA8B-3604ED519838/MSRParaphraseCorpus.msi) and extract the data from it manually. 5 | For Windows users, you can run the .msi file. For Mac and Linux users, consider an external library such as 'cabextract' (see below for an example). 6 | You should then rename and place specific files in a folder (see below for an example). 7 | mkdir MRPC 8 | cabextract MSRParaphraseCorpus.msi -d MRPC 9 | cat MRPC/_2DEC3DBE877E4DB192D17C0256E90F1D | tr -d $'\r' > MRPC/msr_paraphrase_train.txt 10 | cat MRPC/_D7B391F9EAFF4B1B8BCE8F21B20B1B61 | tr -d $'\r' > MRPC/msr_paraphrase_test.txt 11 | rm MRPC/_* 12 | rm MSRParaphraseCorpus.msi 13 | 1/30/19: It looks like SentEval is no longer hosting their extracted and tokenized MRPC data, so you'll need to download the data from the original source for now. 14 | 2/11/19: It looks like SentEval actually *is* hosting the extracted data. Hooray! 15 | ''' 16 | 17 | import os 18 | import sys 19 | import shutil 20 | import argparse 21 | import tempfile 22 | import urllib.request 23 | import zipfile 24 | # TASKS = ["CoLA", "SST", "MRPC", "QQP", "STS", "MNLI", "SNLI", "QNLI", "RTE", "WNLI", "diagnostic"] 25 | TASKS = ["MRPC"] 26 | TASK2PATH = {"CoLA":'https://firebasestorage.googleapis.com/v0/b/mtl-sentence-representations.appspot.com/o/data%2FCoLA.zip?alt=media&token=46d5e637-3411-4188-bc44-5809b5bfb5f4', 27 | "SST":'https://firebasestorage.googleapis.com/v0/b/mtl-sentence-representations.appspot.com/o/data%2FSST-2.zip?alt=media&token=aabc5f6b-e466-44a2-b9b4-cf6337f84ac8', 28 | "MRPC":'https://firebasestorage.googleapis.com/v0/b/mtl-sentence-representations.appspot.com/o/data%2Fmrpc_dev_ids.tsv?alt=media&token=ec5c0836-31d5-48f4-b431-7480817f1adc', 29 | "QQP":'https://firebasestorage.googleapis.com/v0/b/mtl-sentence-representations.appspot.com/o/data%2FQQP.zip?alt=media&token=700c6acf-160d-4d89-81d1-de4191d02cb5', 30 | "STS":'https://firebasestorage.googleapis.com/v0/b/mtl-sentence-representations.appspot.com/o/data%2FSTS-B.zip?alt=media&token=bddb94a7-8706-4e0d-a694-1109e12273b5', 31 | "MNLI":'https://firebasestorage.googleapis.com/v0/b/mtl-sentence-representations.appspot.com/o/data%2FMNLI.zip?alt=media&token=50329ea1-e339-40e2-809c-10c40afff3ce', 32 | "SNLI":'https://firebasestorage.googleapis.com/v0/b/mtl-sentence-representations.appspot.com/o/data%2FSNLI.zip?alt=media&token=4afcfbb2-ff0c-4b2d-a09a-dbf07926f4df', 33 | "QNLI": 'https://firebasestorage.googleapis.com/v0/b/mtl-sentence-representations.appspot.com/o/data%2FQNLIv2.zip?alt=media&token=6fdcf570-0fc5-4631-8456-9505272d1601', 34 | "RTE":'https://firebasestorage.googleapis.com/v0/b/mtl-sentence-representations.appspot.com/o/data%2FRTE.zip?alt=media&token=5efa7e85-a0bb-4f19-8ea2-9e1840f077fb', 35 | "WNLI":'https://firebasestorage.googleapis.com/v0/b/mtl-sentence-representations.appspot.com/o/data%2FWNLI.zip?alt=media&token=068ad0a0-ded7-4bd7-99a5-5e00222e0faf', 36 | "diagnostic":'https://storage.googleapis.com/mtl-sentence-representations.appspot.com/tsvsWithoutLabels%2FAX.tsv?GoogleAccessId=firebase-adminsdk-0khhl@mtl-sentence-representations.iam.gserviceaccount.com&Expires=2498860800&Signature=DuQ2CSPt2Yfre0C%2BiISrVYrIFaZH1Lc7hBVZDD4ZyR7fZYOMNOUGpi8QxBmTNOrNPjR3z1cggo7WXFfrgECP6FBJSsURv8Ybrue8Ypt%2FTPxbuJ0Xc2FhDi%2BarnecCBFO77RSbfuz%2Bs95hRrYhTnByqu3U%2FYZPaj3tZt5QdfpH2IUROY8LiBXoXS46LE%2FgOQc%2FKN%2BA9SoscRDYsnxHfG0IjXGwHN%2Bf88q6hOmAxeNPx6moDulUF6XMUAaXCSFU%2BnRO2RDL9CapWxj%2BDl7syNyHhB7987hZ80B%2FwFkQ3MEs8auvt5XW1%2Bd4aCU7ytgM69r8JDCwibfhZxpaa4gd50QXQ%3D%3D'} 37 | 38 | MRPC_TRAIN = 'https://dl.fbaipublicfiles.com/senteval/senteval_data/msr_paraphrase_train.txt' 39 | MRPC_TEST = 'https://dl.fbaipublicfiles.com/senteval/senteval_data/msr_paraphrase_test.txt' 40 | 41 | def download_and_extract(task, data_dir): 42 | print("Downloading and extracting %s..." % task) 43 | data_file = "%s.zip" % task 44 | urllib.request.urlretrieve(TASK2PATH[task], data_file) 45 | with zipfile.ZipFile(data_file) as zip_ref: 46 | zip_ref.extractall(data_dir) 47 | os.remove(data_file) 48 | print("\tCompleted!") 49 | 50 | def format_mrpc(data_dir, path_to_data): 51 | print("Processing MRPC...") 52 | mrpc_dir = os.path.join(data_dir, "MRPC") 53 | if not os.path.isdir(mrpc_dir): 54 | os.mkdir(mrpc_dir) 55 | if path_to_data: 56 | mrpc_train_file = os.path.join(path_to_data, "msr_paraphrase_train.txt") 57 | mrpc_test_file = os.path.join(path_to_data, "msr_paraphrase_test.txt") 58 | else: 59 | print("Local MRPC data not specified, downloading data from %s" % MRPC_TRAIN) 60 | mrpc_train_file = os.path.join(mrpc_dir, "msr_paraphrase_train.txt") 61 | mrpc_test_file = os.path.join(mrpc_dir, "msr_paraphrase_test.txt") 62 | urllib.request.urlretrieve(MRPC_TRAIN, mrpc_train_file) 63 | urllib.request.urlretrieve(MRPC_TEST, mrpc_test_file) 64 | assert os.path.isfile(mrpc_train_file), "Train data not found at %s" % mrpc_train_file 65 | assert os.path.isfile(mrpc_test_file), "Test data not found at %s" % mrpc_test_file 66 | urllib.request.urlretrieve(TASK2PATH["MRPC"], os.path.join(mrpc_dir, "dev_ids.tsv")) 67 | 68 | dev_ids = [] 69 | with open(os.path.join(mrpc_dir, "dev_ids.tsv"), encoding="utf8") as ids_fh: 70 | for row in ids_fh: 71 | dev_ids.append(row.strip().split('\t')) 72 | 73 | with open(mrpc_train_file, encoding="utf8") as data_fh, \ 74 | open(os.path.join(mrpc_dir, "train.tsv"), 'w', encoding="utf8") as train_fh, \ 75 | open(os.path.join(mrpc_dir, "dev.tsv"), 'w', encoding="utf8") as dev_fh: 76 | header = data_fh.readline() 77 | train_fh.write(header) 78 | dev_fh.write(header) 79 | for row in data_fh: 80 | label, id1, id2, s1, s2 = row.strip().split('\t') 81 | if [id1, id2] in dev_ids: 82 | dev_fh.write("%s\t%s\t%s\t%s\t%s\n" % (label, id1, id2, s1, s2)) 83 | else: 84 | train_fh.write("%s\t%s\t%s\t%s\t%s\n" % (label, id1, id2, s1, s2)) 85 | 86 | with open(mrpc_test_file, encoding="utf8") as data_fh, \ 87 | open(os.path.join(mrpc_dir, "test.tsv"), 'w', encoding="utf8") as test_fh: 88 | header = data_fh.readline() 89 | test_fh.write("index\t#1 ID\t#2 ID\t#1 String\t#2 String\n") 90 | for idx, row in enumerate(data_fh): 91 | label, id1, id2, s1, s2 = row.strip().split('\t') 92 | test_fh.write("%d\t%s\t%s\t%s\t%s\n" % (idx, id1, id2, s1, s2)) 93 | print("\tCompleted!") 94 | 95 | def download_diagnostic(data_dir): 96 | print("Downloading and extracting diagnostic...") 97 | if not os.path.isdir(os.path.join(data_dir, "diagnostic")): 98 | os.mkdir(os.path.join(data_dir, "diagnostic")) 99 | data_file = os.path.join(data_dir, "diagnostic", "diagnostic.tsv") 100 | urllib.request.urlretrieve(TASK2PATH["diagnostic"], data_file) 101 | print("\tCompleted!") 102 | return 103 | 104 | def get_tasks(task_names): 105 | task_names = task_names.split(',') 106 | if "all" in task_names: 107 | tasks = TASKS 108 | else: 109 | tasks = [] 110 | for task_name in task_names: 111 | assert task_name in TASKS, "Task %s not found!" % task_name 112 | tasks.append(task_name) 113 | return tasks 114 | 115 | def main(arguments): 116 | parser = argparse.ArgumentParser() 117 | parser.add_argument('--data_dir', help='directory to save data to', type=str, default='glue_data') 118 | parser.add_argument('--tasks', help='tasks to download data for as a comma separated string', 119 | type=str, default='all') 120 | parser.add_argument('--path_to_mrpc', help='path to directory containing extracted MRPC data, msr_paraphrase_train.txt and msr_paraphrase_text.txt', 121 | type=str, default='') 122 | args = parser.parse_args(arguments) 123 | 124 | if not os.path.isdir(args.data_dir): 125 | os.mkdir(args.data_dir) 126 | tasks = get_tasks(args.tasks) 127 | 128 | for task in tasks: 129 | if task == 'MRPC': 130 | format_mrpc(args.data_dir, args.path_to_mrpc) 131 | elif task == 'diagnostic': 132 | download_diagnostic(args.data_dir) 133 | else: 134 | download_and_extract(task, args.data_dir) 135 | 136 | 137 | if __name__ == '__main__': 138 | sys.exit(main(sys.argv[1:])) -------------------------------------------------------------------------------- /models/lstm_attn.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from torch.autograd import Variable 4 | from torch.nn import functional as F 5 | 6 | class LSTM_Attn(torch.nn.Module): 7 | def __init__(self, batch_size, output_size, hidden_size, vocab_size, embedding_length, weights): 8 | super(LSTM_Attn, self).__init__() 9 | 10 | """ 11 | Arguments 12 | --------- 13 | batch_size : Size of the batch which is same as the batch_size of the data returned by the TorchText BucketIterator 14 | output_size : 2 = (pos, neg) 15 | hidden_sie : Size of the hidden_state of the LSTM 16 | vocab_size : Size of the vocabulary containing unique words 17 | embedding_length : Embeddding dimension of GloVe word embeddings 18 | weights : Pre-trained GloVe word_embeddings which we will use to create our word_embedding look-up table 19 | 20 | -------- 21 | """ 22 | self.batch_size = batch_size 23 | self.output_size = output_size 24 | self.hidden_size = hidden_size 25 | self.vocab_size = vocab_size 26 | self.embedding_length = embedding_length 27 | 28 | self.word_embeddings = nn.Embedding(vocab_size, embedding_length) 29 | self.word_embeddings.weights = nn.Parameter(weights, requires_grad=False) 30 | self.lstm = nn.LSTM(embedding_length, hidden_size) 31 | self.classify = nn.Linear(hidden_size, output_size) 32 | 33 | # self.attn_fc_layer = nn.Linear() 34 | 35 | def attention_net(self, lstm_output, final_state): 36 | 37 | """ 38 | Arguments 39 | --------- 40 | 41 | lstm_output : Final output of the LSTM which contains hidden layer outputs for each sequence. 42 | final_state : Final time-step hidden state (h_n) of the LSTM 43 | --------- 44 | 45 | Returns : It performs attention mechanism by first computing weights for each of the sequence present in lstm_output and and then finally computing the 46 | new hidden state. 47 | 48 | Tensor Size : 49 | hidden.size() = (batch_size, hidden_size) 50 | attn_weights.size() = (batch_size, num_seq) 51 | soft_attn_weights.size() = (batch_size, num_seq) 52 | new_hidden_state.size() = (batch_size, hidden_size) 53 | """ 54 | 55 | hidden = final_state.squeeze(0) 56 | attn_weights = torch.bmm(lstm_output, hidden.unsqueeze(2)).squeeze(2) 57 | soft_attn_weights = F.softmax(attn_weights, 1) 58 | new_hidden_state = torch.bmm(lstm_output.transpose(1, 2), soft_attn_weights.unsqueeze(2)).squeeze(2) 59 | 60 | return new_hidden_state 61 | 62 | def forward(self, input_sentences): 63 | 64 | """ 65 | Parameters 66 | ---------- 67 | input_sentence: input_sentence of shape = (batch_size, num_sequences) 68 | batch_size : default = None. Used only for prediction on a single sentence after training (batch_size = 1) 69 | 70 | Returns 71 | ------- 72 | Output of the linear layer containing logits for pos & neg class which receives its input as the new_hidden_state which is basically the output of the Attention network. 73 | final_output.shape = (batch_size, output_size) 74 | 75 | """ 76 | input = self.word_embeddings(input_sentences) 77 | num_ins = input.size(0) 78 | input = input.permute(1, 0, 2) 79 | 80 | if torch.cuda.is_available(): 81 | h_0 = Variable(torch.zeros(1, num_ins, self.hidden_size)).cuda() 82 | c_0 = Variable(torch.zeros(1, num_ins, self.hidden_size)).cuda() 83 | else: 84 | h_0 = Variable(torch.zeros(1, num_ins, self.hidden_size)) 85 | c_0 = Variable(torch.zeros(1, num_ins, self.hidden_size)) 86 | 87 | output, (final_hidden_state, final_cell_state) = self.lstm(input, ( 88 | h_0, c_0)) # final_hidden_state.size() = (1, batch_size, hidden_size) 89 | output = output.permute(1, 0, 2) # output.size() = (batch_size, num_seq, hidden_size) 90 | 91 | attn_output = self.attention_net(output, final_hidden_state) 92 | logits = self.classify(attn_output) 93 | 94 | return logits -------------------------------------------------------------------------------- /noise_plot.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": {}, 6 | "source": [ 7 | "** The Plot script for all optimizers and comparisons **" 8 | ] 9 | }, 10 | { 11 | "cell_type": "code", 12 | "execution_count": 151, 13 | "metadata": {}, 14 | "outputs": [], 15 | "source": [ 16 | "import numpy as np\n", 17 | "import matplotlib.mlab as mlab\n", 18 | "import matplotlib.pyplot as plt" 19 | ] 20 | }, 21 | { 22 | "cell_type": "code", 23 | "execution_count": 152, 24 | "metadata": {}, 25 | "outputs": [], 26 | "source": [ 27 | "import pickle\n", 28 | "def read_pickle(path):\n", 29 | " with open(path, \"rb\") as f:\n", 30 | " d = pickle.load(f)\n", 31 | " return d" 32 | ] 33 | }, 34 | { 35 | "cell_type": "code", 36 | "execution_count": 370, 37 | "metadata": { 38 | "scrolled": false 39 | }, 40 | "outputs": [ 41 | { 42 | "name": "stdout", 43 | "output_type": "stream", 44 | "text": [ 45 | "20\n", 46 | "1\n" 47 | ] 48 | }, 49 | { 50 | "data": { 51 | "image/png": "\n", 52 | "text/plain": [ 53 | "
" 54 | ] 55 | }, 56 | "metadata": { 57 | "needs_background": "light" 58 | }, 59 | "output_type": "display_data" 60 | }, 61 | { 62 | "name": "stdout", 63 | "output_type": "stream", 64 | "text": [ 65 | "10\n" 66 | ] 67 | }, 68 | { 69 | "data": { 70 | "image/png": "\n", 71 | "text/plain": [ 72 | "
" 73 | ] 74 | }, 75 | "metadata": { 76 | "needs_background": "light" 77 | }, 78 | "output_type": "display_data" 79 | }, 80 | { 81 | "name": "stdout", 82 | "output_type": "stream", 83 | "text": [ 84 | "19\n" 85 | ] 86 | }, 87 | { 88 | "data": { 89 | "image/png": "\n", 90 | "text/plain": [ 91 | "
" 92 | ] 93 | }, 94 | "metadata": { 95 | "needs_background": "light" 96 | }, 97 | "output_type": "display_data" 98 | } 99 | ], 100 | "source": [ 101 | "# lstm plot\n", 102 | "d = read_pickle(\"./sgd-imdb-lstm_noise\")\n", 103 | "print(len(d))\n", 104 | "for i in [1,10,19]:#[3,9,15,19]:\n", 105 | " noises=d[i]\n", 106 | " num_bins = len(noises)\n", 107 | " plt.yticks(np.linspace(0,3.5,6))\n", 108 | " plt.xticks(np.linspace(0,3,4))\n", 109 | " plt.xlim(0,3)\n", 110 | " plt.ylim(0,3.5)\n", 111 | " n, bins, patches = plt.hist(noises, 500, density=1, facecolor='blue')\n", 112 | " plt.ylabel('Density (%)')\n", 113 | " plt.xlabel('Noise norm')\n", 114 | " plt.xticks(fontsize=20)\n", 115 | " plt.yticks(fontsize=20)\n", 116 | " plt.rc('axes', labelsize=20)\n", 117 | " #plt.savefig('text.png')\n", 118 | " plt.tight_layout()\n", 119 | " print(i)\n", 120 | " \n", 121 | " plt.savefig('lstm'+str(i)+'.png',dpi=300)\n", 122 | " plt.show()" 123 | ] 124 | }, 125 | { 126 | "cell_type": "code", 127 | "execution_count": 337, 128 | "metadata": { 129 | "scrolled": false 130 | }, 131 | "outputs": [ 132 | { 133 | "data": { 134 | "image/png": "\n", 135 | "text/plain": [ 136 | "
" 137 | ] 138 | }, 139 | "metadata": { 140 | "needs_background": "light" 141 | }, 142 | "output_type": "display_data" 143 | }, 144 | { 145 | "data": { 146 | "image/png": "\n", 147 | "text/plain": [ 148 | "
" 149 | ] 150 | }, 151 | "metadata": { 152 | "needs_background": "light" 153 | }, 154 | "output_type": "display_data" 155 | }, 156 | { 157 | "data": { 158 | "image/png": "\n", 159 | "text/plain": [ 160 | "
" 161 | ] 162 | }, 163 | "metadata": { 164 | "needs_background": "light" 165 | }, 166 | "output_type": "display_data" 167 | }, 168 | { 169 | "data": { 170 | "image/png": "iVBORw0KGgoAAAANSUhEUgAAAagAAAEYCAYAAAAJeGK1AAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4xLjEsIGh0dHA6Ly9tYXRwbG90bGliLm9yZy8QZhcZAAAgAElEQVR4nO3deZxcVZn/8c83LCIBEpYwOGIIMISIC8u0bBHIopkMyjIs83McUUAGGcCgwoyMOJIw4y6LQRFhBERUZEBRR8UoQdAAahAUMSQQ6ACGJRDZk7Dk+f1xTiWXSlVX3erq7krq+3696nXTdznn1CXph3PuuedRRGBmZtZphg11A8zMzGpxgDIzs47kAGVmZh3JAcrMzDqSA5SZmXUkBygzM+tI6w91A9YVW221VYwZM2aom2Fmtta57bbbHo+IUdX7OyZASdoWOAuYCmwJPAxcC8yIiL+0WOb+wA2knuInI+Ljdc7bF/g4sDewEXAvcAlwfkS83ExdY8aMYe7cua0008ysq0laVGt/RwzxSdoRuA04BvgNcC5wH3AKcIukLVsoc1Pg68DzDc47BLgJ2B/4HvBlYMPchivL1mtmZu3REQEKuADYGpgWEYdGxOkRMYkUJHYGPtlCmV8ERgCfrneCpM2Ai4GXgQkR8f6I+DdgN+AW4AhJ72qhbjMz66chD1CSdgCmAL2k3kvRmcBzwFGShpco8xBSb2wasLiPU48ARgFXRsSq8bmIWE4a8gP412brNTOz9hnyAAVMyttZEbGyeCAingHmABuTng81JGlrUq/o2oi4osm6r6tx7CbS8OC+kl7VTN1mZtY+nRCgds7bBXWO35O3Y5ss7yLS9zqhP3VHxEvA/aSJJDs0WbeZmbVJJwSoEXn7VJ3jlf0jGxUk6VjgEODEiHh0oOuWdLykuZLmLlmypInqzMysWZ0QoBpR3vaZF0TSGOA84H8j4qrBqDsiLoqInojoGTVqjSn8ZmbWD50QoCq9lBF1jm9WdV49lwDLgBOHoG4zM2uzTghQ8/O23jOmnfK23jOqij1IU9WXSIrKB7g0Hz8j77u2mbolrQ9sD7xEeifLzMwGUSesJHFD3k6RNKw4ky+/bDue1DO6tUE5l5Nm+1XbifQS7h2kl4FvLxybDfwzafWKb1ddt38u76aIWNHcV2kvafWfnfjYzLrNkAeoiFgoaRbpXaiTgPMLh2cAw4GvRsRzlZ2SxuVr7y6UM61W+ZKOJgWbH9VY6uhq4LPAuySdX3kXStJGwH/nc77S+rczM7NWDXmAyk4EbgZmSpoMzAP2AiaShvbOqDp/Xt6KfoiIpyX9CylQ/ULSlcBS4GDSFPSrge/0pw4zM2tNJzyDIiIWAj3AZaTAdCqwIzAT2CcinhjAuq8FDiC9mHs48EHgReAjwLsiPLhmZjYU5N+/7dHT0xPtXs3cz6DMrBtIui0ieqr3d0QPyszMrJoDlJmZdSQHKDMz60gOUGZm1pEcoMzMrCM5QJmZWUdygDIzs47kAGVmZh3JAcrMzDqSA5SZmXUkBygzM+tIHROgJG0r6RJJiyWtkNQr6TxJm5co498k/Thf+6ykpyXdKekcSdvWuSb6+DTKQWVmZgOkI9JtSNqRlG5ja+D7wN3AnsApwFRJ45tc0fwDwLPAjcCjwAbA7sCHgfdLmhARt9e4bhFpJfVqD5X8KmZm1iYdEaCAC0jBaVpErEpYKOkcUnD5JHBCE+W8MSKWV+/MOZ8uyuUcWOO63oiY3kK7zcxsgAz5EJ+kHUjZdHuBL1cdPhN4DjhK0vBGZdUKTtlVebtTi800M7NB1gk9qEl5OysiVhYPRMQzkuaQAtjewPUt1nFQ3v6hzvGRko4FtgGeAm6LCD9/MjMbQp0QoHbO2wV1jt9DClBjaTJASToO2BbYBHgT8DbSc6bT61yyK/C1qjJ+DxwVEXc2U6eZmbVXJwSoEXn7VJ3jlf0jS5R5HCl1fMVvgXdHxL01zj0HuIYUIJcD44CPAkcAsyXtFhF/rlWJpOOB4wFGjx5donlmZtbIkD+DakIl8XnTSc8jYu+IELAVqfcFcJukqTXOPTUibo6IxyPi2YiYGxFHkoLWVsBpfdRzUUT0RETPqFGjmv5CZmbWWCcEqEoPaUSd45tVnde0iHgiIn5GClLLgMslvbrJyy/M2/3L1mtmZv3XCQFqft6OrXO8MvOu3jOqhiLiSeAWYBTwhiYvW5K3DWcPmplZ+3VCgLohb6dIekV7JG0KjCf1fvo7q+61eftSk+fvnbf39bNeMzNrwZAHqIhYCMwCxgAnVR2eQerBXB4Rz1V2ShonaVzxREnb5Xeq1iDpA8BbgAeBOwv796j1fpWkN5Ne6gW4oux3MjOz/uuEWXwAJ5KWOpopaTIwjzQLbyJpaO+MqvPn5a0K+3YHvivp5nzNo8CWpJ7Qm0hLIB0VES8XrpkGHCZpNil4rSDN4psKrAdcDHy7Td/RzMxK6IgAFRELJfUAZ5GCw4HAw8BMYEZELG2imN8B5wL7Ae8AtiBNG78POBv4YkQ8WHXNtaRJGG8mvTC8EfAE8BPg4oj4QT+/mpmZtUgRTc/etj709PTE3Llz21qmCv1D/2cys3WVpNsioqd6/5A/gzIzM6vFAcrMzDqSA5SZmXUkBygzM+tIDlBmZtaROmKauTVWnNEHntVnZus+96DMzKwjOUCZmVlHcoAyM7OO5ABlZmYdyQHKzMw6UqlZfIX8TKNJ6dCXAY8Bd0TEXe1vnpmZdauGASqnSP8n4P3AnqzudVUmPkc+73HgGuArEXFndTlN1LMtq1cz35K0mvm1pNXM/9JkGf9GStGxCymArgQWAT8DzomIh+pctwswHZhAWt18EXAl8JmIWFb2u5iZWf/VXc1c0vqkfElnAJuTUlfcBvwWeARYCryaFEzGkfIujSYFrJ8Dp0bEH5tqhLQjKR/U1sD3gbtJwXAiKSX8+Ih4ooly7iXlffo9KR/UBqQ8UQcATwMTIuL2qmv2Ambnc68m5YWaBPQAc4DJEbGiUd0DvZp5Nb8HZWbrinqrmffVg7ob2B64Dvg68P1Gv6gljQWOBt4L3C7p/RFxeRPtu4AUnKZFxPmF8s4BPkzKbntCE+W8MSKW12jXvwAX5XIOLOxfD7gU2Bg4pJL/Kaeevwo4PNf/mSbqNjOzNuqrB/UD4D8j4velC5U2JGXJXRYRX21w7g7AQqAX2DEiVhaObUoa6hOwdTHte8n2jACeBO6NiJ0K+ycB1wM3RcQBddq1CNg+GiTOcg/KzKw1pfNBRcTBrQSnfO0LEXFeo+CUTcrbWcXglMt5hjTMtjFpCLFVB+XtH+rUfV31BRFxHyl1/HbADv2o28zMWtAJa/HtnLcL6hy/B5gCjCX1dhqSdBywLbAJ8CbgbaSe0Okt1D02fxbWqOd44HiA0aNHN9M0MzNrUr8ClKTtSTPmAP4UEfe3UMyIvH2qzvHK/pElyjwO2Kvw82+Bd0fEve2sOyIuIj3boqenx4NuZmZt1NKLupI2lfQd4F7gB8APgXslfTs/N2qnV0xnb0ZE7B0RIk01n5J33yZp6kDXbWZm7dFqD+pLpF/8Z5Kmnm8EHAy8D3ie9M5Usyq9lBF1jm9WdV7T8tT0n0n6LWlW4uWStiu82zRgdZuZWf/0GaAkbRwRz9c49A/Av0bENwv7vidp43ysTICan7dj6xyvzLqr95yooYh4UtItwKHAG4DKdLsBr9vMzFrTaIjvTkkTa+xfH3imxv5nKN8ruyFvp+T3j1YpLK20DLi1ZLnVXpu3LxX2zc7bNYb+8jTzsaTJFff1s24zMyupUYC6Ffi5pAurni1dD3xJ0j9JGidpV0kfJ72g+/MyDYiIhcAsYAxwUtXhGcBw4PLiO1C5znHFEyVtl4PKGiR9AHgLaZWI4jJMNwLzgP0lHVw4fxjw2fzjhY3egTIzs/ar+6LuqhOkdwJfIU0UOD4irpP0OtKSRLuxegKBSM+jDomIxaUaseZSR/NIs/AmkobX9i0udSQpAPJEiMq+Q4Hv5nIWkJY62pL0/tSbSEsgvTMibqyqu3qpoweAyXipIzOzQVHvRd2GASpfvBlwLnAMcDnwofxc5+2kdfgE3BURTb2nVKeO11F/sdilVefWClCjgVOA/Ugv125BWj/wPtJisV+MiAfr1L0Lqbc2EdiUNKz3bUosFusAZWbWmn4FqEIhbye99/Mq4MSIuLZ9TVy7OUCZmbWm9FJHtUTEz4A3At8DrpF0paQt29RGMzOzVZoKUJK2kvS3kraKiOci4iTScNjuwDxJ7xrQVpqZWdfpM0BJ2iSvGPEo8BvgUUlXSdokIm4CdgW+AVwh6VpJ2wx8k83MrBs06kF9GjiSlA/qJOAy4AhyfqSIWB4Rp5LeVdoJ+JOkoweqsWZm1j0aBahDgKsj4tiIuDAi3k9K635I8aSI+DVpyvlXgGZSbJiZmfWpUYAaDjxUte9BUn6mV4iIFyPiDFKqdjMzs35pZiWJoySNl7ShpH2A99DHskOtJjk0MzMrahSgTiGtg3dT3v6K9PLrhwa4XWZm1uX6XNg1Iu7Na94dBIwmLQP0f8V18czMzAZCw5XHc7qN7wxCW8zMzFZpKaOumZnZQKsboCSdKmmjVguWtIekvy9x/raSLpG0WNIKSb2SzpO0eZPXD5f0z5K+JeluSc9JekbS3PxdNqxzXfTx6W8OKjMza1FfQ3yfAj4iaSZwRUT8uVFhkkRKBX8CKQX8GcBPmriuOt3G3aTp6qcAUyWNL6bbqGM/4ApgKSkJ4rWkFc0PAr4AHCZpckQsr3HtItJLyNWqp9ibmdkg6StAvQk4h7SaxH9Lupk0i28uKRXGX4CNSKkxxpHyLk0GtgGeAE6m+Zd2LyAFp2kRcX5lp6RzgA8DnyQFvb48QpoC/78R8UKhjE2BXwD7klbDOLvGtb0RMb3JtnaE4krnXtnczNZFzSQsrPxiPxzYkNUJCl9xWt7OJ60mcWlE1EoJX6v8HYCFQC+wY0SsLBzblBQMBWzd6uxBSe8GvkmagXhQ1bEAboyICa2UXTHY6TaKHKDMbG1WL91GM7P4bgZulnQCsD/wVtKU8y1J70Y9BvwB+EVE3NVC2ybl7axicMp1PyNpDmnYcG9SqvlWvJi3L9U5PlLSsaTe31PAbRHh509mZkOoYYCqyD2iH+VPO+2ctwvqHL+HFKDG0nqAOjZvr6tzfFfga8Udkn4PHBURd7ZYp5mZ9UMnTDMfkbdP1Tle2T+ylcIlnUxKI38HcEmNU84hrcY+ipTu/S3A1aSgNVvSa/so+/g8S3DukiVLWmmemZnV0QkBqpHKk5jST1okHQacR5pAcXhEvFh9TkScGhE3R8TjEfFsRMyNiCNJq7ZvBZxWr/yIuCgieiKiZ9SoUWWbV6fNqz9mZt2sEwJUpYc0os7xzarOa4qkQ4ErSc/IJkTEfSXbdWHe7l/yOjMza4NOCFDz83ZsneM75W29Z1RrkHQk8L+kTMAHRMT8BpfUUhmzG97CtWZm1k+dEKBuyNspkl7RnjzNfDxptmBTs+rylPJvA4tJwemeFtu1d96W7XmZmVkbDHmAioiFwCxgDOl9q6IZpB7M5cV3oCSNy6usv4Kk9wHfIK26vn+jYb28HNMaPSRJbya9HAxpdQozMxtkTU8zH2AnkpY6milpMjAP2AuYSBraO6Pq/Hl5u2oqgaSJpFl6w0i9smO05kyDJyPivMLP00hLIM0mZQpeQVoVYyqwHnAxqTdmZmaDrOkAJekfgB9ExMvtbkRELJTUA5xFCg4HklaQmAnMiIilTRSzHat7hMfWOWcRaVZfxbWkSRhvJr0wvBFpmaafABdHxA9KfhUzM2uThksdrTpRWkl6rnMJ8D8R8cBANmxt066ljlqZXu6ljsxsbVZvqaMyz6AuADYGPg4slPRDSe9UjXE0MzOz/mo6QEXEycBfk4bP5gLvIKXGWCTpE32tuGBmZlZWqVl8EbE8Ii6LiH1Iz22+AmwCTAful/Q9SVPb30wzM+s2LU8zj4g/FnpVx5Beij0Y+JGk+yWdVmsKt5mZWTP69R5UDkDvJU3Xfi1p2vfvSak4PgfcLWm3/jbSzMy6T0sBStLuki4kzeq7kLRM0f8Ae0TEHqRe1emkxVZntqmtZmbWRcq8B7Ux8E/AB4C/JfWW5pEC1Ncj4unKuRHxLPA5Sa8D3t/WFpuZWVcos5LEYlK+pJdJqSguiIhfNLjmz6SXX83MzEopE6CeAc4mrbDwSJPXXICXCjIzsxaUCVDbRcTKMoXnYb+nG55oZmZWpcwkiZ9Lem9fJ0h6T1541czMrF/KBKgJpJQYfdkOOKCVhkjaVtIlkhZLWiGpV9J5kjZv8vrhkv5Z0rck3S3pOUnPSJor6VRJG/Zx7S6SrpL0mKTlkuZLmiHp1a18FzMz6792p9t4NfBS2Ysk7UhKt7E1afmku4E9gVOAqZLGR8QTDYrZj5S7aSkp3ca1wBbAQcAXSGk1JkfE8qq69wJmAxsAV5PSbkwCPgFMztesKPudzMysf8oGqJrrZucFY0eT0mQ82EI7LiAFp2kRcX6h3HOAD5OSB57QoIxHgPcA/xsRLxTK2BT4BbAvKSHi2YVj6wGXkhbBPaSSXiNn9r0KODzX/5kWvpOZmfVDn+k2coqNygmiToAqXgJ8KiI+3nQDpB2AhUAvsGNxIkYOLg/ncrcuZtUtI6eB/ybwfxFxUGH/JOB64KaIOKDqmkq7FgHbR4O8JE63YWbWmnrpNhr1oG5idVDan5RKvbfGeS+TEv1dT1pRooxJeTurepZgRDwjaQ4wBdg7l9+KF/O2evixUvd11RdExH2SFpBWyagEKzMzGyR9BqiImFD5c+5NXRoRZ7W5DTvn7YI6x+8hBaixtB6gKhl2qwNRM3WPzR8HKDOzQVTmGdT2wJMD0IYReftUneOV/SNbKVzSyaQ08neQsgG3rW5JxwPHA4wePbqV5pmZWR1lEhYuioh6v8gHUuWpTOknLZIOA84jTaA4PCJebHBJqboj4qKI6ImInlGjRpVtnpmZ9aFuD0rSJ0i/mL8cEUvzz82IiPivEm2oBL0RdY5vVnVeUyQdClwJPAZMjIj7BqvuwVacWOEJE2a2ruhriG86KUB9h/Ru0fQmywygTICan7dj6xzfKW/rPSdag6QjgW+Rek6TIuKewarbzMzao68ANTFvH6j6ud1uyNspkobVmGY+HlgG3NpMYXlK+eWkldTr9ZwqZgNnkJ5RfbqqnB1IgWsR0FcZZmY2AOoGqIi4sa+f2yUiFkqaRZqpdxJwfuHwDGA48NXiO1CSxuVr7y6WJel9pIkQi0jBaVGD6m8k5bTaX9LBVS/qfjafc2Gjd6DMzKz9+nxRd9AaseZSR/OAvUi9tgXAvsWljiQFQESosG8i8HPSxI9LqL2ixZMRcV5V3dVLHT0ATAZ6gDlAU0sdDeWLukUd8J/TzKyUVl/ULRYwBtgFuLHSm5G0PvCfwKHAc8DnI+J7ZRuXe1E9wFmk4bYDSStIzARmRMTSJorZjtWzEo+tc84i0qy+Yt2/lvQWUm9tCikp46Lcls94HT4zs6HRdA9K0qXAwcBfRcRLed900qKqFS8D+0VEU8+L1iXuQZmZtaZeD6pMuo19gOsLwWkYcCJp5fHRpNXHnyMtrmpmZtYvZQLUX5GGvip2A7YivSf1UETMJT0/eksb22dmZl2qTIDagFeuqDA+/1zMoPsQ8Jo2tMvMzLpcmQD1EPDmws8HAo9HxLzCvq2Bp9vRMDMz625lFov9P+DDkr4ALAfeTkr2VzSOVw4DmpmZtaRMgPocaTr5R/LPfwbOrByUtB0pa+25bWudmZl1raYDVEQ8JulNpJdYIb0P9UzhlE1IweunbWyfmZl1qTI9KCJiGWmor9axu4C72tEoMzOzMpMkzMzMBk2pHpSkLUjLCO0JbA6sV+O0iIjJNfabmZk1rcxafOOAXwCjWJ1pthYvtmNmZv1WZojvC6T3nD4L7ABsEBHDanxq9aoakrStpEskLZa0QlKvpPMkbV6ijLdLOlvS9ZKWSgpJv2pwTfTxWevWFJRWf8zM1mZlhvj2A34UER9rdyNqpNu4mzSMeAowVdL4YrqNPpwEHEJ6T+te0jBkMxYBl9XY/1CT15uZWZuVCVAC/jRA7biAFJymRcSqhIWSziEtPvtJ4IQmyvksKUPu3cDrgPubrL83IqaXabCZmQ2sMkN8twE7t7sBObX6FKAX+HLV4TNJK6QfJWl4o7Ii4paIuCsiXm53O83MbHCVCVBnAQdKmtDmNkzK21kRsbJ4IL8IPAfYGNi7zfUWjZR0rKSPSTpJ0kDWZWZmTSgzxPc60vOhWZK+TepRPVnrxIi4vES5lV7ZgjrH7yH1sMYC15cot4xdga8Vd0j6PXBURNw5QHWamVkfygSoy0hTyAUclT/VU8qV95UJUCPy9qk6xyv7R5Yos4xzgGtIAXI5acHbjwJHALMl7RYRf651oaTjgeMBRo8ePUDNMzPrTmUC1DED1oq+VSZMD8j7VRFxatWuucCRkq4GDgdOo06W4Ii4CLgIUsr3gWifmVm3KrNY7NcHqA2VHtKIOsc3qzpvsFxIClD7D3K9ZmZGZ6zFNz9vx9Y5vlPe1ntGNVCW5G3D2YNmZtZ+pdbiA5A0itSzeD0wPCKOK+zfHrgzr3rerBvydoqkYcWZfJI2JaWWXwYM9qoOlZl89w1yvWZmRskelKT3s/p9pQ/yyudSfwXcAry7TJkRsRCYBYwhrQRRNIPUg7k8Ip4rtGNcXhuwXyTtUev9KklvJr0cDHBFf+sxM7PyyiwW+3bShIA/kF6g/TsKqztExB8l3UXKuvu1moXUdyJpqaOZkiYD84C9gImkob0zqs6fV2lWVRvfChyXf9wkb3eSdFmhnUcXLpkGHCZpNvAgsII0i28qaaX2i4Fvl/wuZmbWBmWG+D4KPAwcEBFPS9q9xjl/APYp24iIWCiph/Qy8FTgwFzXTGBGRCxtsqi/Ad5XtW/rqn1HF/58LWkSxptJLwxvBDwB/AS4OCJ+UO6bmJlZu5QJUD3AlRHxdB/nPARs00pDIuJBmpzKHhE11+qOiMuovehrvXKuJQUpMzPrMGWeQW1IWhevLyMBr4NnZmb9ViZA9QJ/2+CcvVg9bdzMzKxlZQLU94H9JB1Z66CkY0jPcq5pR8PMzKy7lXkG9TngXcC3JR1BXvlB0smkZIaHkRZ2Pb9uCWZmZk0qs9TRXyQdQFoIttiLmpm3vwTeXXxfyczMrFWlVpKIiAeACflF1n2ALUlr5N0aEbcNQPvMzKxLlV7qCCAi/kB658nMzGxAtLIW33bAKFL6iyW5V2UdToU3x8KJQcxsLdDULD5JW0k6R9LDpMVTfw38Brhf0mJJn5e0xUA21MzMukvDACVpJ1ISv1NIC8K+DDxGSkfxMmnliI8AcyXtMHBNNTOzbtJngJI0DPgmMBq4EXgbsElEvCYitgE2BaYAN5FWI/fK32Zm1haNelBTSGvwXQVMjojZEfFC5WBErIiIn5MWWr0a2Cuvem5mZtYvjQLU4aQUFB+MqP9oPR87GXgROKKVhkjaVtIl+ZnWCkm9ks6TtHmJMt4u6WxJ10taKikk/aqJ63aRdJWkxyQtlzRf0gxJr27lu5iZWf81msW3BzAnIpY0OI+IeCwHgz3KNkLSjqR8UFuTllS6G9iT9NxrqqTxEfFEE0WdBBwCLAfuBRoGN0l7AbOBDUi9wAdJPcJPAJMlTY6IFWW/k5mZ9U+jHtTrgLtKlHcXsF0L7biAFJymRcShEXF6REwCzgV2ZnV220Y+C7yRlKzwoEYnS1oPuBTYGDgiIt4dER8lLXp7DSnd/IfLfhkzM+u/RgFqM+DJEuU9SZo40bQ8828Kq1PJF51JSvFxVK3U7NUi4paIuCsimk35cQDweuCmYnLCiFgJ/Hv+8QRJNfNPmZnZwGkUoDakXH6nlfmaMibl7awcGFaJiGeAOaQezt4lyy1T93XVByLiPlK6+e0AT583MxtkzbyoO9DrDuyctwvqHL8nb8d2Wt2Sjpc0V9LcJUsaPqYzM7MSmglQ0yW93MyHNLGgrBF5+1Sd45X9I1soe0DrjoiLIqInInpGjRrV9saZmXWzZtbiK/v8pd09rkr9Q7GC3FDW3W9+cmZma7M+A1RElMm426pKL2VEneObVZ23rtRtZmZ9GIwA1Mj8vK33jGmnvK33nGhtrdvMzPrQCQHqhrydktf+W0XSpqR3kZYBtw5A3bPzdmr1gTz9fSywiLSCu5mZDaIhD1ARsRCYRVps9qSqwzOA4cDlxVTyksZJGteG6m8E5gH7Szq4UP4w0ku/ABf2tcyTmZkNjJYy6g6AE0lLHc2UNJkUNPYCJpKG186oOn9e3r5iGoCktwLH5R83ydudJF1WOSciji78+WVJx5B6UldLuhp4AJhMWiR3Dmk1CzMzG2QdEaAiYqGkHuAs0nDbgcDDwExgRkQsbbKovwHeV7Vv66p9R1fV/WtJbyH11qaQVsJYlNvyGa/DZ2Y2NOTRq/bo6emJuXPn9rucwZga7v/kZtZJJN0WET3V+4f8GZSZmVktDlBmZtaRHKDMzKwjOUCZmVlH6ohZfDa4ihMxPGHCzDqVe1BmZtaRHKDMzKwjOUCZmVlHcoAyM7OO5ABlZmYdyQHKzMw6UscEKEnbSrpE0mJJKyT1SjpP0uYly9kiX9eby1mcy922zvm9kqLO55H2fDszMyurI96DkrQjKd3G1sD3gbuBPYFTgKmSxkfEE02Us2UuZywphcaVwDjgGOAdkvaJiFrJB58Czqux/9kWvs5axe9EmVmn6ogABVxACk7TIuL8yk5J5wAfBj4JnNBEOZ8iBadzI+IjhXKmAV/M9ayRPRd4MiKmt9x6MzNruyFPt5FTqy8EeoEdI2Jl4dimpLxQArYuZtWtUc5wYAmwEnhNRDxTODYs1zEm13Ff4VgvQESM6c/3WJvSbdTjHpSZDYVOTrcxKW9nFYMTQA4yc4CNgb0blLMP8GpgTjE45XJWktLKQ8rSW+1Vkt4j6WOSTpE0UdJ6Zb+ImZm1TycM8e2ctwvqHL+HlOl2LPDmyN0AABAySURBVHB9P8shl1NtG+AbVfvul3RMRNzYR51dx8+szGywdEIPakTePlXneGX/yAEq51JgMilIDQfeBHyVNBz4E0m71qtQ0vGS5kqau2TJkgbNW3tJqz9mZoOlEwJUI5Vfi/39//Wa5UTEjIiYHRGPRsTzEfHHiDgBOIc0ZDi9XoERcVFE9EREz6hRo/rZvKHnQGRmnaQTAlSlZzOizvHNqs4b6HIqLszb/Zs838zM2qgTnkHNz9taz4YAdsrbes+W2l1OxWN5O7zJ89cp7kWZ2VDrhB7UDXk7JU8HXyVPMx8PLANubVDOrfm88fm6YjnDSBMtivU1sk/e1nqx18zMBtiQB6iIWEiaAj4GOKnq8AxSD+by4jtQksZJGldVzrOkmXjDWfO50cm5/J9WvQP1BklbVLdJ0nbAl/KPV5T+UmZm1m+dMMQHcCJpiaKZkiYD84C9SO8sLQDOqDp/Xt5WD0R9DJgAfETSbsBvgNcDh5CG7KoD4JHA6ZJuAO4HngF2BN4BbAT8GPhCP7+bmZm1oCMCVEQslNQDnEVaiuhA0goSM4EZEbG0yXKekLQPcCZwKLAf8ARpKvknIuKhqktuIL0/tTtpSG848CTwK1Jv7Bsx1EttmJl1qSFf6mhdsS4sdVSW/+qYWTt08lJHZmZma3CAMjOzjtQRz6Bs3eW1+8ysVQ5Q1hYORGbWbh7iMzOzjuQelLVsbZpxaGZrHwcoG3IeHjSzWjzEZ2ZmHck9qA6wrg2VrWvfx8yGhgOUrTU8FGjWXTzEZ2ZmHaljApSkbSVdImmxpBWSeiWdJ2nzkuVska/rzeUszuVuO9B1W9+cUt7MyuiIIT5JO5LSbWwNfB+4G9gTOAWYKml8RDzRRDlb5nLGArOBK4FxwDHAOyTtU8wH1c66rT0GahhvoIcHPfxo1n6d0oO6gBQgpkXEoRFxekRMAs4lpcP4ZJPlfIoUnM6NiMm5nENJwWbrXM9A1W0lNNObKp5TfV71sTI9M/fkzNYOQ55uQ9IOwEKgF9gxIlYWjm1KygslYOtiVt0a5QwHlgArgddExDOFY8NyHWNyHfe1s27oX7oN/6Jsn+Jf57L3tT//FNyDMmtdJ6fbmJS3s4oBAiAHmTnAxsDeDcrZB3g1MKcYnHI5K0lp5SFl6W133dYhBrN3NBg9N/f2rJt1QoDaOW8X1Dl+T96OHYBy2lW3rQPqDRuWHU5sZfiyXUOW/QloA92GZodrbWC1614Pxn+zTpgkMSJvn6pzvLJ/5ACU06+6JR0PHJ9/fFbS/AZtrNgKeLzJc7vVOnGP2vmPt0ZZfd6jdtXdn3LqXdsoeLfROvH3aKBI7bk/bfhvtl2tnZ0QoBqpfPX+juy3Uk6f10TERcBFpRsiza013mqr+R415nvUmO9R3zr9/nTCEF+llzKizvHNqs5rZzntqtvMzNqsEwJUZVis3nOenfK23nOi/pTTrrrNzKzNOiFA3ZC3U/J08FXyVO/xwDLg1gbl3JrPG5+vK5YzDJhSVV876y6r9LBgF/I9asz3qDHfo7519P0Z8gAVEQtJU8DHACdVHZ4BDAcuL76HJGmcpHFV5TwLfCOfP72qnJNz+T8triTRSt3tkJ9dWR98jxrzPWrM96hvnX5/hvxFXai53NA8YC/SO0sLgH2Lyw1JCoCIUFU51Usd/QZ4PXAI8FguZ2F/6jYzs8HREQEKQNLrgLOAqcCWpFUcrgVmRMTSqnNrBqh8bAvgTOBQ4DXAE8BPgE9ExEP9rdvMzAZHxwQoMzOzoiF/BtUtnNIDJB0h6XxJv5T0tKSQdEWDa/aV9GNJSyU9L+kPkj4kab3BavdgkbSlpOMkfU/SvZKWSXpK0q8kvb96Ik/huq65RwCSPivpekkP5nu0VNLtks7Mw/y1rumqe1SLpKPyv7mQdFydc94p6Rf5792zkn4t6X2D3dZV7XEPauD1kdJjImmqe1ek9JB0B7Ar8CzwECkVyjcj4j11zj8EuAZYDnwHWAocRFqi6uqIOHIw2j1YJJ0AfIU0xHwD8ADwV8BhpHf1rgGOjMI/2m67RwCSXgB+B/yJ9Gx5OGm9zB5gMbB3RDxYOL/r7lG1/BjjTmA9YBPgXyLif6rOORk4n/RY5DvAC8ARwLbA2RFx2qA2GiAi/BngD/BT0moUH6zaf07ef+FQt3GQ7sNE0rtlAibk735FnXM3I/3yWQH0FPZvRAr2AbxrqL9Tm+/PJNIvzmFV+7chBasADu/me1T5fnX2fzJ/5wu6/R5V3RcBPydlbvh8/s7HVZ0zhhTAnwDGFPZvDtybr9lnsNvuIb4BllN6TCGl9Phy1eEzgeeAo3K6kHVaRNwQEfdE/pvfwBHAKODKiFiVxyQilgMfzz/+6wA0c8hExOyI+GGsubL+I8CF+ccJhUNdd49g1fer5aq83amwryvvUZVppP/5OYb0+6aWY4FXAV+KiN7Kzoj4CynPHsAJA9jGmhygBp5TerSmct+uq3HsJuB5YF9Jrxq8Jg2pF/P2pcI+36NXOihv/1DY19X3SNLrgc8AX4yIm/o4ta/79JOqcwaNA9TAc0qP1tS9bxHxEnA/abHjHQazUUNB0vrAe/OPxV8gXX2PJJ0mabqkcyX9EvgvUnD6TOG0rr1H+e/NN0jDwx9rcHpf9+lhUs9rW0kbt7WRDawNq5mv7dqVTqTb+L6t9hngjcCPI+Knhf3dfo9OI00iqbgOODoilhT2dfM9+gSwO/DWiFjW4Nxm7tPwfN7z7WleY+5BDb12pRPpNl1x3yRNA04lzfw8quzlebtO3qOI2CbSy/rbkGY67gDcLmmPEsWsk/dI0p6kXtPZEXFLO4rM20G9Tw5QA88pPVrT9fdN0knAF0nTqSfGmquadP09AoiIRyPie6TJSFsClxcOd909KgztLQD+s8nLmr1PT/ejaaU5QA08p/RoTd37lv8Bbk+aMHBf9fF1gaQPAV8C/kgKTo/UOK2r71G1iFhECuZvkLRV3t2N92gT0vd9PbC88HJukGYOA1yc952Xf+7rPr2GNLz3UEQM2vAeOEANhqFK6bG2m523U2sc25808/HmiFgxeE0aHJI+CpwL3EEKTo/VObVr71Ef/jpvX87bbrxHK4Cv1fncns/5Vf65MvzX1336+6pzBs9Qv0TWDR/8om6tezKBxi/qLqHLXrAkDckEMBfYosG5XXePSKuPbFNj/zBWv6g7p5vvUYP7N53aL+puTwe+qOuljgaBU3okkg4lrTIP6cH235GGVn6Z9z0eheVU8vlXk/7hXElaouZg8hI1wD/GOvQXOK95dhnp//7Pp/Zzkd6IuKxwTbfdow+RVkO4ibQywhOkmXwHkCZJPAJMjog/Fa7pqnvUF0nTScN8tZY6+iAwEy911H0f4HXApaR11l4AFpEegPf5f8nr0ofV//dW79Nb45rxwI+Bv5CGQu8EPgysN9TfZwjuTwC/6PJ79EbSiix3AI+Tnh89Bfw237+a/5666R41+XfsuDrHDwJuBJ4hvfv0W+B9Q9Ve96DMzKwjeZKEmZl1JAcoMzPrSA5QZmbWkRygzMysIzlAmZlZR3KAMjOzjuQAZWZmHckBymyQSeqV1DvU7TDrdA5QZkBhxedFkjaqc05vPseJPs0GgQOU2SuNBj40wHVMzh8z64OXOjIj9aBI67QFsD6wY0Q8XnVOL7AdsEFEvDTojTTrMu5Bma32PPBfpBQNZzY49xUk/aOkmyQ9JWmZpDsl/YekV9U4d41nUJI2lDRN0u8k/UXS8/m870t6W40yxkm6TNKDklZIelTStyTtXKLNE/KQ5XRJu0n6kaQnc903Stq3znUjJH1a0nxJy3N7f1qnncU69sx1LM37xhTvh6RNJJ2bv9MySXfklciRtL6kj0m6J9e5UNLJzX5XWzu5B2XGqh7Un0l5ceaRhvreGBELCuf0UqMHJelTwH+QVte+GniWlOTtDaSVod8eES9WlUNEjCns+xbwT6QMurNJK27/NfBW4LvxyjQkU4HvAhsAPyTl69kWOIyU92hiRPyuie88gZRQ80fAJFLyutvzdz+ctOr+bhExv3DNSGAOsAtppevZwFbAP5Iyuf5rRHy1Rh2zSDnAfpXr2Ar4WEQszvdjA+ABYAvgZ8CG+X5sTErlfiIpRc1P8nc8kpS+5l0R8Z1G39XWUkO9/Ls//nTChzS091D+8xH55+9WndOb969f2LdP3vcAhUR6pGHCH+ZjH6tRTm/h5xHASlKSwjXSPwBbFv68OWko8nFgl6rz3kAKjr9r8jtPYHUaj6Orjn0g77+gav9X8/6vkv8HN+/fiZT2YgWvTHhXrOMDddpRua8/BF5V2L9f3r+UFAxHFo7tQAqgtw/13x1/Bu7jIT6zKhFxNak38Q+S3trg9GPz9r8j4pFCGS8Bp5ICz3GNqgRE+uW+skZ7isks3wuMBM6MQlK+fN5dwMXA7pJ2aVBn0ZwoJEHMLiHlWtqzskPSBsB7SEHwPyJi1fBLRNxDSna3YW5jtTui0LOq40NRSL0eEb8E7icF5Y9GxJOFY/eRenJvkrRew29oayVPlzWr7VRSFuSzJe1d/GVcZY+8nV19ICIWSHoI2F7SyOIv2Krznpb0Q1KyuDskXUPKMvzriHi+6vR98nbXnB212ti8fT3wpxrHa5lbo00vSnqUFBwqxpGG3OZExNIa5cwGPg7sXuPYbxq04cmIWFhj/2LSsOttNY79GViPlJ35zw3Kt7WQA5RZDRFxi6SrScN9/0hKgV3LiLx9uM7xh0nPdEYANQNU9v+AjwLvBmbkfctzG06LiEfzvi3z9l8afIVNGhwvqteul0gBoKKZ7wqph1ftkRr7imqlt6+0gYiodbzyHHCDBmXbWspDfGb1nQ68CHxa0oZ1zqn84tymzvHXVJ1XU0Qsi4jpETGWFNDeQ5pQ8B7SxIvq+naNCPXx+XqD79aK/nxXz8ay0hygzOrIQ04XkIaYPljntNvzdkL1AUl/Q5pdd3+94b069T4YEd8E/g64B3irpErP6da83a/Z8tpoPmkq/m6SNq9xfGLeNpxBaNYMByizvp1FGgI7g9rDZpfk7ccljarszA/uv0D6N/a1viqQNErSXjUODQc2JQ1lvZD3XZrbc6akPasvkDQsT+1uu4h4Afgm6T6cVVXvjsA0Uo/zGwNRv3UfP4My60NELM3vOX2uzvGbJX0O+Hfgj/mZ0XOk96DeSBqm+3yDal4L3CppHqn38SDpZeF3kobTZkbEM7m+JyQdAXwvX3M9cBdp9t9o0iSKLYGa6wm2wemk3tvJkt5Cesep8h7UpsDJEXH/ANVtXcYByqyxmaQXRcfUOhgRH5V0O3AyaYr1BsBC0oy2s3PPoy+9pJUrJpCGybYivfsznxQQrqyq73pJbwZOIw0D7kfqYS0mzaS7puT3a1oO2PuQXkw+DPgI6aXi3wCfj4hZA1W3dR+vJGFmZh3Jz6DMzKwjOUCZmVlHcoAyM7OO5ABlZmYdyQHKzMw6kgOUmZl1JAcoMzPrSA5QZmbWkRygzMysI/1/OLVlNf2S27EAAAAASUVORK5CYII=\n", 171 | "text/plain": [ 172 | "
" 173 | ] 174 | }, 175 | "metadata": { 176 | "needs_background": "light" 177 | }, 178 | "output_type": "display_data" 179 | }, 180 | { 181 | "data": { 182 | "image/png": "\n", 183 | "text/plain": [ 184 | "
" 185 | ] 186 | }, 187 | "metadata": { 188 | "needs_background": "light" 189 | }, 190 | "output_type": "display_data" 191 | }, 192 | { 193 | "data": { 194 | "image/png": "\n", 195 | "text/plain": [ 196 | "
" 197 | ] 198 | }, 199 | "metadata": { 200 | "needs_background": "light" 201 | }, 202 | "output_type": "display_data" 203 | } 204 | ], 205 | "source": [ 206 | "# lstm plot\n", 207 | "d = read_pickle(\"./acclip-sst-lstm_noise\")\n", 208 | "for i in range(7,13):\n", 209 | " noises=d[i]\n", 210 | " num_bins = len(noises)\n", 211 | " #plt.yticks(np.linspace(0,0.8,3))\n", 212 | " #plt.xticks(np.linspace(0,20,5))\n", 213 | " #plt.xlim(0,20)\n", 214 | " #plt.ylim(0,0.8)\n", 215 | " n, bins, patches = plt.hist(noises, 100, density=1, facecolor='blue')\n", 216 | " plt.ylabel('Density (%)')\n", 217 | " plt.xlabel('Noise norm')\n", 218 | " plt.xticks(fontsize=20)\n", 219 | " plt.yticks(fontsize=20)\n", 220 | " plt.rc('axes', labelsize=20)\n", 221 | " plt.tight_layout()\n", 222 | " \n", 223 | " plt.savefig('sst'+str(i)+'.png',dpi=300)\n", 224 | " plt.show()" 225 | ] 226 | }, 227 | { 228 | "cell_type": "code", 229 | "execution_count": 347, 230 | "metadata": { 231 | "scrolled": true 232 | }, 233 | "outputs": [ 234 | { 235 | "data": { 236 | "image/png": "\n", 237 | "text/plain": [ 238 | "
" 239 | ] 240 | }, 241 | "metadata": { 242 | "needs_background": "light" 243 | }, 244 | "output_type": "display_data" 245 | } 246 | ], 247 | "source": [ 248 | "# lstm plot\n", 249 | "d = read_pickle(\"./adam-mrpc-bert_noise\")\n", 250 | "for i in [1]:\n", 251 | " noises=d[i]\n", 252 | " num_bins = len(noises)\n", 253 | " plt.yticks(np.linspace(0,0.5,6))\n", 254 | " plt.xticks(np.linspace(0,40,5))\n", 255 | " plt.xlim(0,40)\n", 256 | " #lt.ylim(0,0.8)\n", 257 | " n, bins, patches = plt.hist(noises, 100, density=1, facecolor='blue')\n", 258 | " plt.ylabel('Density (%)')\n", 259 | " plt.xlabel('Noise norm')\n", 260 | " plt.xticks(fontsize=20)\n", 261 | " plt.yticks(fontsize=20)\n", 262 | " plt.rc('axes', labelsize=20)\n", 263 | " plt.tight_layout()\n", 264 | " \n", 265 | " #plt.savefig('bert'+str(i)+'.png',dpi=300)\n", 266 | " plt.show()" 267 | ] 268 | } 269 | ], 270 | "metadata": { 271 | "kernelspec": { 272 | "display_name": "Python 3", 273 | "language": "python", 274 | "name": "python3" 275 | }, 276 | "language_info": { 277 | "codemirror_mode": { 278 | "name": "ipython", 279 | "version": 3 280 | }, 281 | "file_extension": ".py", 282 | "mimetype": "text/x-python", 283 | "name": "python", 284 | "nbconvert_exporter": "python", 285 | "pygments_lexer": "ipython3", 286 | "version": "3.7.4" 287 | }, 288 | "pycharm": { 289 | "stem_cell": { 290 | "cell_type": "raw", 291 | "metadata": { 292 | "collapsed": false 293 | }, 294 | "source": [] 295 | } 296 | } 297 | }, 298 | "nbformat": 4, 299 | "nbformat_minor": 2 300 | } 301 | -------------------------------------------------------------------------------- /optimizers/ACClip.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.optim import Optimizer 3 | import math 4 | 5 | class ACClip(Optimizer): 6 | 7 | 8 | def __init__(self, params, lr=1e-4, betas=(0.9, 0.99), eps=1e-5, 9 | weight_decay=1e-5, alpha=2, mod=1): 10 | if not 0.0 <= lr: 11 | raise ValueError("Invalid learning rate: {}".format(lr)) 12 | if not 0.0 <= eps: 13 | raise ValueError("Invalid epsilon value: {}".format(eps)) 14 | if not 0.0 <= betas[0] < 1.0: 15 | raise ValueError("Invalid beta parameter at index 0: {}".format(betas[0])) 16 | if not 0.0 <= betas[1] < 1.0: 17 | raise ValueError("Invalid beta parameter at index 1: {}".format(betas[1])) 18 | if not 1.0 <= alpha <= 2.0: 19 | raise ValueError("Invalid alpha parameter: {}".format(alpha)) 20 | 21 | defaults = dict(lr=lr, betas=betas, eps=eps, 22 | weight_decay=weight_decay, alpha=alpha, mod=mod) 23 | super(ACClip, self).__init__(params, defaults) 24 | 25 | def __setstate__(self, state): 26 | super(ACClip, self).__setstate__(state) 27 | 28 | def step(self, closure=None): 29 | 30 | loss = None 31 | if closure is not None: 32 | loss = closure() 33 | 34 | for group in self.param_groups: 35 | for p in group['params']: 36 | if p.grad is None: 37 | continue 38 | grad = p.grad.data 39 | if grad.is_sparse: 40 | raise RuntimeError('ACClip does not support sparse gradients') 41 | 42 | state = self.state[p] 43 | 44 | # State initialization 45 | if len(state) == 0: 46 | # the momentum term, i.e., m_0 47 | state['momentum'] = torch.zeros_like(p.data) 48 | # the clipping value, i.e., \tao_0^{\alpha} 49 | state['clip'] = torch.zeros_like(p.data) 50 | # second-order momentum, i.e., v_t 51 | state['second_moment'] = torch.zeros_like(p.data) 52 | # the number of step in total 53 | state['step'] = 0 54 | 55 | state['step'] += 1 56 | momentum, clip, second_moment = state['momentum'], state['clip'], state['second_moment'] 57 | beta1, beta2 = group['betas'] 58 | # bias_decay1 = 1 - beta1 ** state['step'] 59 | # bias_decay2 = 1 - beta2 ** state['step'] 60 | 61 | alpha = group['alpha'] 62 | 63 | if group['weight_decay'] != 0: 64 | grad.add_(group['weight_decay'], p.data) 65 | 66 | # update momentum and clip 67 | momentum.mul_(beta1).add_(1 - beta1, grad) 68 | clip.mul_(beta2).add_(1 - beta2, grad.abs().pow(alpha)) 69 | second_moment.mul_(beta2).addcmul_(1-beta2, grad, grad) 70 | 71 | # truncate large gradient 72 | denom = clip.pow(1/alpha).div(momentum.abs().add(group['eps'])).clamp(min=0.0, max=1.0) 73 | 74 | # calculate eta_t 75 | if group['mod'] == 1: 76 | denom.div_((second_moment.mul(beta2).sqrt()).add(group['eps'])) 77 | step_size = group['lr'] 78 | p.data.addcmul_(-step_size, denom, momentum) 79 | 80 | return loss 81 | 82 | -------------------------------------------------------------------------------- /run_mrpc.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2018 The Google AI Language Team Authors and The HuggingFace Inc. team. 3 | # Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved. 4 | # 5 | # Licensed under the Apache License, Version 2.0 (the "License"); 6 | # you may not use this file except in compliance with the License. 7 | # You may obtain a copy of the License at 8 | # 9 | # http://www.apache.org/licenses/LICENSE-2.0 10 | # 11 | # Unless required by applicable law or agreed to in writing, software 12 | # distributed under the License is distributed on an "AS IS" BASIS, 13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | # See the License for the specific language governing permissions and 15 | # limitations under the License. 16 | """ Finetuning the library models for sequence classification on GLUE (Bert, XLM, XLNet, RoBERTa).""" 17 | 18 | from __future__ import absolute_import, division, print_function 19 | 20 | import argparse 21 | import glob 22 | import logging 23 | import os 24 | import random 25 | import pickle 26 | 27 | import numpy as np 28 | import torch 29 | from torch.optim import Adam, SGD 30 | from torch.utils.data import (DataLoader, RandomSampler, SequentialSampler, 31 | TensorDataset) 32 | from torch.utils.data.distributed import DistributedSampler 33 | 34 | try: 35 | from torch.utils.tensorboard import SummaryWriter 36 | except: 37 | from tensorboardX import SummaryWriter 38 | 39 | from tqdm import tqdm, trange 40 | 41 | from transformers import (WEIGHTS_NAME, BertConfig, 42 | BertForSequenceClassification, BertTokenizer, 43 | RobertaConfig, 44 | RobertaForSequenceClassification, 45 | RobertaTokenizer, 46 | XLMConfig, XLMForSequenceClassification, 47 | XLMTokenizer, XLNetConfig, 48 | XLNetForSequenceClassification, 49 | XLNetTokenizer, 50 | DistilBertConfig, 51 | DistilBertForSequenceClassification, 52 | DistilBertTokenizer) 53 | 54 | from transformers import AdamW, WarmupLinearSchedule 55 | 56 | from transformers import glue_compute_metrics as compute_metrics 57 | from transformers import glue_output_modes as output_modes 58 | from transformers import glue_processors as processors 59 | from transformers import glue_convert_examples_to_features as convert_examples_to_features 60 | from optimizers.ACClip import ACClip 61 | 62 | logger = logging.getLogger(__name__) 63 | 64 | ALL_MODELS = sum((tuple(conf.pretrained_config_archive_map.keys()) for conf in (BertConfig, XLNetConfig, XLMConfig, 65 | RobertaConfig, DistilBertConfig)), ()) 66 | 67 | MODEL_CLASSES = { 68 | 'bert': (BertConfig, BertForSequenceClassification, BertTokenizer), 69 | 'xlnet': (XLNetConfig, XLNetForSequenceClassification, XLNetTokenizer), 70 | 'xlm': (XLMConfig, XLMForSequenceClassification, XLMTokenizer), 71 | 'roberta': (RobertaConfig, RobertaForSequenceClassification, RobertaTokenizer), 72 | 'distilbert': (DistilBertConfig, DistilBertForSequenceClassification, DistilBertTokenizer) 73 | } 74 | 75 | 76 | def set_seed(args): 77 | random.seed(args.seed) 78 | np.random.seed(args.seed) 79 | torch.manual_seed(args.seed) 80 | if args.n_gpu > 0: 81 | torch.cuda.manual_seed_all(args.seed) 82 | 83 | 84 | def train(args, train_dataset, model, tokenizer): 85 | """ Train the model """ 86 | train_loss, test_loss, test_f1, test_acc = [], [], [], [] 87 | model_name = "{}-{}-{}".format(args.optimizer.lower(), args.task_name, args.model_type) 88 | if args.local_rank in [-1, 0]: 89 | tb_writer = SummaryWriter() 90 | 91 | args.train_batch_size = args.per_gpu_train_batch_size * max(1, args.n_gpu) 92 | train_sampler = RandomSampler(train_dataset) if args.local_rank == -1 else DistributedSampler(train_dataset) 93 | train_dataloader = DataLoader(train_dataset, sampler=train_sampler, batch_size=args.train_batch_size) 94 | 95 | if args.max_steps > 0: 96 | t_total = args.max_steps 97 | args.num_train_epochs = args.max_steps // (len(train_dataloader) // args.gradient_accumulation_steps) + 1 98 | else: 99 | t_total = len(train_dataloader) // args.gradient_accumulation_steps * args.num_train_epochs 100 | 101 | # Prepare optimizer and schedule (linear warmup and decay) 102 | no_decay = ['bias', 'LayerNorm.weight'] 103 | optimizer_grouped_parameters = [ 104 | {'params': [p for n, p in model.named_parameters() if not any(nd in n for nd in no_decay)], 105 | 'weight_decay': args.weight_decay}, 106 | {'params': [p for n, p in model.named_parameters() if any(nd in n for nd in no_decay)], 'weight_decay': 0.0} 107 | ] 108 | if args.optimizer.lower() == "adamw": 109 | print ("We use AdamW optimizer!") 110 | optimizer = AdamW(optimizer_grouped_parameters, lr=args.learning_rate, eps=args.adam_epsilon) 111 | elif args.optimizer.lower() == "adam": 112 | print("We use Adam optimizer!") 113 | optimizer = Adam(optimizer_grouped_parameters, lr=args.learning_rate, eps=args.adam_epsilon) 114 | elif args.optimizer.lower() == "sgd": 115 | print("We use SGD optimizer!") 116 | optimizer = SGD(optimizer_grouped_parameters, lr=args.learning_rate, momentum=0.9) 117 | elif "acclip" in args.optimizer.lower(): 118 | print ("We use ACClip optimizer!") 119 | optimizer = ACClip(optimizer_grouped_parameters, lr=args.learning_rate) 120 | scheduler = WarmupLinearSchedule(optimizer, warmup_steps=args.warmup_steps, t_total=t_total) 121 | 122 | # multi-gpu training (should be after apex fp16 initialization) 123 | if args.n_gpu > 1: 124 | model = torch.nn.DataParallel(model) 125 | 126 | # Distributed training (should be after apex fp16 initialization) 127 | if args.local_rank != -1: 128 | model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.local_rank], 129 | output_device=args.local_rank, 130 | find_unused_parameters=True) 131 | # Train! 132 | logger.info("***** Running training *****") 133 | logger.info(" Num examples = %d", len(train_dataset)) 134 | logger.info(" Num Epochs = %d", args.num_train_epochs) 135 | logger.info(" Instantaneous batch size per GPU = %d", args.per_gpu_train_batch_size) 136 | logger.info(" Total train batch size (w. parallel, distributed & accumulation) = %d", 137 | args.train_batch_size * args.gradient_accumulation_steps * ( 138 | torch.distributed.get_world_size() if args.local_rank != -1 else 1)) 139 | logger.info(" Gradient Accumulation steps = %d", args.gradient_accumulation_steps) 140 | logger.info(" Total optimization steps = %d", t_total) 141 | 142 | global_step = 0 143 | tr_loss, logging_loss = 0.0, 0.0 144 | model.zero_grad() 145 | train_iterator = trange(int(args.num_train_epochs), desc="Epoch", disable=args.local_rank not in [-1, 0]) 146 | set_seed(args) # Added here for reproductibility (even between python 2 and 3) 147 | for _ in train_iterator: 148 | epoch_iterator = tqdm(train_dataloader, desc="Iteration", disable=args.local_rank not in [-1, 0]) 149 | for step, batch in enumerate(epoch_iterator): 150 | model.train() 151 | batch = tuple(t.to(args.device) for t in batch) 152 | inputs = {'input_ids': batch[0], 153 | 'attention_mask': batch[1], 154 | 'labels': batch[3]} 155 | if args.model_type != 'distilbert': 156 | inputs['token_type_ids'] = batch[2] if args.model_type in ['bert', 157 | 'xlnet'] else None # XLM, DistilBERT and RoBERTa don't use segment_ids 158 | outputs = model(**inputs) 159 | loss = outputs[0] # model outputs are always tuple in transformers (see doc) 160 | 161 | if args.n_gpu > 1: 162 | loss = loss.mean() # mean() to average on multi-gpu parallel training 163 | if args.gradient_accumulation_steps > 1: 164 | loss = loss / args.gradient_accumulation_steps 165 | 166 | loss.backward() 167 | 168 | tr_loss += loss.item() 169 | if (step + 1) % args.gradient_accumulation_steps == 0: 170 | if args.optimizer.lower() != "acclip": # make sure we don't clip for acclip 171 | torch.nn.utils.clip_grad_norm_(model.parameters(), args.max_grad_norm) # 172 | optimizer.step() 173 | scheduler.step() # Update learning rate schedule 174 | model.zero_grad() 175 | global_step += 1 176 | 177 | if args.local_rank in [-1, 0] and args.logging_steps > 0 and global_step % args.logging_steps == 0: 178 | # Log metrics 179 | if args.local_rank == -1 and args.evaluate_during_training: # Only evaluate when single GPU otherwise metrics may not average well 180 | results, eval_loss = evaluate(args, model, tokenizer) 181 | for key, value in results.items(): 182 | tb_writer.add_scalar('eval_{}'.format(key), value, global_step) 183 | print ('eval_{}'.format(key), value, global_step) 184 | print ('eval_loss', eval_loss, global_step) 185 | test_f1.append(results['f1']*100) 186 | test_acc.append(results['acc']*100) 187 | test_loss.append(eval_loss) 188 | tb_writer.add_scalar('lr', scheduler.get_lr()[0], global_step) 189 | iter_train_loss = (tr_loss - logging_loss) / args.logging_steps 190 | tb_writer.add_scalar('loss', iter_train_loss, global_step) 191 | print ("eval_Training_Loss_{}".format(iter_train_loss), global_step) 192 | logging_loss = tr_loss 193 | train_loss.append(iter_train_loss) 194 | 195 | if args.local_rank in [-1, 0] and args.save_steps > 0 and global_step % args.save_steps == 0: 196 | # Save model checkpoint 197 | output_dir = os.path.join(args.output_dir, 'checkpoint-{}'.format(global_step)) 198 | if not os.path.exists(output_dir): 199 | os.makedirs(output_dir) 200 | model_to_save = model.module if hasattr(model, 201 | 'module') else model # Take care of distributed/parallel training 202 | model_to_save.save_pretrained(output_dir) 203 | torch.save(args, os.path.join(output_dir, 'training_args.bin')) 204 | logger.info("Saving model checkpoint to %s", output_dir) 205 | 206 | if args.max_steps > 0 and global_step > args.max_steps: 207 | epoch_iterator.close() 208 | break 209 | if args.max_steps > 0 and global_step > args.max_steps: 210 | train_iterator.close() 211 | break 212 | 213 | if args.local_rank in [-1, 0]: 214 | tb_writer.close() 215 | 216 | if not os.path.exists("./curves"): 217 | os.mkdir("./curves") 218 | with open(os.path.join('./curves', model_name), "wb") as f: 219 | pickle.dump({'train_loss': train_loss, 'test_loss': test_loss, 220 | 'test_f1': test_f1, 'test_acc': test_acc}, f) 221 | return global_step, tr_loss / global_step 222 | 223 | 224 | def plot_noise(args, train_dataset, model, tokenizer): 225 | """ Train the model + Plot the noise""" 226 | # train_loss, test_loss, test_f1, test_acc = [], [], [], [] 227 | model_name = "{}-{}-{}".format(args.optimizer.lower(), args.task_name, args.model_type) 228 | # if args.local_rank in [-1, 0]: 229 | # tb_writer = SummaryWriter() 230 | 231 | args.train_batch_size = args.per_gpu_train_batch_size * max(1, args.n_gpu) 232 | train_sampler = RandomSampler(train_dataset) if args.local_rank == -1 else DistributedSampler(train_dataset) 233 | train_dataloader = DataLoader(train_dataset, sampler=train_sampler, batch_size=args.train_batch_size) 234 | 235 | if args.max_steps > 0: 236 | t_total = args.max_steps 237 | args.num_train_epochs = args.max_steps // (len(train_dataloader) // args.gradient_accumulation_steps) + 1 238 | else: 239 | t_total = len(train_dataloader) // args.gradient_accumulation_steps * args.num_train_epochs 240 | 241 | # Prepare optimizer and schedule (linear warmup and decay) 242 | no_decay = ['bias', 'LayerNorm.weight'] 243 | optimizer_grouped_parameters = [ 244 | {'params': [p for n, p in model.named_parameters() if not any(nd in n for nd in no_decay)], 245 | 'weight_decay': args.weight_decay}, 246 | {'params': [p for n, p in model.named_parameters() if any(nd in n for nd in no_decay)], 'weight_decay': 0.0} 247 | ] 248 | if args.optimizer.lower() == "adamw": 249 | print("We use AdamW optimizer!") 250 | optimizer = AdamW(optimizer_grouped_parameters, lr=args.learning_rate, eps=args.adam_epsilon) 251 | elif args.optimizer.lower() == "adam": 252 | print("We use Adam optimizer!") 253 | optimizer = Adam(optimizer_grouped_parameters, lr=args.learning_rate, eps=args.adam_epsilon) 254 | elif args.optimizer.lower() == "sgd": 255 | print("We use SGD optimizer!") 256 | optimizer = SGD(optimizer_grouped_parameters, lr=args.learning_rate, momentum=0.9) 257 | elif args.optimizer.lower() == "acclip": 258 | print("We use ACClip optimizer!") 259 | optimizer = ACClip(optimizer_grouped_parameters, lr=args.learning_rate) 260 | scheduler = WarmupLinearSchedule(optimizer, warmup_steps=args.warmup_steps, t_total=t_total) 261 | 262 | # multi-gpu training (should be after apex fp16 initialization) 263 | if args.n_gpu > 1: 264 | model = torch.nn.DataParallel(model) 265 | 266 | # Distributed training (should be after apex fp16 initialization) 267 | if args.local_rank != -1: 268 | model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.local_rank], 269 | output_device=args.local_rank, 270 | find_unused_parameters=True) 271 | # Train! 272 | logger.info("***** Running training *****") 273 | logger.info(" Num examples = %d", len(train_dataset)) 274 | logger.info(" Num Epochs = %d", args.num_train_epochs) 275 | logger.info(" Instantaneous batch size per GPU = %d", args.per_gpu_train_batch_size) 276 | logger.info(" Total train batch size (w. parallel, distributed & accumulation) = %d", 277 | args.train_batch_size * args.gradient_accumulation_steps * ( 278 | torch.distributed.get_world_size() if args.local_rank != -1 else 1)) 279 | logger.info(" Gradient Accumulation steps = %d", args.gradient_accumulation_steps) 280 | logger.info(" Total optimization steps = %d", t_total) 281 | 282 | global_step = 0 283 | tr_loss, logging_loss = 0.0, 0.0 284 | model.zero_grad() 285 | train_iterator = trange(int(args.num_train_epochs), desc="Epoch", disable=args.local_rank not in [-1, 0]) 286 | set_seed(args) # Added here for reproductibility (even between python 2 and 3) 287 | 288 | epoch_num = 0 289 | diction = {} 290 | 291 | for _ in train_iterator: 292 | epoch_iterator = tqdm(train_dataloader, desc="Iteration", disable=args.local_rank not in [-1, 0]) 293 | # i = 0 --> calculate average gradient 294 | # i = 1 --> calculate noises 295 | # i = 2 --> optimizer.step() 296 | avg_grad = None 297 | noise_sample = list() 298 | for i in range(3): 299 | num = 0 300 | for step, batch in enumerate(epoch_iterator): 301 | model.train() 302 | batch = tuple(t.to(args.device) for t in batch) 303 | inputs = {'input_ids': batch[0], 304 | 'attention_mask': batch[1], 305 | 'labels': batch[3]} 306 | if args.model_type != 'distilbert': 307 | inputs['token_type_ids'] = batch[2] if args.model_type in ['bert', 308 | 'xlnet'] else None # XLM, DistilBERT and RoBERTa don't use segment_ids 309 | outputs = model(**inputs) 310 | loss = outputs[0] # model outputs are always tuple in transformers (see doc) 311 | 312 | if args.n_gpu > 1: 313 | loss = loss.mean() # mean() to average on multi-gpu parallel training 314 | if args.gradient_accumulation_steps > 1: 315 | loss = loss / args.gradient_accumulation_steps 316 | 317 | loss.backward() 318 | 319 | tr_loss += loss.item() 320 | if (step + 1) % args.gradient_accumulation_steps == 0: 321 | # if args.optimizer.lower() != "acclip": # make sure we don't clip for acclip 322 | # torch.nn.utils.clip_grad_norm_(model.parameters(), args.max_grad_norm) # 323 | if (i <= 1): 324 | emb_grad = None 325 | for j in model.parameters(): 326 | if (emb_grad is None): 327 | emb_grad = j.grad.data.view(-1) 328 | else: 329 | emb_grad = torch.cat((emb_grad, j.grad.data.view(-1))) 330 | if (i == 0): 331 | if (avg_grad is None): 332 | avg_grad = emb_grad 333 | else: 334 | avg_grad += emb_grad 335 | num += 1 336 | if (i == 1): 337 | diff = emb_grad - avg_grad 338 | l2_norm = torch.norm(diff, p=2) 339 | noise_sample.append(l2_norm.item()) 340 | if (i == 2): 341 | optimizer.step() 342 | scheduler.step() # Update learning rate schedule 343 | global_step += 1 344 | if args.local_rank in [-1, 345 | 0] and args.logging_steps > 0 and global_step % args.logging_steps == 0: 346 | # Log metrics 347 | if args.local_rank == -1 and args.evaluate_during_training: # Only evaluate when single GPU otherwise metrics may not average well 348 | results, eval_loss = evaluate(args, model, tokenizer) 349 | for key, value in results.items(): 350 | print('eval_{}'.format(key), value, global_step) 351 | print('eval_loss', eval_loss, global_step) 352 | model.zero_grad() 353 | 354 | if args.max_steps > 0 and global_step > args.max_steps: 355 | epoch_iterator.close() 356 | break 357 | if (i == 0): 358 | avg_grad /= num 359 | if (i == 1): 360 | diction[epoch_num] = noise_sample 361 | epoch_num += 1 362 | 363 | if args.max_steps > 0 and global_step > args.max_steps: 364 | train_iterator.close() 365 | break 366 | if not os.path.exists("./noises"): 367 | os.mkdir("./noises") 368 | with open(os.path.join('./noises', model_name + '_noise'), "wb") as f: 369 | pickle.dump(diction, f) 370 | 371 | return global_step, tr_loss / global_step 372 | 373 | 374 | def evaluate(args, model, tokenizer, prefix=""): 375 | # Loop to handle MNLI double evaluation (matched, mis-matched) 376 | eval_task_names = ("mnli", "mnli-mm") if args.task_name == "mnli" else (args.task_name,) 377 | eval_outputs_dirs = (args.output_dir, args.output_dir + '-MM') if args.task_name == "mnli" else (args.output_dir,) 378 | 379 | results = {} 380 | for eval_task, eval_output_dir in zip(eval_task_names, eval_outputs_dirs): 381 | eval_dataset = load_and_cache_examples(args, eval_task, tokenizer, evaluate=True) 382 | 383 | if not os.path.exists(eval_output_dir) and args.local_rank in [-1, 0]: 384 | os.makedirs(eval_output_dir) 385 | 386 | args.eval_batch_size = args.per_gpu_eval_batch_size * max(1, args.n_gpu) 387 | # Note that DistributedSampler samples randomly 388 | eval_sampler = SequentialSampler(eval_dataset) if args.local_rank == -1 else DistributedSampler(eval_dataset) 389 | eval_dataloader = DataLoader(eval_dataset, sampler=eval_sampler, batch_size=args.eval_batch_size) 390 | 391 | # Eval! 392 | logger.info("***** Running evaluation {} *****".format(prefix)) 393 | logger.info(" Num examples = %d", len(eval_dataset)) 394 | logger.info(" Batch size = %d", args.eval_batch_size) 395 | eval_loss = 0.0 396 | nb_eval_steps = 0 397 | preds = None 398 | out_label_ids = None 399 | for batch in tqdm(eval_dataloader, desc="Evaluating"): 400 | model.eval() 401 | batch = tuple(t.to(args.device) for t in batch) 402 | 403 | with torch.no_grad(): 404 | inputs = {'input_ids': batch[0], 405 | 'attention_mask': batch[1], 406 | 'labels': batch[3]} 407 | if args.model_type != 'distilbert': 408 | inputs['token_type_ids'] = batch[2] if args.model_type in ['bert', 409 | 'xlnet'] else None # XLM, DistilBERT and RoBERTa don't use segment_ids 410 | outputs = model(**inputs) 411 | tmp_eval_loss, logits = outputs[:2] 412 | 413 | eval_loss += tmp_eval_loss.mean().item() 414 | nb_eval_steps += 1 415 | if preds is None: 416 | preds = logits.detach().cpu().numpy() 417 | out_label_ids = inputs['labels'].detach().cpu().numpy() 418 | else: 419 | preds = np.append(preds, logits.detach().cpu().numpy(), axis=0) 420 | out_label_ids = np.append(out_label_ids, inputs['labels'].detach().cpu().numpy(), axis=0) 421 | 422 | eval_loss = eval_loss / nb_eval_steps 423 | if args.output_mode == "classification": 424 | preds = np.argmax(preds, axis=1) 425 | elif args.output_mode == "regression": 426 | preds = np.squeeze(preds) 427 | result = compute_metrics(eval_task, preds, out_label_ids) 428 | results.update(result) 429 | 430 | output_eval_file = os.path.join(eval_output_dir, prefix, "eval_results.txt") 431 | with open(output_eval_file, "w") as writer: 432 | logger.info("***** Eval results {} *****".format(prefix)) 433 | for key in sorted(result.keys()): 434 | logger.info(" %s = %s", key, str(result[key])) 435 | writer.write("%s = %s\n" % (key, str(result[key]))) 436 | 437 | return results, eval_loss 438 | 439 | 440 | def load_and_cache_examples(args, task, tokenizer, evaluate=False): 441 | if args.local_rank not in [-1, 0] and not evaluate: 442 | torch.distributed.barrier() # Make sure only the first process in distributed training process the dataset, and the others will use the cache 443 | 444 | processor = processors[task]() 445 | output_mode = output_modes[task] 446 | # Load data features from cache or dataset file 447 | cached_features_file = os.path.join(args.data_dir, 'cached_{}_{}_{}_{}'.format( 448 | 'dev' if evaluate else 'train', 449 | list(filter(None, args.model_name_or_path.split('/'))).pop(), 450 | str(args.max_seq_length), 451 | str(task))) 452 | if os.path.exists(cached_features_file) and not args.overwrite_cache: 453 | logger.info("Loading features from cached file %s", cached_features_file) 454 | features = torch.load(cached_features_file) 455 | else: 456 | logger.info("Creating features from dataset file at %s", args.data_dir) 457 | label_list = processor.get_labels() 458 | if task in ['mnli', 'mnli-mm'] and args.model_type in ['roberta']: 459 | # HACK(label indices are swapped in RoBERTa pretrained model) 460 | label_list[1], label_list[2] = label_list[2], label_list[1] 461 | examples = processor.get_dev_examples(args.data_dir) if evaluate else processor.get_train_examples( 462 | args.data_dir) 463 | features = convert_examples_to_features(examples, 464 | tokenizer, 465 | label_list=label_list, 466 | max_length=args.max_seq_length, 467 | output_mode=output_mode, 468 | pad_on_left=bool(args.model_type in ['xlnet']), 469 | # pad on the left for xlnet 470 | pad_token=tokenizer.convert_tokens_to_ids([tokenizer.pad_token])[0], 471 | pad_token_segment_id=4 if args.model_type in ['xlnet'] else 0, 472 | ) 473 | if args.local_rank in [-1, 0]: 474 | logger.info("Saving features into cached file %s", cached_features_file) 475 | torch.save(features, cached_features_file) 476 | 477 | if args.local_rank == 0 and not evaluate: 478 | torch.distributed.barrier() # Make sure only the first process in distributed training process the dataset, and the others will use the cache 479 | 480 | # Convert to Tensors and build dataset 481 | all_input_ids = torch.tensor([f.input_ids for f in features], dtype=torch.long) 482 | all_attention_mask = torch.tensor([f.attention_mask for f in features], dtype=torch.long) 483 | all_token_type_ids = torch.tensor([f.token_type_ids for f in features], dtype=torch.long) 484 | if output_mode == "classification": 485 | all_labels = torch.tensor([f.label for f in features], dtype=torch.long) 486 | elif output_mode == "regression": 487 | all_labels = torch.tensor([f.label for f in features], dtype=torch.float) 488 | 489 | dataset = TensorDataset(all_input_ids, all_attention_mask, all_token_type_ids, all_labels) 490 | return dataset 491 | 492 | 493 | def main(): 494 | parser = argparse.ArgumentParser() 495 | 496 | ## Required parameters 497 | parser.add_argument("--data_dir", default=None, type=str, required=True, 498 | help="The input data dir. Should contain the .tsv files (or other data files) for the task.") 499 | parser.add_argument("--model_type", default=None, type=str, required=True, 500 | help="Model type selected in the list: " + ", ".join(MODEL_CLASSES.keys())) 501 | parser.add_argument("--model_name_or_path", default=None, type=str, required=True, 502 | help="Path to pre-trained model or shortcut name selected in the list: " + ", ".join( 503 | ALL_MODELS)) 504 | parser.add_argument("--task_name", default=None, type=str, required=True, 505 | help="The name of the task to train selected in the list: " + ", ".join(processors.keys())) 506 | parser.add_argument("--output_dir", default=None, type=str, required=True, 507 | help="The output directory where the model predictions and checkpoints will be written.") 508 | parser.add_argument("--optimizer", default=AdamW, type=str, required=True, 509 | help="The optimizer") 510 | 511 | ## Other parameters 512 | parser.add_argument("--config_name", default="", type=str, 513 | help="Pretrained config name or path if not the same as model_name") 514 | parser.add_argument("--tokenizer_name", default="", type=str, 515 | help="Pretrained tokenizer name or path if not the same as model_name") 516 | parser.add_argument("--cache_dir", default="", type=str, 517 | help="Where do you want to store the pre-trained models downloaded from s3") 518 | parser.add_argument("--max_seq_length", default=128, type=int, 519 | help="The maximum total input sequence length after tokenization. Sequences longer " 520 | "than this will be truncated, sequences shorter will be padded.") 521 | parser.add_argument("--do_train", action='store_true', 522 | help="Whether to run training.") 523 | parser.add_argument("--do_eval", action='store_true', 524 | help="Whether to run eval on the dev set.") 525 | parser.add_argument("--evaluate_during_training", action='store_true', 526 | help="Rul evaluation during training at each logging step.") 527 | parser.add_argument("--do_lower_case", action='store_true', 528 | help="Set this flag if you are using an uncased model.") 529 | parser.add_argument("--do_plot", action='store_true', help="Whether to run plot.") 530 | parser.add_argument("--per_gpu_train_batch_size", default=8, type=int, 531 | help="Batch size per GPU/CPU for training.") 532 | parser.add_argument("--per_gpu_eval_batch_size", default=8, type=int, 533 | help="Batch size per GPU/CPU for evaluation.") 534 | parser.add_argument('--gradient_accumulation_steps', type=int, default=1, 535 | help="Number of updates steps to accumulate before performing a backward/update pass.") 536 | parser.add_argument("--learning_rate", default=5e-5, type=float, 537 | help="The initial learning rate for Adam.") 538 | parser.add_argument("--weight_decay", default=0.0, type=float, 539 | help="Weight deay if we apply some.") 540 | parser.add_argument("--adam_epsilon", default=1e-8, type=float, 541 | help="Epsilon for Adam optimizer.") 542 | parser.add_argument("--max_grad_norm", default=1.0, type=float, 543 | help="Max gradient norm.") 544 | parser.add_argument("--num_train_epochs", default=3.0, type=float, 545 | help="Total number of training epochs to perform.") 546 | parser.add_argument("--max_steps", default=-1, type=int, 547 | help="If > 0: set total number of training steps to perform. Override num_train_epochs.") 548 | parser.add_argument("--warmup_steps", default=0, type=int, 549 | help="Linear warmup over warmup_steps.") 550 | parser.add_argument('--logging_steps', type=int, default=10, 551 | help="Log every X updates steps.") 552 | parser.add_argument('--save_steps', type=int, default=100, 553 | help="Save checkpoint every X updates steps.") 554 | parser.add_argument("--eval_all_checkpoints", action='store_true', 555 | help="Evaluate all checkpoints starting with the same prefix as model_name ending and ending with step number") 556 | parser.add_argument("--no_cuda", action='store_true', 557 | help="Avoid using CUDA when available") 558 | parser.add_argument('--overwrite_output_dir', action='store_true', 559 | help="Overwrite the content of the output directory") 560 | parser.add_argument('--overwrite_cache', action='store_true', 561 | help="Overwrite the cached training and evaluation sets") 562 | parser.add_argument('--seed', type=int, default=42, 563 | help="random seed for initialization") 564 | parser.add_argument("--local_rank", type=int, default=-1, 565 | help="For distributed training: local_rank") 566 | args = parser.parse_args() 567 | 568 | if os.path.exists(args.output_dir) and os.listdir( 569 | args.output_dir) and args.do_train and not args.overwrite_output_dir: 570 | raise ValueError( 571 | "Output directory ({}) already exists and is not empty. Use --overwrite_output_dir to overcome.".format( 572 | args.output_dir)) 573 | 574 | # Setup CUDA, GPU & distributed training 575 | if args.local_rank == -1 or args.no_cuda: 576 | device = torch.device("cuda" if torch.cuda.is_available() and not args.no_cuda else "cpu") 577 | args.n_gpu = torch.cuda.device_count() 578 | else: # Initializes the distributed backend which will take care of sychronizing nodes/GPUs 579 | torch.cuda.set_device(args.local_rank) 580 | device = torch.device("cuda", args.local_rank) 581 | torch.distributed.init_process_group(backend='nccl') 582 | args.n_gpu = 1 583 | args.device = device 584 | 585 | # Setup logging 586 | logging.basicConfig(format='%(asctime)s - %(levelname)s - %(name)s - %(message)s', 587 | datefmt='%m/%d/%Y %H:%M:%S', 588 | level=logging.INFO if args.local_rank in [-1, 0] else logging.WARN) 589 | logger.warning("Process rank: %s, device: %s, n_gpu: %s, distributed training: %s", 590 | args.local_rank, device, args.n_gpu, bool(args.local_rank != -1)) 591 | 592 | # Set seed 593 | set_seed(args) 594 | 595 | # Prepare GLUE task 596 | args.task_name = args.task_name.lower() 597 | if args.task_name not in processors: 598 | raise ValueError("Task not found: %s" % (args.task_name)) 599 | processor = processors[args.task_name]() 600 | args.output_mode = output_modes[args.task_name] 601 | label_list = processor.get_labels() 602 | num_labels = len(label_list) 603 | 604 | # Load pretrained model and tokenizer 605 | if args.local_rank not in [-1, 0]: 606 | torch.distributed.barrier() # Make sure only the first process in distributed training will download model & vocab 607 | 608 | args.model_type = args.model_type.lower() 609 | config_class, model_class, tokenizer_class = MODEL_CLASSES[args.model_type] 610 | config = config_class.from_pretrained(args.config_name if args.config_name else args.model_name_or_path, 611 | num_labels=num_labels, 612 | finetuning_task=args.task_name, 613 | cache_dir=args.cache_dir if args.cache_dir else None) 614 | tokenizer = tokenizer_class.from_pretrained(args.tokenizer_name if args.tokenizer_name else args.model_name_or_path, 615 | do_lower_case=args.do_lower_case, 616 | cache_dir=args.cache_dir if args.cache_dir else None) 617 | model = model_class.from_pretrained(args.model_name_or_path, 618 | from_tf=bool('.ckpt' in args.model_name_or_path), 619 | config=config, 620 | cache_dir=args.cache_dir if args.cache_dir else None) 621 | 622 | if args.local_rank == 0: 623 | torch.distributed.barrier() # Make sure only the first process in distributed training will download model & vocab 624 | 625 | model.to(args.device) 626 | 627 | logger.info("Training/evaluation parameters %s", args) 628 | 629 | # Plotting 630 | if args.do_plot: 631 | if args.do_train: 632 | print ("Conflict Arguments!") 633 | print ("Do Plot and Do Train") 634 | raise 635 | else: 636 | print ("======= Plotting Noise Start =======\n") 637 | print ("====================================\n") 638 | train_dataset = load_and_cache_examples(args, args.task_name, tokenizer, evaluate=False) 639 | global_step, tr_loss = plot_noise(args, train_dataset, model, tokenizer) 640 | logger.info(" global_step = %s, average loss = %s", global_step, tr_loss) 641 | 642 | # Training 643 | if args.do_train: 644 | train_dataset = load_and_cache_examples(args, args.task_name, tokenizer, evaluate=False) 645 | global_step, tr_loss = train(args, train_dataset, model, tokenizer) 646 | logger.info(" global_step = %s, average loss = %s", global_step, tr_loss) 647 | 648 | # Saving best-practices: if you use defaults names for the model, you can reload it using from_pretrained() 649 | if args.do_train and (args.local_rank == -1 or torch.distributed.get_rank() == 0): 650 | # Create output directory if needed 651 | if not os.path.exists(args.output_dir) and args.local_rank in [-1, 0]: 652 | os.makedirs(args.output_dir) 653 | 654 | logger.info("Saving model checkpoint to %s", args.output_dir) 655 | # Save a trained model, configuration and tokenizer using `save_pretrained()`. 656 | # They can then be reloaded using `from_pretrained()` 657 | model_to_save = model.module if hasattr(model, 658 | 'module') else model # Take care of distributed/parallel training 659 | model_to_save.save_pretrained(args.output_dir) 660 | tokenizer.save_pretrained(args.output_dir) 661 | 662 | # Good practice: save your training arguments together with the trained model 663 | torch.save(args, os.path.join(args.output_dir, 'training_args.bin')) 664 | 665 | # Load a trained model and vocabulary that you have fine-tuned 666 | model = model_class.from_pretrained(args.output_dir) 667 | tokenizer = tokenizer_class.from_pretrained(args.output_dir) 668 | model.to(args.device) 669 | 670 | # Evaluation 671 | results = {} 672 | if args.do_eval and args.local_rank in [-1, 0]: 673 | tokenizer = tokenizer_class.from_pretrained(args.output_dir, do_lower_case=args.do_lower_case) 674 | checkpoints = [args.output_dir] 675 | if args.eval_all_checkpoints: 676 | checkpoints = list( 677 | os.path.dirname(c) for c in sorted(glob.glob(args.output_dir + '/**/' + WEIGHTS_NAME, recursive=True))) 678 | logging.getLogger("transformers.modeling_utils").setLevel(logging.WARN) # Reduce logging 679 | logger.info("Evaluate the following checkpoints: %s", checkpoints) 680 | for checkpoint in checkpoints: 681 | global_step = checkpoint.split('-')[-1] if len(checkpoints) > 1 else "" 682 | prefix = checkpoint.split('/')[-1] if checkpoint.find('checkpoint') != -1 else "" 683 | 684 | model = model_class.from_pretrained(checkpoint) 685 | model.to(args.device) 686 | result, eval_loss = evaluate(args, model, tokenizer, prefix=prefix) 687 | result = dict((k + '_{}'.format(global_step), v) for k, v in result.items()) 688 | results.update(result) 689 | return results 690 | 691 | 692 | if __name__ == "__main__": 693 | main() 694 | -------------------------------------------------------------------------------- /run_mrpc.sh: -------------------------------------------------------------------------------- 1 | export GLUE_DIR=./glue_data 2 | 3 | python run_mrpc.py \ 4 | --model_type bert \ 5 | --model_name_or_path bert-base-cased \ 6 | --task_name MRPC \ 7 | --optimizer acclip \ 8 | --do_plot \ 9 | --do_eval \ 10 | --do_lower_case \ 11 | --data_dir $GLUE_DIR/MRPC/ \ 12 | --max_seq_length 128 \ 13 | --per_gpu_train_batch_size 32 \ 14 | --learning_rate 2e-5 \ 15 | --num_train_epochs 5.0 \ 16 | --evaluate_during_training \ 17 | --overwrite_output_dir \ 18 | --output_dir ./outputs 19 | -------------------------------------------------------------------------------- /run_text_classifier.py: -------------------------------------------------------------------------------- 1 | from models.lstm_attn import LSTM_Attn 2 | import argparse 3 | import torch 4 | import torch.nn.functional as F 5 | from utils import load_dataset 6 | import sys, os, pickle 7 | from optimizers.ACClip import ACClip 8 | 9 | parser = argparse.ArgumentParser() 10 | parser.add_argument('--mode', type=str, default='train', help='mode=train or mode=plot') 11 | parser.add_argument('--model', type=str, default='lstm', help='name of the model') 12 | parser.add_argument('--checkpoint', type=str, default='', help='load existing checkpoint') 13 | parser.add_argument('--dataset', type=str, default='imdb', help='the dataset used for text classification') 14 | parser.add_argument('--epoch', type=int, default='20', help='The number of epochs') 15 | parser.add_argument('--optimizer', type=str, default='sgd', help='Type of otpimizer') 16 | parser.add_argument('--lr', type=float, default=0.001) 17 | parser.add_argument('--momentum', type=float, default=0.9) 18 | parser.add_argument('--batch_size', type=int, default=32) 19 | parser.add_argument('--hidden_size', type=int, default=256) 20 | parser.add_argument('--emb_dim', type=int, default=300) 21 | args = parser.parse_args() 22 | 23 | 24 | def print_noise(model, optimizer, train_batches, eval_batches, total_epoch, model_name): 25 | step = 0 26 | model.train() 27 | train_loss, test_loss, test_acc = [], [], [] 28 | diction = {} 29 | for epoch in range(total_epoch): 30 | # i = 0 -- calc average gradient 31 | # i = 1 -- calc noise 32 | # i = 2 -- update 33 | noise_sample = list() 34 | for i in range(3): 35 | local_step = 0 36 | total_loss = 0.0 37 | for batch in train_batches: 38 | optimizer.zero_grad() 39 | inputs, target = batch.text, batch.label 40 | if torch.cuda.is_available(): 41 | inputs, target = inputs.cuda(), target.cuda() 42 | pred = model(inputs) 43 | loss = F.cross_entropy(pred, target) 44 | total_loss += loss.item() 45 | loss.backward() 46 | params = list(model.parameters()) 47 | if (i <= 1): 48 | emb_grad = params[0].grad.data.view(-1) 49 | for j in range(2, 8): # params[1] is none and we need to skip 50 | emb_grad = torch.cat((emb_grad, params[j].grad.data.view(-1))) 51 | # the 0 th round -- calc average gradient 52 | if (i == 0): 53 | if (local_step == 0): 54 | avg_grad = emb_grad 55 | else: 56 | avg_grad += emb_grad 57 | # the 1 th round -- output noise 58 | if (i == 1): 59 | diff = emb_grad - avg_grad 60 | l2_norm = torch.norm(diff, p=2) 61 | # print(noise_sample,l2_norm) 62 | noise_sample.append(l2_norm.item()) 63 | # print(emb_grad.size()) 64 | if (i == 2): 65 | optimizer.step() 66 | step += 1 67 | local_step += 1 68 | if (i == 0): 69 | avg_grad /= local_step 70 | diction[epoch] = noise_sample 71 | eval_loss, eval_acc = eval(model, eval_batches) 72 | print ("validation loss {:.4f} and acc {:.4f}".format(eval_loss, eval_acc)) 73 | 74 | if not os.path.exists("./noises"): 75 | os.mkdir("./noises") 76 | with open(os.path.join('./noises', model_name + '_noise'), "wb") as f: 77 | pickle.dump(diction, f) 78 | 79 | def train(model, optimizer, train_batches, eval_batches, total_epoch, model_name): 80 | step = 0 81 | model.train() 82 | train_loss, test_loss, test_acc = [], [], [] 83 | for epoch in range(total_epoch): 84 | total_loss = 0.0 85 | for batch in train_batches: 86 | optimizer.zero_grad() 87 | inputs, target = batch.text, batch.label 88 | if torch.cuda.is_available(): 89 | inputs, target = inputs.cuda(), target.cuda() 90 | pred = model(inputs) 91 | loss = F.cross_entropy(pred, target) 92 | total_loss += loss.item() 93 | loss.backward() 94 | optimizer.step() 95 | step += 1 96 | if step % 100 == 0: 97 | sys.stdout.write('\rBatch[{}] - loss: {:.4f}\n'.format(step, loss.item())) 98 | train_loss.append(total_loss/len(train_batches)) 99 | eval_loss, eval_acc = eval(model, eval_batches) 100 | test_loss.append(eval_loss) 101 | test_acc.append(eval_acc) 102 | print ("validation loss {:.4f} and acc {:.4f}".format(eval_loss, eval_acc)) 103 | 104 | if not os.path.exists("./curves"): 105 | os.mkdir("./curves") 106 | with open(os.path.join('./curves', model_name), "wb") as f: 107 | pickle.dump({'train_loss': train_loss, 'test_loss': test_loss, 'test_acc': test_acc}, f) 108 | 109 | def eval(model, eval_batches): 110 | model.eval() 111 | total_epoch_loss = 0 112 | total_epoch_acc = 0 113 | model.eval() 114 | with torch.no_grad(): 115 | for idx, batch in enumerate(eval_batches): 116 | inputs = batch.text 117 | target = batch.label 118 | if torch.cuda.is_available(): 119 | inputs, target = inputs.cuda(), target.cuda() 120 | target = torch.autograd.Variable(target).long() 121 | prediction = model(inputs) 122 | loss = F.cross_entropy(prediction, target) 123 | num_corrects = (torch.max(prediction, 1)[1].view(target.size()).data == target.data).sum() 124 | acc = 100.0 * num_corrects/len(batch) 125 | total_epoch_loss += loss.item() 126 | total_epoch_acc += acc.item() 127 | model.train() 128 | return total_epoch_loss/len(eval_batches), total_epoch_acc/len(eval_batches) 129 | 130 | if __name__ == "__main__": 131 | print ("Start loading and processing dataset") 132 | train_batches, dev_batches, test_batches, class_num, vocab_size, word_embeds \ 133 | = load_dataset(dataset=args.dataset, batch_size=args.batch_size, word_dim=args.emb_dim) 134 | 135 | if args.model == "lstm": 136 | model = LSTM_Attn(args.batch_size, class_num, args.hidden_size, vocab_size, args.emb_dim, word_embeds) 137 | if torch.cuda.is_available(): 138 | model.cuda() 139 | 140 | if "sgd" in args.optimizer.lower(): 141 | optimizer = torch.optim.SGD(model.parameters(), lr=args.lr, momentum=args.momentum) 142 | elif args.optimizer == "adam": 143 | optimizer = torch.optim.Adam(model.parameters(), lr=args.lr) 144 | elif "acclip" in args.optimizer.lower(): 145 | optimizer = ACClip(model.parameters(), lr=args.lr, weight_decay=0) 146 | 147 | model_name = "{}-{}-{}".format(args.optimizer, args.dataset, args.model) 148 | if args.mode.lower() == 'plot': 149 | print ("Start Plotting Noise Norm!") 150 | print_noise(model, optimizer, train_batches, dev_batches, args.epoch, model_name) 151 | else: 152 | print ("Start training models!") 153 | train(model, optimizer, train_batches, dev_batches, args.epoch, model_name) -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | import torch 4 | from torchtext import data 5 | from torchtext import datasets 6 | from torchtext.vocab import Vectors, GloVe 7 | 8 | 9 | def load_dataset(dataset="imdb", batch_size=32, word_dim=300): 10 | """ 11 | tokenizer : Breaks sentences into a list of words. If sequential=False, no tokenization is applied 12 | Field : A class that stores information about the way of preprocessing 13 | fix_length : An important property of TorchText is that we can let the input to be variable length, and TorchText will 14 | dynamically pad each sequence to the longest sequence in that "batch". But here we are using fi_length which 15 | will pad each sequence to have a fix length of 200. 16 | 17 | build_vocab : It will first make a vocabulary or dictionary mapping all the unique words present in the train_data to an 18 | idx and then after it will use GloVe word embedding to map the index to the corresponding word embedding. 19 | 20 | vocab.vectors : This returns a torch tensor of shape (vocab_size x embedding_dim) containing the pre-trained word embeddings. 21 | BucketIterator : Defines an iterator that batches examples of similar lengths together to minimize the amount of padding needed. 22 | """ 23 | 24 | tokenize = lambda x: x.split() 25 | TEXT = data.Field(sequential=True, tokenize=tokenize, lower=True, include_lengths=False, batch_first=True, 26 | fix_length=200) 27 | LABEL = data.LabelField(sequential=False) 28 | if dataset == "imdb": 29 | train_data, test_data = datasets.IMDB.splits(TEXT, LABEL) 30 | elif dataset == "sst": 31 | train_data, dev_data, test_data = datasets.SST.splits(TEXT, LABEL, 32 | fine_grained=False) 33 | print ("after split, we start to build vocab") 34 | TEXT.build_vocab(train_data, vectors=GloVe(name='6B', dim=word_dim), min_freq=10) 35 | LABEL.build_vocab(train_data) 36 | 37 | word_embeddings = TEXT.vocab.vectors 38 | print ("Length of Text Vocabulary: " + str(len(TEXT.vocab))) 39 | print ("Vector size of Text Vocabulary: ", TEXT.vocab.vectors.size()) 40 | print ("Label Length: " + str(len(LABEL.vocab))) 41 | 42 | train_data, valid_data = train_data.split() # Further splitting of train & validation 43 | print ("train: {}; dev: {}; test: {}".format(len(train_data), len(valid_data), len(test_data))) 44 | train_iter, valid_iter, test_iter = data.BucketIterator.splits((train_data, valid_data, test_data), 45 | batch_size=batch_size, 46 | sort_key=lambda x: len(x.text), repeat=False, 47 | shuffle=True) 48 | 49 | '''Alternatively we can also use the default configurations''' 50 | # train_iter, test_iter = datasets.IMDB.iters(batch_size=32) 51 | vocab_size = len(TEXT.vocab) 52 | class_num = len(LABEL.vocab) 53 | 54 | return train_iter, valid_iter, test_iter, class_num, vocab_size, word_embeddings --------------------------------------------------------------------------------