├── assets └── ftlm.png ├── analysis.py ├── LICENSE ├── .gitignore ├── datasets.py ├── utils.py ├── parameters_names.json ├── text_utils.py ├── README.md ├── opt.py ├── model_py.py └── train.py /assets/ftlm.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/0101011/pytorch-openai-transformer-lm/master/assets/ftlm.png -------------------------------------------------------------------------------- /analysis.py: -------------------------------------------------------------------------------- 1 | import os 2 | import json 3 | import numpy as np 4 | import pandas as pd 5 | 6 | from sklearn.metrics import accuracy_score 7 | 8 | from datasets import _rocstories 9 | 10 | def rocstories(data_dir, pred_path, log_path): 11 | preds = pd.read_csv(pred_path, delimiter='\t')['prediction'].values.tolist() 12 | _, _, _, labels = _rocstories(os.path.join(data_dir, 'cloze_test_test__spring2016 - cloze_test_ALL_test.csv')) 13 | test_accuracy = accuracy_score(labels, preds)*100. 14 | logs = [json.loads(line) for line in open(log_path)][1:] 15 | best_validation_index = np.argmax([log['va_acc'] for log in logs]) 16 | valid_accuracy = logs[best_validation_index]['va_acc'] 17 | print('ROCStories Valid Accuracy: %.2f'%(valid_accuracy)) 18 | print('ROCStories Test Accuracy: %.2f'%(test_accuracy)) 19 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2018 OpenAI 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Model, ipynb_checkpoints 2 | model 3 | save 4 | log 5 | submission 6 | 7 | .vscode 8 | 9 | # Byte-compiled / optimized / DLL files 10 | __pycache__/ 11 | *.py[cod] 12 | *$py.class 13 | 14 | # C extensions 15 | *.so 16 | 17 | # Distribution / packaging 18 | .Python 19 | build/ 20 | develop-eggs/ 21 | dist/ 22 | downloads/ 23 | eggs/ 24 | .eggs/ 25 | lib/ 26 | lib64/ 27 | parts/ 28 | sdist/ 29 | var/ 30 | wheels/ 31 | *.egg-info/ 32 | .installed.cfg 33 | *.egg 34 | MANIFEST 35 | 36 | # PyInstaller 37 | # Usually these files are written by a python script from a template 38 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 39 | *.manifest 40 | *.spec 41 | 42 | # Installer logs 43 | pip-log.txt 44 | pip-delete-this-directory.txt 45 | 46 | # Unit test / coverage reports 47 | htmlcov/ 48 | .tox/ 49 | .coverage 50 | .coverage.* 51 | .cache 52 | nosetests.xml 53 | coverage.xml 54 | *.cover 55 | .hypothesis/ 56 | .pytest_cache/ 57 | 58 | # Translations 59 | *.mo 60 | *.pot 61 | 62 | # Django stuff: 63 | *.log 64 | local_settings.py 65 | db.sqlite3 66 | 67 | # Flask stuff: 68 | instance/ 69 | .webassets-cache 70 | 71 | # Scrapy stuff: 72 | .scrapy 73 | 74 | # Sphinx documentation 75 | docs/_build/ 76 | 77 | # PyBuilder 78 | target/ 79 | 80 | # Jupyter Notebook 81 | .ipynb_checkpoints 82 | 83 | # pyenv 84 | .python-version 85 | 86 | # celery beat schedule file 87 | celerybeat-schedule 88 | 89 | # SageMath parsed files 90 | *.sage.py 91 | 92 | # Environments 93 | .env 94 | .venv 95 | env/ 96 | venv/ 97 | ENV/ 98 | env.bak/ 99 | venv.bak/ 100 | 101 | # Spyder project settings 102 | .spyderproject 103 | .spyproject 104 | 105 | # Rope project settings 106 | .ropeproject 107 | 108 | # mkdocs documentation 109 | /site 110 | 111 | # mypy 112 | .mypy_cache/ -------------------------------------------------------------------------------- /datasets.py: -------------------------------------------------------------------------------- 1 | import os 2 | import csv 3 | import numpy as np 4 | 5 | from tqdm import tqdm 6 | 7 | from sklearn.utils import shuffle 8 | from sklearn.model_selection import train_test_split 9 | 10 | seed = 3535999445 11 | 12 | def _rocstories(path): 13 | with open(path) as f: 14 | f = csv.reader(f) 15 | st = [] 16 | ct1 = [] 17 | ct2 = [] 18 | y = [] 19 | for i, line in enumerate(tqdm(list(f), ncols=80, leave=False)): 20 | if i > 0: 21 | s = ' '.join(line[1:5]) 22 | c1 = line[5] 23 | c2 = line[6] 24 | st.append(s) 25 | ct1.append(c1) 26 | ct2.append(c2) 27 | y.append(int(line[-1])-1) 28 | return st, ct1, ct2, y 29 | 30 | def rocstories(data_dir, n_train=1497, n_valid=374): 31 | storys, comps1, comps2, ys = _rocstories(os.path.join(data_dir, 'cloze_test_val__spring2016 - cloze_test_ALL_val.csv')) 32 | teX1, teX2, teX3, _ = _rocstories(os.path.join(data_dir, 'cloze_test_test__spring2016 - cloze_test_ALL_test.csv')) 33 | tr_storys, va_storys, tr_comps1, va_comps1, tr_comps2, va_comps2, tr_ys, va_ys = train_test_split(storys, comps1, comps2, ys, test_size=n_valid, random_state=seed) 34 | trX1, trX2, trX3 = [], [], [] 35 | trY = [] 36 | for s, c1, c2, y in zip(tr_storys, tr_comps1, tr_comps2, tr_ys): 37 | trX1.append(s) 38 | trX2.append(c1) 39 | trX3.append(c2) 40 | trY.append(y) 41 | 42 | vaX1, vaX2, vaX3 = [], [], [] 43 | vaY = [] 44 | for s, c1, c2, y in zip(va_storys, va_comps1, va_comps2, va_ys): 45 | vaX1.append(s) 46 | vaX2.append(c1) 47 | vaX3.append(c2) 48 | vaY.append(y) 49 | trY = np.asarray(trY, dtype=np.int32) 50 | vaY = np.asarray(vaY, dtype=np.int32) 51 | return (trX1, trX2, trX3, trY), (vaX1, vaX2, vaX3, vaY), (teX1, teX2, teX3) 52 | -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | import json 4 | import time 5 | from functools import partial 6 | import numpy as np 7 | # import tensorflow as tf 8 | # from tensorflow.python.framework import function 9 | from tqdm import tqdm 10 | 11 | def encode_dataset(*splits, encoder): 12 | encoded_splits = [] 13 | for split in splits[0]: 14 | fields = [] 15 | for field in split: 16 | if isinstance(field[0], str): 17 | field = encoder.encode(field) 18 | fields.append(field) 19 | encoded_splits.append(fields) 20 | return encoded_splits 21 | 22 | def stsb_label_encoding(labels, nclass=6): 23 | """ 24 | Label encoding from Tree LSTM paper (Tai, Socher, Manning) 25 | """ 26 | Y = np.zeros((len(labels), nclass)).astype(np.float32) 27 | for j, y in enumerate(labels): 28 | for i in range(nclass): 29 | if i == np.floor(y) + 1: 30 | Y[j,i] = y - np.floor(y) 31 | if i == np.floor(y): 32 | Y[j,i] = np.floor(y) - y + 1 33 | return Y 34 | 35 | def np_softmax(x, t=1): 36 | x = x/t 37 | x = x - np.max(x, axis=-1, keepdims=True) 38 | ex = np.exp(x) 39 | return ex/np.sum(ex, axis=-1, keepdims=True) 40 | 41 | def make_path(f): 42 | d = os.path.dirname(f) 43 | if d and not os.path.exists(d): 44 | os.makedirs(d) 45 | return f 46 | 47 | def _identity_init(shape, dtype, partition_info, scale): 48 | n = shape[-1] 49 | w = np.eye(n)*scale 50 | if len([s for s in shape if s != 1]) == 2: 51 | w = w.reshape(shape) 52 | return w.astype(np.float32) 53 | 54 | def identity_init(scale=1.0): 55 | return partial(_identity_init, scale=scale) 56 | 57 | def _np_init(shape, dtype, partition_info, w): 58 | return w 59 | 60 | def np_init(w): 61 | return partial(_np_init, w=w) 62 | 63 | class ResultLogger(object): 64 | def __init__(self, path, *args, **kwargs): 65 | if 'time' not in kwargs: 66 | kwargs['time'] = time.time() 67 | self.f_log = open(make_path(path), 'w') 68 | self.f_log.write(json.dumps(kwargs)+'\n') 69 | 70 | def log(self, **kwargs): 71 | if 'time' not in kwargs: 72 | kwargs['time'] = time.time() 73 | self.f_log.write(json.dumps(kwargs)+'\n') 74 | self.f_log.flush() 75 | 76 | def close(self): 77 | self.f_log.close() 78 | 79 | def flatten(outer): 80 | return [el for inner in outer for el in inner] 81 | 82 | def remove_none(l): 83 | return [e for e in l if e is not None] 84 | 85 | def iter_data(*datas, n_batch=128, truncate=False, verbose=False, max_batches=float("inf")): 86 | n = len(datas[0]) 87 | if truncate: 88 | n = (n//n_batch)*n_batch 89 | n = min(n, max_batches*n_batch) 90 | n_batches = 0 91 | if verbose: 92 | f = sys.stderr 93 | else: 94 | f = open(os.devnull, 'w') 95 | for i in tqdm(range(0, n, n_batch), total=n//n_batch, file=f, ncols=80, leave=False): 96 | if n_batches >= max_batches: raise StopIteration 97 | if len(datas) == 1: 98 | yield datas[0][i:i+n_batch] 99 | else: 100 | yield (d[i:i+n_batch] for d in datas) 101 | n_batches += 1 102 | -------------------------------------------------------------------------------- /parameters_names.json: -------------------------------------------------------------------------------- 1 | ["model/we:0", "model/h0/attn/c_attn/w:0", "model/h0/attn/c_attn/b:0", "model/h0/attn/c_proj/w:0", "model/h0/attn/c_proj/b:0", "model/h0/ln_1/g:0", "model/h0/ln_1/b:0", "model/h0/mlp/c_fc/w:0", "model/h0/mlp/c_fc/b:0", "model/h0/mlp/c_proj/w:0", "model/h0/mlp/c_proj/b:0", "model/h0/ln_2/g:0", "model/h0/ln_2/b:0", "model/h1/attn/c_attn/w:0", "model/h1/attn/c_attn/b:0", "model/h1/attn/c_proj/w:0", "model/h1/attn/c_proj/b:0", "model/h1/ln_1/g:0", "model/h1/ln_1/b:0", "model/h1/mlp/c_fc/w:0", "model/h1/mlp/c_fc/b:0", "model/h1/mlp/c_proj/w:0", "model/h1/mlp/c_proj/b:0", "model/h1/ln_2/g:0", "model/h1/ln_2/b:0", "model/h2/attn/c_attn/w:0", "model/h2/attn/c_attn/b:0", "model/h2/attn/c_proj/w:0", "model/h2/attn/c_proj/b:0", "model/h2/ln_1/g:0", "model/h2/ln_1/b:0", "model/h2/mlp/c_fc/w:0", "model/h2/mlp/c_fc/b:0", "model/h2/mlp/c_proj/w:0", "model/h2/mlp/c_proj/b:0", "model/h2/ln_2/g:0", "model/h2/ln_2/b:0", "model/h3/attn/c_attn/w:0", "model/h3/attn/c_attn/b:0", "model/h3/attn/c_proj/w:0", "model/h3/attn/c_proj/b:0", "model/h3/ln_1/g:0", "model/h3/ln_1/b:0", "model/h3/mlp/c_fc/w:0", "model/h3/mlp/c_fc/b:0", "model/h3/mlp/c_proj/w:0", "model/h3/mlp/c_proj/b:0", "model/h3/ln_2/g:0", "model/h3/ln_2/b:0", "model/h4/attn/c_attn/w:0", "model/h4/attn/c_attn/b:0", "model/h4/attn/c_proj/w:0", "model/h4/attn/c_proj/b:0", "model/h4/ln_1/g:0", "model/h4/ln_1/b:0", "model/h4/mlp/c_fc/w:0", "model/h4/mlp/c_fc/b:0", "model/h4/mlp/c_proj/w:0", "model/h4/mlp/c_proj/b:0", "model/h4/ln_2/g:0", "model/h4/ln_2/b:0", "model/h5/attn/c_attn/w:0", "model/h5/attn/c_attn/b:0", "model/h5/attn/c_proj/w:0", "model/h5/attn/c_proj/b:0", "model/h5/ln_1/g:0", "model/h5/ln_1/b:0", "model/h5/mlp/c_fc/w:0", "model/h5/mlp/c_fc/b:0", "model/h5/mlp/c_proj/w:0", "model/h5/mlp/c_proj/b:0", "model/h5/ln_2/g:0", "model/h5/ln_2/b:0", "model/h6/attn/c_attn/w:0", "model/h6/attn/c_attn/b:0", "model/h6/attn/c_proj/w:0", "model/h6/attn/c_proj/b:0", "model/h6/ln_1/g:0", "model/h6/ln_1/b:0", "model/h6/mlp/c_fc/w:0", "model/h6/mlp/c_fc/b:0", "model/h6/mlp/c_proj/w:0", "model/h6/mlp/c_proj/b:0", "model/h6/ln_2/g:0", "model/h6/ln_2/b:0", "model/h7/attn/c_attn/w:0", "model/h7/attn/c_attn/b:0", "model/h7/attn/c_proj/w:0", "model/h7/attn/c_proj/b:0", "model/h7/ln_1/g:0", "model/h7/ln_1/b:0", "model/h7/mlp/c_fc/w:0", "model/h7/mlp/c_fc/b:0", "model/h7/mlp/c_proj/w:0", "model/h7/mlp/c_proj/b:0", "model/h7/ln_2/g:0", "model/h7/ln_2/b:0", "model/h8/attn/c_attn/w:0", "model/h8/attn/c_attn/b:0", "model/h8/attn/c_proj/w:0", "model/h8/attn/c_proj/b:0", "model/h8/ln_1/g:0", "model/h8/ln_1/b:0", "model/h8/mlp/c_fc/w:0", "model/h8/mlp/c_fc/b:0", "model/h8/mlp/c_proj/w:0", "model/h8/mlp/c_proj/b:0", "model/h8/ln_2/g:0", "model/h8/ln_2/b:0", "model/h9/attn/c_attn/w:0", "model/h9/attn/c_attn/b:0", "model/h9/attn/c_proj/w:0", "model/h9/attn/c_proj/b:0", "model/h9/ln_1/g:0", "model/h9/ln_1/b:0", "model/h9/mlp/c_fc/w:0", "model/h9/mlp/c_fc/b:0", "model/h9/mlp/c_proj/w:0", "model/h9/mlp/c_proj/b:0", "model/h9/ln_2/g:0", "model/h9/ln_2/b:0", "model/h10/attn/c_attn/w:0", "model/h10/attn/c_attn/b:0", "model/h10/attn/c_proj/w:0", "model/h10/attn/c_proj/b:0", "model/h10/ln_1/g:0", "model/h10/ln_1/b:0", "model/h10/mlp/c_fc/w:0", "model/h10/mlp/c_fc/b:0", "model/h10/mlp/c_proj/w:0", "model/h10/mlp/c_proj/b:0", "model/h10/ln_2/g:0", "model/h10/ln_2/b:0", "model/h11/attn/c_attn/w:0", "model/h11/attn/c_attn/b:0", "model/h11/attn/c_proj/w:0", "model/h11/attn/c_proj/b:0", "model/h11/ln_1/g:0", "model/h11/ln_1/b:0", "model/h11/mlp/c_fc/w:0", "model/h11/mlp/c_fc/b:0", "model/h11/mlp/c_proj/w:0", "model/h11/mlp/c_proj/b:0", "model/h11/ln_2/g:0", "model/h11/ln_2/b:0", "model/clf/w:0", "model/clf/b:0"] -------------------------------------------------------------------------------- /text_utils.py: -------------------------------------------------------------------------------- 1 | import re 2 | import ftfy 3 | import json 4 | import spacy 5 | 6 | from tqdm import tqdm 7 | 8 | def get_pairs(word): 9 | """ 10 | Return set of symbol pairs in a word. 11 | word is represented as tuple of symbols (symbols being variable-length strings) 12 | """ 13 | pairs = set() 14 | prev_char = word[0] 15 | for char in word[1:]: 16 | pairs.add((prev_char, char)) 17 | prev_char = char 18 | return pairs 19 | 20 | def text_standardize(text): 21 | """ 22 | fixes some issues the spacy tokenizer had on books corpus 23 | also does some whitespace standardization 24 | """ 25 | text = text.replace('—', '-') 26 | text = text.replace('–', '-') 27 | text = text.replace('―', '-') 28 | text = text.replace('…', '...') 29 | text = text.replace('´', "'") 30 | text = re.sub(r'''(-+|~+|!+|"+|;+|\?+|\++|,+|\)+|\(+|\\+|\/+|\*+|\[+|\]+|}+|{+|\|+|_+)''', r' \1 ', text) 31 | text = re.sub(r'\s*\n\s*', ' \n ', text) 32 | text = re.sub(r'[^\S\n]+', ' ', text) 33 | return text.strip() 34 | 35 | class TextEncoder(object): 36 | """ 37 | mostly a wrapper for a public python bpe tokenizer 38 | """ 39 | 40 | def __init__(self, encoder_path, bpe_path): 41 | self.nlp = spacy.load('en', disable=['parser', 'tagger', 'ner', 'textcat']) 42 | self.encoder = json.load(open(encoder_path)) 43 | self.decoder = {v:k for k,v in self.encoder.items()} 44 | merges = open(bpe_path).read().split('\n')[1:-1] 45 | merges = [tuple(merge.split()) for merge in merges] 46 | self.bpe_ranks = dict(zip(merges, range(len(merges)))) 47 | self.cache = {} 48 | 49 | def bpe(self, token): 50 | word = tuple(token[:-1]) + ( token[-1] + '',) 51 | if token in self.cache: 52 | return self.cache[token] 53 | pairs = get_pairs(word) 54 | 55 | if not pairs: 56 | return token+'' 57 | 58 | while True: 59 | bigram = min(pairs, key = lambda pair: self.bpe_ranks.get(pair, float('inf'))) 60 | if bigram not in self.bpe_ranks: 61 | break 62 | first, second = bigram 63 | new_word = [] 64 | i = 0 65 | while i < len(word): 66 | try: 67 | j = word.index(first, i) 68 | new_word.extend(word[i:j]) 69 | i = j 70 | except: 71 | new_word.extend(word[i:]) 72 | break 73 | 74 | if word[i] == first and i < len(word)-1 and word[i+1] == second: 75 | new_word.append(first+second) 76 | i += 2 77 | else: 78 | new_word.append(word[i]) 79 | i += 1 80 | new_word = tuple(new_word) 81 | word = new_word 82 | if len(word) == 1: 83 | break 84 | else: 85 | pairs = get_pairs(word) 86 | word = ' '.join(word) 87 | if word == '\n ': 88 | word = '\n' 89 | self.cache[token] = word 90 | return word 91 | 92 | def encode(self, texts, verbose=True): 93 | texts_tokens = [] 94 | if verbose: 95 | for text in tqdm(texts, ncols=80, leave=False): 96 | text = self.nlp(text_standardize(ftfy.fix_text(text))) 97 | text_tokens = [] 98 | for token in text: 99 | text_tokens.extend([self.encoder.get(t, 0) for t in self.bpe(token.text.lower()).split(' ')]) 100 | texts_tokens.append(text_tokens) 101 | else: 102 | for text in texts: 103 | text = self.nlp(text_standardize(ftfy.fix_text(text))) 104 | text_tokens = [] 105 | for token in text: 106 | text_tokens.extend([self.encoder.get(t, 0) for t in self.bpe(token.text.lower()).split(' ')]) 107 | texts_tokens.append(text_tokens) 108 | return texts_tokens 109 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # PyTorch implementation of OpenAI's Finetuned Transformer Language Model 2 | 3 | This is a PyTorch implementation of the [TensorFlow code](https://github.com/openai/finetune-transformer-lm) provided with OpenAI's paper ["Improving Language Understanding by Generative Pre-Training"](https://blog.openai.com/language-unsupervised/) by Alec Radford, Karthik Narasimhan, Tim Salimans and Ilya Sutskever. 4 | 5 | This implementation comprises **a script to load in the PyTorch model the weights pre-trained by the authors** with the TensorFlow implementation. 6 | 7 | ![Transformer Language Model](assets/ftlm.png) 8 | 9 | The model classes and loading script are located in [model_py.py](model_py.py). 10 | 11 | The names of the modules in the PyTorch model follow the names of the Variable in the TensorFlow implementation. This implementation tries to follow the original code as closely as possible to minimize the discrepancies. 12 | 13 | This implementation thus also comprises a modified Adam optimization algorithm as used in OpenAI's paper with: 14 | - fixed weights decay following the work of [Loshchilov et al.](https://arxiv.org/abs/1711.05101), and 15 | - scheduled learning rate as [commonly used for Transformers](http://nlp.seas.harvard.edu/2018/04/03/attention.html#optimizer). 16 | 17 | ## Requirements 18 | To use the model it-self by importing [model_py.py](model_py.py), you just need: 19 | - PyTorch (version >=0.4) 20 | 21 | To run the classifier training script in [train.py](train.py) you will need in addition: 22 | - tqdm 23 | - sklearn 24 | - spacy 25 | - ftfy 26 | - pandas 27 | 28 | You can download the weights of the OpenAI pre-trained version by cloning [Alec Radford's repo](https://github.com/openai/finetune-transformer-lm) and placing the `model` folder containing the pre-trained weights in the present repo. 29 | 30 | ## Using the pre-trained model as a Transformer Language Model 31 | The model can be used as a transformer language model with OpenAI's pre-trained weights as follow: 32 | ```python 33 | from model_py import Model, load_openai_pretrained_model, DEFAULT_CONFIG 34 | 35 | args = DEFAULT_CONFIG 36 | vocab = 40000 # Size of your vocabulary 37 | model = Model(vocab, args) 38 | load_openai_pretrained_model(model) 39 | ``` 40 | 41 | This model generates Transformer's hidden states. You can use the `LMHead` class in [model.py](model.py) to add a decoder tied with the weights of the encoder and get a full language model. You can also use the `ClfHead` class in [model.py](model.py) to add a classifier on top of the transformer and get a classifier as described in OpenAI's publication. (see an example of both in the `__main__` function of [train.py](train.py)) 42 | 43 | To use the positional encoder of the transformer, you should encode your dataset using the `encode_dataset()` function of [utils.py](utils.py). Please refer to the beginning of the `__main__` function in [train.py](train.py) to see how to properly define the vocabulary and encode your dataset. 44 | 45 | ## Fine-tuning the pre-trained model on a classification task 46 | This model can also be integrated in a classifier as detailed in [OpenAI's paper](https://blog.openai.com/language-unsupervised/). An example of fine-tuning on the ROCStories Cloze task is included with the training code in [train.py](train.py) 47 | 48 | The ROCStories dataset can be downloaded from the associated [website](http://cs.rochester.edu/nlp/rocstories/). 49 | 50 | As with the [TensorFlow code](https://github.com/openai/finetune-transformer-lm), this code implements the ROCStories Cloze Test result reported in the paper which can be reproduced by running: 51 | 52 | ```bash 53 | python train.py --dataset rocstories --desc rocstories --submit --analysis --data_dir [path to data here] 54 | ``` 55 | 56 | #### Accuracy on the ROCStories test set 57 | Finetuning the PyTorch model for 3 Epochs on ROCStories takes 10 minutes to run on a single NVidia K-80. 58 | 59 | The test accuracy of this PyTorch version (with the default TensorFlow hyper-parameters) is 83.43%. 60 | 61 | The authors reports a median accuracy of 10 runs with the TensorFlow code of 85.8%. 62 | The paper reports a best accuracy of 86.5%. 63 | 64 | As noted by the author, the code can be non-deterministic due to various GPU ops. 65 | 66 | ### TO-DO list 67 | - [ ] Add Multi-GPU training logic 68 | -------------------------------------------------------------------------------- /opt.py: -------------------------------------------------------------------------------- 1 | import math 2 | import torch 3 | from torch.optim import Optimizer 4 | from torch.nn.utils import clip_grad_norm_ 5 | 6 | def warmup_cosine(x, warmup=0.002): 7 | s = 1 if x <= warmup else 0 8 | return s*(x/warmup) + (1-s)*(0.5 * (1 + torch.cos(math.pi * x))) 9 | 10 | def warmup_constant(x, warmup=0.002): 11 | s = 1 if x <= warmup else 0 12 | return s*(x/warmup) + (1-s)*1 13 | 14 | def warmup_linear(x, warmup=0.002): 15 | s = 1 if x <= warmup else 0 16 | return (s*(x/warmup) + (1-s))*(1-x) 17 | 18 | SCHEDULES = { 19 | 'warmup_cosine':warmup_cosine, 20 | 'warmup_constant':warmup_constant, 21 | 'warmup_linear':warmup_linear, 22 | } 23 | 24 | 25 | class OpenAIAdam(Optimizer): 26 | """Implements Open AI version of Adam algorithm with weight decay fix. 27 | """ 28 | def __init__(self, params, lr, schedule, warmup, t_total, 29 | b1=0.9, b2=0.999, e=1e-8, l2=0, 30 | vector_l2=False, max_grad_norm=-1, **kwargs): 31 | if not 0.0 <= lr: 32 | raise ValueError("Invalid learning rate: {}".format(lr)) 33 | if schedule not in SCHEDULES: 34 | raise ValueError("Invalid schedule parameter: {}".format(schedule)) 35 | if not 0 <= warmup: 36 | raise ValueError("Invalid warmup: {}".format(warmup)) 37 | if not 0.0 <= b1 < 1.0: 38 | raise ValueError("Invalid b1 parameter: {}".format(b1)) 39 | if not 0.0 <= b2 < 1.0: 40 | raise ValueError("Invalid b2 parameter: {}".format(b2)) 41 | if not 0.0 <= e: 42 | raise ValueError("Invalid epsilon value: {}".format(e)) 43 | defaults = dict(lr=lr, schedule=schedule, warmup=warmup, t_total=t_total, 44 | b1=b1, b2=b2, e=e, l2=l2, vector_l2=vector_l2, 45 | max_grad_norm=max_grad_norm) 46 | super(OpenAIAdam, self).__init__(params, defaults) 47 | 48 | def step(self, closure=None): 49 | """Performs a single optimization step. 50 | 51 | Arguments: 52 | closure (callable, optional): A closure that reevaluates the model 53 | and returns the loss. 54 | """ 55 | loss = None 56 | if closure is not None: 57 | loss = closure() 58 | 59 | for group in self.param_groups: 60 | for p in group['params']: 61 | if p.grad is None: 62 | continue 63 | grad = p.grad.data 64 | if grad.is_sparse: 65 | raise RuntimeError('Adam does not support sparse gradients, please consider SparseAdam instead') 66 | 67 | state = self.state[p] 68 | 69 | # State initialization 70 | if len(state) == 0: 71 | state['step'] = 0 72 | # Exponential moving average of gradient values 73 | state['exp_avg'] = torch.zeros_like(p.data) 74 | # Exponential moving average of squared gradient values 75 | state['exp_avg_sq'] = torch.zeros_like(p.data) 76 | 77 | exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq'] 78 | beta1, beta2 = group['b1'], group['b2'] 79 | 80 | state['step'] += 1 81 | 82 | # Add grad clipping 83 | if group['max_grad_norm'] > 0: 84 | clip_grad_norm_(p, group['max_grad_norm']) 85 | 86 | # Decay the first and second moment running average coefficient 87 | exp_avg.mul_(beta1).add_(1 - beta1, grad) 88 | exp_avg_sq.mul_(beta2).addcmul_(1 - beta2, grad, grad) 89 | denom = exp_avg_sq.sqrt().add_(group['e']) 90 | 91 | bias_correction1 = 1 - beta1 ** state['step'] 92 | bias_correction2 = 1 - beta2 ** state['step'] 93 | 94 | schedule_fct = SCHEDULES[group['schedule']] 95 | lr_scheduled = group['lr'] * schedule_fct(state['step']/group['t_total'], group['warmup']) 96 | step_size = lr_scheduled * math.sqrt(bias_correction2) / bias_correction1 97 | 98 | p.data.addcdiv_(-step_size, exp_avg, denom) 99 | 100 | # Add weight decay at the end (fixed version) 101 | if (len(p.size()) > 1 or group['vector_l2']) and group['l2'] > 0: 102 | p.data.add_(-lr_scheduled * group['l2'], p.data) 103 | 104 | return loss 105 | -------------------------------------------------------------------------------- /model_py.py: -------------------------------------------------------------------------------- 1 | import re 2 | import math 3 | import json 4 | import copy 5 | import numpy as np 6 | 7 | import torch 8 | import torch.nn as nn 9 | from torch.nn.parameter import Parameter 10 | 11 | def gelu(x): 12 | return 0.5*x*(1+torch.tanh(math.sqrt(2/math.pi)*(x+0.044715*torch.pow(x, 3)))) 13 | 14 | def swish(x): 15 | return x*torch.sigmoid(x) 16 | 17 | ACT_FNS = { 18 | 'relu': nn.ReLU, 19 | 'swish': swish, 20 | 'gelu': gelu 21 | } 22 | 23 | 24 | class LayerNorm(nn.Module): 25 | "Construct a layernorm module (See citation for details)." 26 | def __init__(self, n_state, e=1e-5): 27 | super(LayerNorm, self).__init__() 28 | self.g = nn.Parameter(torch.ones(n_state)) 29 | self.b = nn.Parameter(torch.zeros(n_state)) 30 | self.e = e 31 | 32 | def forward(self, x): 33 | u = x.mean(-1, keepdim=True) 34 | s = (x - u).pow(2).mean(-1, keepdim=True) 35 | x = (x - u) / torch.sqrt(s + self.e) 36 | return self.g * x + self.b 37 | 38 | 39 | class Conv1D(nn.Module): 40 | def __init__(self, nf, rf, nx): 41 | super(Conv1D, self).__init__() 42 | self.rf = rf 43 | self.nf = nf 44 | if rf == 1: #faster 1x1 conv 45 | w = torch.empty(nx, nf) 46 | nn.init.normal_(w, std=0.02) 47 | self.w = Parameter(w) 48 | self.b = Parameter(torch.zeros(nf)) 49 | else: #was used to train LM 50 | raise NotImplementedError 51 | 52 | def forward(self, x): 53 | if self.rf == 1: 54 | size_out = x.size()[:-1] + (self.nf,) 55 | x = torch.addmm(self.b, x.view(-1, x.size(-1)), self.w) 56 | x = x.view(*size_out) 57 | else: 58 | raise NotImplementedError 59 | return x 60 | 61 | 62 | class Attention(nn.Module): 63 | def __init__(self, nx, n_ctx, cfg, scale=False): 64 | super(Attention, self).__init__() 65 | n_state = nx # in Attention: n_state=768 (nx=n_embd) 66 | #[switch nx => n_state from Block to Attention to keep identical to TF implem] 67 | assert n_state % cfg.n_head==0 68 | self.register_buffer('b', torch.tril(torch.ones(n_ctx, n_ctx)).view(1, 1, n_ctx, n_ctx)) 69 | self.n_head = cfg.n_head 70 | self.split_size = n_state 71 | self.scale = scale 72 | self.c_attn = Conv1D(n_state * 3, 1, nx) 73 | self.c_proj = Conv1D(n_state, 1, nx) 74 | self.attn_dropout = nn.Dropout(cfg.attn_pdrop) 75 | self.resid_dropout = nn.Dropout(cfg.resid_pdrop) 76 | 77 | def _attn(self, q, k, v): 78 | w = torch.matmul(q, k) 79 | if self.scale: 80 | w = w / math.sqrt(v.size(-1)) 81 | w = w * self.b + -1e9*(1-self.b) # TF implem method: mask_attn_weights 82 | w = nn.Softmax(dim=-1)(w) 83 | w = self.attn_dropout(w) 84 | return torch.matmul(w, v) 85 | 86 | def merge_heads(self, x): 87 | x = x.permute(0, 2, 1, 3).contiguous() 88 | new_x_shape = x.size()[:-2] + (x.size(-2) * x.size(-1),) 89 | return x.view(*new_x_shape) # in Tensorflow implem: fct merge_states 90 | 91 | def split_heads(self, x, k=False): 92 | new_x_shape = x.size()[:-1] + (self.n_head, x.size(-1)//self.n_head) 93 | x = x.view(*new_x_shape) # in Tensorflow implem: fct split_states 94 | if k: 95 | return x.permute(0, 2, 3, 1) 96 | else: 97 | return x.permute(0, 2, 1, 3) 98 | 99 | def forward(self, x): 100 | x = self.c_attn(x) 101 | query, key, value = x.split(self.split_size, dim=2) 102 | query = self.split_heads(query) 103 | key = self.split_heads(key, k=True) 104 | value = self.split_heads(value) 105 | a = self._attn(query, key, value) 106 | a = self.merge_heads(a) 107 | a = self.c_proj(a) 108 | a = self.resid_dropout(a) 109 | return a 110 | 111 | 112 | class MLP(nn.Module): 113 | def __init__(self, n_state, cfg): # in MLP: n_state=3072 (4 * n_embd) 114 | super(MLP, self).__init__() 115 | nx = cfg.n_embd 116 | self.c_fc = Conv1D(n_state, 1, nx) 117 | self.c_proj = Conv1D(nx, 1, n_state) 118 | self.act = ACT_FNS[cfg.afn] 119 | self.dropout = nn.Dropout(cfg.resid_pdrop) 120 | 121 | def forward(self, x): 122 | h = self.act(self.c_fc(x)) 123 | h2 = self.c_proj(h) 124 | return self.dropout(h2) 125 | 126 | 127 | class Block(nn.Module): 128 | def __init__(self, n_ctx, cfg, scale=False): 129 | super(Block, self).__init__() 130 | nx = cfg.n_embd 131 | self.attn = Attention(nx, n_ctx, cfg, scale) 132 | self.ln_1 = LayerNorm(nx) 133 | self.mlp = MLP(4*nx, cfg) 134 | self.ln_2 = LayerNorm(nx) 135 | 136 | def forward(self, x): 137 | a = self.attn(x) 138 | n = self.ln_1(x+a) 139 | m = self.mlp(n) 140 | h = self.ln_2(n+m) 141 | return h 142 | 143 | 144 | class Model(nn.Module): 145 | """ Transformer model """ 146 | def __init__(self, vocab, n_ctx, cfg): 147 | super(Model, self).__init__() 148 | self.vocab = vocab 149 | self.embed = nn.Embedding(vocab, cfg.n_embd) 150 | self.drop = nn.Dropout(cfg.embd_pdrop) 151 | block = Block(n_ctx, cfg, scale=True) 152 | self.h = nn.ModuleList([copy.deepcopy(block) for _ in range(cfg.n_layer)]) 153 | self.decoder = nn.Linear(cfg.n_embd, vocab, bias=False) 154 | self.decoder.weight = self.embed.weight # Tied weights 155 | self.clf_dropout = nn.Dropout2d(cfg.clf_pdrop) # To reproduce the noise_shape parameter of TF implementation 156 | 157 | nn.init.normal_(self.embed.weight, std=0.02) 158 | 159 | def forward(self, x): 160 | x = x.view(-1, x.size(2), x.size(3)) 161 | e = self.embed(x) 162 | h = e.sum(dim=2) 163 | for block in self.h: 164 | h = block(h) 165 | return h 166 | 167 | 168 | class LMHead(nn.Module): 169 | """ Language Model Head for the transformer """ 170 | def __init__(self, model, cfg): 171 | super(LMHead, self).__init__() 172 | self.n_embd = cfg.n_embd 173 | self.decoder = nn.Linear(cfg.n_embd, model.vocab, bias=False) 174 | self.decoder.weight = model.embed.weight # Tied weights 175 | 176 | def forward(self, h): 177 | # Truncated Language modeling logits 178 | h_trunc = h[:, :-1].contiguous().view(-1, self.n_embd) # Shape: 252, 768 179 | lm_logits = self.decoder(h_trunc) 180 | return lm_logits 181 | 182 | 183 | class ClfHead(nn.Module): 184 | """ Classifier Head for the transformer """ 185 | def __init__(self, clf_token, cfg): 186 | super(ClfHead, self).__init__() 187 | self.n_embd = cfg.n_embd 188 | self.clf_token = clf_token 189 | self.dropout = nn.Dropout2d(cfg.clf_pdrop) # To reproduce the noise_shape parameter of TF implementation 190 | self.linear = nn.Linear(cfg.n_embd, 1) 191 | nn.init.normal_(self.linear.weight, std=0.02) 192 | nn.init.normal_(self.linear.bias, 0) 193 | 194 | def forward(self, h, x): 195 | # Classification logits 196 | clf_h = h.view(-1, self.n_embd) 197 | flat = x[:, :, :, 0].contiguous().view(-1) 198 | #pool_idx = torch.eq(x[:, :, 0].contiguous().view(-1), self.clf_token) 199 | clf_h = clf_h[flat == self.clf_token, :] #.index_select(0, pool_idx) 200 | clf_h = clf_h.view(-1, 2, self.n_embd, 1) 201 | clf_h = self.dropout(clf_h) 202 | clf_h = clf_h.view(-1, self.n_embd) 203 | clf_logits = self.linear(clf_h) 204 | return clf_logits.view(-1, 2) 205 | 206 | 207 | def load_openai_pretrained_model(model, n_ctx=-1, n_special=-1, n_transfer=12, n_embd=768, path='./model/', path_names='./'): 208 | # Load weights from TF model 209 | names = json.load(open(path_names + 'parameters_names.json')) 210 | shapes = json.load(open(path + 'params_shapes.json')) 211 | offsets = np.cumsum([np.prod(shape) for shape in shapes]) 212 | init_params = [np.load(path + 'params_{}.npy'.format(n)) for n in range(10)] 213 | init_params = np.split(np.concatenate(init_params, 0), offsets)[:-1] 214 | init_params = [param.reshape(shape) for param, shape in zip(init_params, shapes)] 215 | if n_ctx > 0: 216 | init_params[0] = init_params[0][:n_ctx] 217 | if n_special > 0: 218 | init_params[0] = np.concatenate([init_params[1], 219 | (np.random.randn(n_special, n_embd)*0.02).astype(np.float32), 220 | init_params[0] 221 | ], 0) 222 | else: 223 | init_params[0] = np.concatenate([init_params[1], 224 | init_params[0] 225 | ], 0) 226 | del init_params[1] 227 | if n_transfer == -1: 228 | n_transfer = 0 229 | else: 230 | n_transfer = 1+n_transfer*12 231 | init_params = [arr.squeeze() for arr in init_params] 232 | try: 233 | assert model.embed.weight.shape == init_params[0].shape 234 | except AssertionError as e: 235 | e.args += (model.embed.weight.shape, init_params[0].shape) 236 | raise 237 | model.embed.weight.data = torch.from_numpy(init_params[0]) 238 | for name, ip in zip(names[1:n_transfer], init_params[1:n_transfer]): 239 | name = name[6:] # skip "model/" 240 | assert name[-2:] == ":0" 241 | name = name[:-2] 242 | name = name.split('/') 243 | pointer = model 244 | for m_name in name: 245 | if re.fullmatch(r'[A-Za-z]+\d+', m_name): 246 | l = re.split(r'(\d+)', m_name) 247 | else: 248 | l = [m_name] 249 | pointer = getattr(pointer, l[0]) 250 | if len(l) >= 2: 251 | num = int(l[1]) 252 | pointer = pointer[num] 253 | try: 254 | assert pointer.shape == ip.shape 255 | except AssertionError as e: 256 | e.args += (pointer.shape, ip.shape) 257 | raise 258 | pointer.data = torch.from_numpy(ip) 259 | 260 | class dotdict(dict): 261 | """dot.notation access to dictionary attributes""" 262 | __getattr__ = dict.get 263 | __setattr__ = dict.__setitem__ 264 | __delattr__ = dict.__delitem__ 265 | 266 | DEFAULT_CONFIG = dotdict({ 267 | 'n_embd': 768, 268 | 'n_head': 12, 269 | 'n_layer': 12, 270 | 'embd_pdrop': 0.1, 271 | 'attn_pdrop': 0.1, 272 | 'resid_pdrop': 0.1, 273 | 'clf_pdrop': 0.1}) 274 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | import re 2 | import os 3 | import time 4 | import math 5 | import json 6 | import random 7 | import argparse 8 | import numpy as np 9 | 10 | import torch 11 | import torch.nn as nn 12 | import torch.nn.functional as F 13 | 14 | from tqdm import tqdm 15 | from functools import partial 16 | from sklearn.utils import shuffle 17 | from sklearn.metrics import accuracy_score 18 | 19 | from model_py import Model, LMHead, ClfHead, load_openai_pretrained_model 20 | from opt import OpenAIAdam 21 | from datasets import rocstories 22 | from analysis import rocstories as rocstories_analysis 23 | from text_utils import TextEncoder 24 | from utils import (encode_dataset, flatten, iter_data, 25 | ResultLogger, make_path) 26 | 27 | class LossCompute: 28 | "A Loss compute and train function." 29 | def __init__(self, lm_criterion, clf_criterion, lm_coef, opt=None): 30 | self.lm_criterion = lm_criterion 31 | self.clf_criterion = clf_criterion 32 | self.lm_coef = lm_coef 33 | self.opt = opt 34 | 35 | def __call__(self, X, Y, M, clf_logits, lm_logits=None, only_return_losses=False): 36 | # Language modeling loss 37 | if lm_logits is not None: 38 | x_shifted = X[:, :, 1:, 0].contiguous().view(-1) # Shape: 252 39 | M = M.view(-1, M.size(2)) 40 | lm_losses = self.lm_criterion(lm_logits, x_shifted) 41 | lm_losses = lm_losses.view(X.size(0) * X.size(1), X.size(2)-1) 42 | lm_losses = lm_losses * M[:, 1:] 43 | lm_losses = lm_losses.sum(1) / torch.sum(M[:, 1:], 1) 44 | # Classification loss 45 | clf_losses = self.clf_criterion(clf_logits, Y) 46 | if only_return_losses: 47 | return (clf_losses, lm_losses) if lm_logits is not None else clf_losses 48 | 49 | if self.lm_coef > 0 and lm_logits is not None: 50 | train_loss = clf_losses.sum() + self.lm_coef * lm_losses.sum() 51 | else: 52 | train_loss = clf_losses.sum() 53 | train_loss.backward() 54 | if self.opt is not None: 55 | self.opt.step() 56 | self.opt.zero_grad() 57 | return train_loss.item() 58 | 59 | def transform_roc(X1, X2, X3): 60 | n_batch = len(X1) 61 | xmb = np.zeros((n_batch, 2, n_ctx, 2), dtype=np.int32) 62 | mmb = np.zeros((n_batch, 2, n_ctx), dtype=np.float32) 63 | start = encoder['_start_'] 64 | delimiter = encoder['_delimiter_'] 65 | for i, (x1, x2, x3), in enumerate(zip(X1, X2, X3)): 66 | x12 = [start]+x1[:max_len]+[delimiter]+x2[:max_len]+[clf_token] 67 | x13 = [start]+x1[:max_len]+[delimiter]+x3[:max_len]+[clf_token] 68 | l12 = len(x12) 69 | l13 = len(x13) 70 | xmb[i, 0, :l12, 0] = x12 71 | xmb[i, 1, :l13, 0] = x13 72 | mmb[i, 0, :l12] = 1 73 | mmb[i, 1, :l13] = 1 74 | xmb[:, :, :, 1] = np.arange(n_vocab+n_special, n_vocab+n_special+n_ctx) 75 | return xmb, mmb 76 | 77 | def iter_apply(Xs, Ms, Ys): 78 | # fns = [lambda x: np.concatenate(x, 0), lambda x: float(np.sum(x))] 79 | logits = [] 80 | cost = 0 81 | with torch.no_grad(): 82 | model.eval() 83 | for xmb, mmb, ymb in iter_data(Xs, Ms, Ys, n_batch=n_batch_train, truncate=False, verbose=True): 84 | n = len(xmb) 85 | XMB = torch.tensor(xmb, dtype=torch.long).to(device) 86 | YMB = torch.tensor(ymb, dtype=torch.long).to(device) 87 | MMB = torch.tensor(mmb).to(device) 88 | h = model(XMB) 89 | clf_logits = clf_head(h, XMB) 90 | clf_logits *= n 91 | clf_losses = compute_loss_fct(XMB, YMB, MMB, clf_logits, only_return_losses=True) 92 | clf_losses *= n 93 | logits.append(clf_logits.to("cpu").numpy()) 94 | cost += clf_losses.sum().item() 95 | logits = np.concatenate(logits, 0) 96 | return logits, cost 97 | 98 | def iter_predict(Xs, Ms): 99 | logits = [] 100 | with torch.no_grad(): 101 | model.eval() 102 | for xmb, mmb in iter_data(Xs, Ms, n_batch=n_batch_train, truncate=False, verbose=True): 103 | n = len(xmb) 104 | XMB = torch.tensor(xmb, dtype=torch.long).to(device) 105 | MMB = torch.tensor(mmb).to(device) 106 | h = model(XMB) 107 | clf_logits = clf_head(h, XMB) 108 | logits.append(clf_logits.to("cpu").numpy()) 109 | logits = np.concatenate(logits, 0) 110 | return logits 111 | 112 | def log(): 113 | global best_score 114 | print("Logging") 115 | tr_logits, tr_cost = iter_apply(trX[:n_valid], trM[:n_valid], trY[:n_valid]) 116 | va_logits, va_cost = iter_apply(vaX, vaM, vaY) 117 | tr_cost = tr_cost/len(trY[:n_valid]) 118 | va_cost = va_cost/n_valid 119 | tr_acc = accuracy_score(trY[:n_valid], np.argmax(tr_logits, 1))*100. 120 | va_acc = accuracy_score(vaY, np.argmax(va_logits, 1))*100. 121 | logger.log(n_epochs=n_epochs, n_updates=n_updates, tr_cost=tr_cost, va_cost=va_cost, tr_acc=tr_acc, va_acc=va_acc) 122 | print('%d %d %.3f %.3f %.2f %.2f'%(n_epochs, n_updates, tr_cost, va_cost, tr_acc, va_acc)) 123 | if submit: 124 | score = va_acc 125 | if score > best_score: 126 | best_score = score 127 | path = os.path.join(save_dir, desc, 'best_params') 128 | torch.save(model.state_dict(), make_path(path)) 129 | 130 | def predict(): 131 | filename = filenames[dataset] 132 | pred_fn = pred_fns[dataset] 133 | label_decoder = label_decoders[dataset] 134 | predictions = pred_fn(iter_predict(teX, teM)) 135 | if label_decoder is not None: 136 | predictions = [label_decoder[prediction] for prediction in predictions] 137 | path = os.path.join(submission_dir, filename) 138 | os.makedirs(os.path.dirname(path), exist_ok=True) 139 | with open(path, 'w') as f: 140 | f.write('{}\t{}\n'.format('index', 'prediction')) 141 | for i, prediction in enumerate(predictions): 142 | f.write('{}\t{}\n'.format(i, prediction)) 143 | 144 | def run_epoch(): 145 | for xmb, mmb, ymb in iter_data(*shuffle(trX, trM, trYt, random_state=np.random), 146 | n_batch=n_batch_train, truncate=True, verbose=True): 147 | global n_updates 148 | model.train() 149 | XMB = torch.tensor(xmb, dtype=torch.long).to(device) 150 | YMB = torch.tensor(ymb, dtype=torch.long).to(device) 151 | MMB = torch.tensor(mmb).to(device) 152 | h = model(XMB) 153 | lm_logits = lm_head(h) 154 | clf_logits = clf_head(h, XMB) 155 | compute_loss_fct(XMB, YMB, MMB, clf_logits, lm_logits) 156 | n_updates += 1 157 | if n_updates in [1000, 2000, 4000, 8000, 16000, 32000] and n_epochs == 0: 158 | log() 159 | 160 | argmax = lambda x:np.argmax(x, 1) 161 | 162 | pred_fns = { 163 | 'rocstories':argmax, 164 | } 165 | 166 | filenames = { 167 | 'rocstories':'ROCStories.tsv', 168 | } 169 | 170 | label_decoders = { 171 | 'rocstories':None, 172 | } 173 | 174 | if __name__ == '__main__': 175 | parser = argparse.ArgumentParser() 176 | parser.add_argument('--desc', type=str) 177 | parser.add_argument('--dataset', type=str) 178 | parser.add_argument('--log_dir', type=str, default='log/') 179 | parser.add_argument('--save_dir', type=str, default='save/') 180 | parser.add_argument('--data_dir', type=str, default='data/') 181 | parser.add_argument('--submission_dir', type=str, default='submission/') 182 | parser.add_argument('--submit', action='store_true') 183 | parser.add_argument('--analysis', action='store_true') 184 | parser.add_argument('--seed', type=int, default=42) 185 | parser.add_argument('--n_iter', type=int, default=3) 186 | parser.add_argument('--n_batch', type=int, default=8) 187 | parser.add_argument('--max_grad_norm', type=int, default=1) 188 | parser.add_argument('--lr', type=float, default=6.25e-5) 189 | parser.add_argument('--lr_warmup', type=float, default=0.002) 190 | parser.add_argument('--n_ctx', type=int, default=512) 191 | parser.add_argument('--n_embd', type=int, default=768) 192 | parser.add_argument('--n_head', type=int, default=12) 193 | parser.add_argument('--n_layer', type=int, default=12) 194 | parser.add_argument('--embd_pdrop', type=float, default=0.1) 195 | parser.add_argument('--attn_pdrop', type=float, default=0.1) 196 | parser.add_argument('--resid_pdrop', type=float, default=0.1) 197 | parser.add_argument('--clf_pdrop', type=float, default=0.1) 198 | parser.add_argument('--l2', type=float, default=0.01) 199 | parser.add_argument('--vector_l2', action='store_true') 200 | parser.add_argument('--n_gpu', type=int, default=1)#4) # TODO add mutli-gpu training logic 201 | parser.add_argument('--opt', type=str, default='adam') 202 | parser.add_argument('--afn', type=str, default='gelu') 203 | parser.add_argument('--lr_schedule', type=str, default='warmup_linear') 204 | parser.add_argument('--encoder_path', type=str, default='model/encoder_bpe_40000.json') 205 | parser.add_argument('--bpe_path', type=str, default='model/vocab_40000.bpe') 206 | parser.add_argument('--n_transfer', type=int, default=12) 207 | parser.add_argument('--lm_coef', type=float, default=0.5) 208 | parser.add_argument('--b1', type=float, default=0.9) 209 | parser.add_argument('--b2', type=float, default=0.999) 210 | parser.add_argument('--e', type=float, default=1e-8) 211 | parser.add_argument('--n_valid', type=int, default=374) 212 | 213 | args = parser.parse_args() 214 | print(args) 215 | globals().update(args.__dict__) #TODO maybe we want to remove these gobal variables to make it cleaner 216 | random.seed(seed) 217 | np.random.seed(seed) 218 | torch.manual_seed(seed) 219 | torch.cuda.manual_seed_all(seed) 220 | 221 | # torch.device object used throughout this script TODO add gpu setting 222 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 223 | 224 | logger = ResultLogger(path=os.path.join(log_dir, '{}.jsonl'.format(desc)), **args.__dict__) 225 | text_encoder = TextEncoder(encoder_path, bpe_path) 226 | encoder = text_encoder.encoder 227 | n_vocab = len(text_encoder.encoder) 228 | 229 | (trX1, trX2, trX3, trY), (vaX1, vaX2, vaX3, vaY), (teX1, teX2, teX3) = encode_dataset(rocstories(data_dir, n_valid=n_valid), encoder=text_encoder) 230 | n_y = 2 231 | encoder['_start_'] = len(encoder) 232 | encoder['_delimiter_'] = len(encoder) 233 | encoder['_classify_'] = len(encoder) 234 | clf_token = encoder['_classify_'] 235 | n_special = 3 236 | max_len = n_ctx//2-2 237 | n_ctx = min(max( 238 | [len(x1[:max_len]) + max(len(x2[:max_len]), 239 | len(x3[:max_len])) for x1, x2, x3 in zip(trX1, trX2, trX3)] 240 | +[len(x1[:max_len]) + max(len(x2[:max_len]), 241 | len(x3[:max_len])) for x1, x2, x3 in zip(vaX1, vaX2, vaX3)] 242 | +[len(x1[:max_len]) + max(len(x2[:max_len]), 243 | len(x3[:max_len])) for x1, x2, x3 in zip(teX1, teX2, teX3)] 244 | ) + 3, 245 | n_ctx) 246 | vocab = n_vocab + n_special + n_ctx 247 | trX, trM = transform_roc(trX1, trX2, trX3) 248 | vaX, vaM = transform_roc(vaX1, vaX2, vaX3) 249 | if submit: 250 | teX, teM = transform_roc(teX1, teX2, teX3) 251 | 252 | n_train = len(trY) 253 | n_valid = len(vaY) 254 | n_batch_train = n_batch*n_gpu 255 | n_updates_total = (n_train//n_batch_train)*n_iter 256 | 257 | model = Model(vocab, n_ctx, args) 258 | lm_head = LMHead(model, args) 259 | clf_head = ClfHead(clf_token, args) 260 | 261 | criterion = nn.CrossEntropyLoss(reduce=False) # TODO check loss functions 262 | model_opt = OpenAIAdam(model.parameters(), lr=lr, schedule=lr_schedule, 263 | warmup=lr_warmup, t_total=n_updates_total, b1=b1, 264 | b2=b2, e=e, l2=l2, vector_l2=vector_l2, 265 | max_grad_norm=max_grad_norm) 266 | compute_loss_fct = LossCompute(criterion, criterion, lm_coef, model_opt) 267 | load_openai_pretrained_model(model, n_ctx=n_ctx, n_special=n_special) 268 | 269 | model.to(device) 270 | lm_head.to(device) 271 | clf_head.to(device) 272 | 273 | n_updates = 0 274 | n_epochs = 0 275 | if dataset != 'stsb': 276 | trYt = trY 277 | if submit: 278 | path = os.path.join(save_dir, desc, 'best_params') 279 | torch.save(model.state_dict(), make_path(path)) 280 | best_score = 0 281 | for i in range(n_iter): 282 | print("running epoch", i) 283 | run_epoch() 284 | n_epochs += 1 285 | log() 286 | if submit: 287 | path = os.path.join(save_dir, desc, 'best_params') 288 | model.load_state_dict(torch.load(path)) 289 | predict() 290 | if analysis: 291 | rocstories_analysis(data_dir, os.path.join(submission_dir, 'ROCStories.tsv'), 292 | os.path.join(log_dir, 'rocstories.jsonl')) 293 | --------------------------------------------------------------------------------