├── models ├── finetune_large.sh ├── README.md ├── gpt2.py └── run_language_modeling.py ├── get_interpolated_output.py ├── README.md ├── requirements.txt └── get_matrices.py /models/finetune_large.sh: -------------------------------------------------------------------------------- 1 | export TRAIN_FILE=./path/to/finetune-lm-train 2 | export TEST_FILE=./path/to/finetune-lm-eval 3 | 4 | python run_language_modeling.py \ 5 | --output_dir=output_large \ 6 | --model_type=gpt2-large \ 7 | --model_name_or_path=gpt2-large \ 8 | --do_train \ 9 | --train_data_file=$TRAIN_FILE \ 10 | --do_eval \ 11 | --eval_data_file=$TEST_FILE 12 | -------------------------------------------------------------------------------- /models/README.md: -------------------------------------------------------------------------------- 1 | # Model Directory 2 | 3 | We use GPT2-large as provided by [HF transformers](https://github.com/huggingface/transformers). This directory contains code used to interface with GPT-2 and obtain word probabilities in a sequence, it's imported as a utility. 4 | 5 | ## Fine tuning GPT2 6 | 7 | run\_language\_modeling.py is a standard fine tuning python script.
8 | You only need to modify the train/test files within finetune\_large.sh and execute the same. The files we used are [here](https://drive.google.com/drive/folders/1XrlzvJqmvcK0IpYK-VwIN5tk2y6iIILi?usp=sharing), adapt to your domain 9 | 10 | ``` 11 | sh finetune_large.sh 12 | ``` 13 | 14 | ## Using the model 15 | 16 | gpt2.py is a utility used by other elements of our project. As a standalone, it loads the checkpoint saved by the fine tuning script and has a function to obtain the probability of each word in a paragraph/sentence of text.
17 | -------------------------------------------------------------------------------- /get_interpolated_output.py: -------------------------------------------------------------------------------- 1 | import random 2 | import pickle 3 | import spacy 4 | import numpy as np 5 | nlp = spacy.load("en_core_web_sm") 6 | 7 | import argparse 8 | 9 | parser = argparse.ArgumentParser(description='Process some integers.') 10 | parser.add_argument("--classifier") 11 | parser.add_argument("--output_file") 12 | parser.add_argument("--dataset_file") 13 | parser.add_argument("--matrix_file_pattern") 14 | 15 | args = parser.parse_args() 16 | 17 | clf = pickle.load(open(args.classifier, "rb")) 18 | 19 | data = pickle.load(open(args.dataset_file, "rb")) 20 | 21 | for file_idx in range(5): 22 | count = 0 23 | output = pickle.load(open(args.matrix_file_pattern+str(file_idx)+".pkl", "rb")) 24 | for key in output.keys(): 25 | #if key not in ref.keys(): 26 | # continue 27 | count+=1 28 | article, abstract = data[key] 29 | print("--------------------") 30 | print(key) 31 | selected = [] 32 | doc = nlp(article) 33 | sentences = [sentence.text.strip() for sentence in doc.sents] 34 | 35 | matrix = output[key]['vanilla'] 36 | matrix[matrix<0] = 0 37 | relevance = [] 38 | surprise = output[key]['surprise'] 39 | for idx in range(len(sentences)): 40 | relevance.append(sum(matrix[idx])) 41 | 42 | penalty = [0 for i in range(len(sentences))] 43 | #print(surprise) 44 | #print(matrix) 45 | try: 46 | for j in range(1, 9): 47 | selected = [] 48 | summary = "" 49 | for k in range(j): 50 | maxIdx = -1 51 | maxVal = -float('inf') 52 | #print(maxIdx, selected) 53 | for i in range(len(sentences)): 54 | temp = np.dot(clf.coef_[0], [penalty[i], relevance[i]]) 55 | if temp > maxVal and i not in selected: 56 | maxIdx = i 57 | maxVal = temp 58 | 59 | #print(maxVal, maxIdx) 60 | for i in range(len(sentences)): 61 | penalty[i]+=matrix[i][maxIdx] 62 | 63 | selected.append(maxIdx) 64 | #print(selected) 65 | summary = "" 66 | for i in sorted(selected): 67 | summary+= sentences[i]+" " 68 | 69 | summary = ' '.join(summary.split()) 70 | 71 | with open(args.output_file+str(j), "a") as f: 72 | f.write(summary+'\n') 73 | 74 | with open("./path/to/output/gold", "a") as f: 75 | f.write(' '.join(abstract.split())+'\n') 76 | except: 77 | print("Missed ", key) 78 | 79 | #print("SUMMARY ", summary) 80 | 81 | #if count == 10: 82 | # exit(0) 83 | 84 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Unsupervised Extractive Summarization using Mutual Information 2 | 3 | ## Steps to perform summarization 4 | 5 | 1. Fine tune the required language model 6 | 2. Create sentence-wise mutual information matrices for documents in the dataset 7 | 3. Use the created matrices to generate extractive summaries 8 | 4. Evaluation 9 | 10 | ### 1. Fine tune the required language model 11 | 12 | We use GPT2-large as provided by [HF transformers](https://github.com/huggingface/transformers) with a standard fine tuning script as provided in the model directory. 13 | 14 | ### 2. Creating mutual information matrices 15 | 16 | For each document in the dataset, if there are n sentences, we need to create the nxn pairwise matrix as described in the paper.
17 | 18 | ``` 19 | python get_matrices.py --index 0 --input_data ./path/to/dataset.pkl --index_file ./path/to/indices.pkl --output_format ./path/to/output/sent_train_set_ 20 | ``` 21 | 22 | The parameters indicate the following: 23 | 1. input\_data - Pickle file containing the document-summary files. Check that the format matches the following input scheme: 24 | ``` 25 | data = pickle.load(open(args.input_data, "rb")) 26 | for i in range(len(data)): 27 | article, abstract = data[i] 28 | ``` 29 | 2. index\_file - Since the process takes a while, we split the same into different parts by means of a pickled index file. The index file consists of a list of tuples of (start index, end index) within the indexing of the length of the dataset file. So if the dataset is of length 100, our index file might be [(0, 30), (30, 60), (60, 90), (90, 100)] and each execution of get\_matrices with a particular index creates the matrices associated with those datapoints in the dataset. The simplest way to handle this would be to make a list with just one tuple as [(0, length of dataset)]. 30 | 3. index - Index within the index file to execute. Index file is a list, so if index is 0, then the matrices are created for the documents associated with the indices enclosed by the first tuple in the index file. So you can parallelize the process by running multiple at the same time. 31 | 4. output\_format - File format/location where the output is stored. Output is stored in the form of a pickle file at output format location with the index appended at the end. 32 | 33 | #### Output Format 34 | For the given execution line, the output would look like: 35 | ``` 36 | ./path/to/output/sent_train_set_0.pkl 37 | ``` 38 | Where the location/directory was given by the output format parameter in step 1 and the 0 is given by the index executed by the current script (in practise you would have 0, 1, 2... all running in parallel)
39 | 40 | The pickle file is indexed as 41 | ``` 42 | data = pickle.load(open("./path/to/output/sent_train_set_0.pkl", "rb")) 43 | doc_number = 0 44 | print(data[doc_number]['vanilla']) 45 | ``` 46 | Within the pickle file, the primary index is the index of the document within the dataset pickle file. We create the 'vanilla' matrix which is an nxn sentence-wise PMI matrix as described in the paper. Additionally we also create the unused 'surprise' list which is the sentence-wise surprisal for each sentence and the 'normalised' PMI matrix. 47 | 48 | ### 3. Generate Summaries 49 | From the generated pmi matrices, we generate summaries using the algorithm detailed in the paper. 50 | ``` 51 | python get_interpolated_output.py --output_file ./path/to/output/interpolated_ --dataset_file ./path/to/dataset.pkl --matrix_file_pattern ./path/to/output/sent_train_set_ 52 | ``` 53 | Consumes the original dataset pickled file and the generated output files from the previous step. The script iterates over all the different output files created (hence why we aceept it as a pattern), and generates extractive summaries from length 1 to length 9 in 9 separate files using the format specified here. So here interpolated\_1 will consist of a text file where each line corresponds to a summary of length 1 sentence for each sentence in the dataset. And there will be 9 files like this. Additionally we also save all the gold summaries in a similar file. The reason for this file format is that it is compatible with the Rouge package in the next step. Additionally change the interpolation coefficients here if needed. To interpolate between relevance and redundancy, in the paper we used a classifier to learn weights assigned to each. A simpler alternative to just run inference is to set the weights to +1 for relevance and -1 for redundancy [here](https://github.com/vishakhpk/mi-unsup-summ/blob/196d3b646460f03cfcf9e41e1db621868a7156d0/get_interpolated_output.py#L54) and comment out the [line](https://github.com/vishakhpk/mi-unsup-summ/blob/196d3b646460f03cfcf9e41e1db621868a7156d0/get_interpolated_output.py#L17) loading the classifier. 54 | 55 | ### 4. Evaluation 56 | Standard rouge evaluation using [rouge scorer](https://github.com/google-research/google-research/tree/master/rouge) 57 | ``` 58 | python3 -m rouge_score.rouge --target_filepattern=./path/to/output/gold --prediction_filepattern=./path/to/output/interpolated_3 --output_filename=./path/to/output/results.csv --use_stemmer=true 59 | ``` 60 | 61 | ## Data for the paper 62 | The following contains the preprocessed datasets, created PMI matrices, generated summaries and Rouge score reports used in our paper:
63 | 64 | [Google Drive Link](https://drive.google.com/drive/folders/1dBPd7trOOdKTNFDtUSGH9Z3zZ2PucDmL?usp=sharing)
65 | 66 | Please reach out if you would like to use our saved language models (vishakh@nyu.edu). We use GPT2 large, fine tuned on the document sentences from the various domains, using the script in the models directory. The input files for the fine tuning script are available in folder [LM-Files](https://drive.google.com/drive/folders/1XrlzvJqmvcK0IpYK-VwIN5tk2y6iIILi?usp=sharing) at the above drive location. 67 | 68 | ## TODOs 69 | 1. Convert scripts to a reusable utility 70 | 2. Maybe remove indices files to make things easier 71 | -------------------------------------------------------------------------------- /models/gpt2.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | import code 4 | from transformers import GPT2Tokenizer, GPT2LMHeadModel 5 | 6 | class GPT2: 7 | """ 8 | Citation: https://github.com/HendrikStrobelt/detecting-fake-text/blob/master/backend/api.py 9 | Model class for GPT-2. Primarily used to obtain word probabilities 10 | """ 11 | def __init__(self, device = "cpu", location = ""): 12 | if location == "": 13 | self.enc = GPT2Tokenizer.from_pretrained("gpt2-large") 14 | self.model = GPT2LMHeadModel.from_pretrained("gpt2-large") 15 | else: 16 | self.enc = GPT2Tokenizer.from_pretrained(location) 17 | self.model = GPT2LMHeadModel.from_pretrained(location) 18 | self.device = torch.device(device) 19 | self.model.eval() 20 | self.start_tok = "<|endoftext|>" 21 | #SPECIAL_TOKENS = [""] 22 | #self.enc.add_special_tokens(SPECIAL_TOKENS) 23 | #self.model.set_num_special_tokens(len(SPECIAL_TOKENS)) 24 | self.model.to(self.device) 25 | 26 | def pad(self, context): 27 | max_len = max([len(sentence) for sentence in context]) 28 | #print("Maximum Length: ", max_len) 29 | for i in range(len(context)): 30 | #print(len(context[i]), max_len - len(context[i])) 31 | for j in range(max_len - len(context[i])): 32 | context[i].append(context[i][0]) 33 | #print(max_len - len(sentences[i].split())) 34 | #print("i: ", sentences[i]) 35 | #print([[self.enc.encode("") for idx in range(max_len - len(in_text[i].split()))] for i in range(len(in_text))]) 36 | #print(sentences) 37 | #print([len(context[i]) for i in range(len(context))]) 38 | return context 39 | 40 | def get_probabilities(self, in_text, topk = 40): 41 | """ 42 | Take in a sequence of text tokens, make predictions on each word given past context and 43 | return topk 44 | 45 | Returns: 46 | Dictionary "payload" containing: 47 | real_probs 48 | - List of tuples, one for each token in sequence 49 | - Probability of the actual words in the sequence 50 | - Each tuple of the form (position of next word in prediction, predicted probability) 51 | 52 | context_strings: 53 | - Strings in the sequence along with start token 54 | """ 55 | with torch.no_grad(): 56 | start_tok = torch.full((1, 1), self.enc.encoder[self.start_tok], 57 | device=self.device, dtype=torch.long) 58 | context = [self.start_tok+" "+in_text[i] for i in range(len(in_text))] 59 | context = [self.enc.encode(context[i]) for i in range(len(context))] 60 | context = self.pad(context) 61 | context = torch.tensor(context, device=self.device, dtype=torch.long) 62 | logits, _ = self.model(context) 63 | yhat = torch.softmax(logits[:, :-1], dim=-1) 64 | y = context[:, 1:] 65 | real_topk_probs = [yhat[t][np.arange(0, y[t].shape[0], 1), y[t]].data.cpu().numpy().tolist() for t in range(yhat.shape[0])] 66 | real_topk_probs = [list(map(lambda x: round(x, 15), real_topk_probs[t])) for t in range(len(real_topk_probs))] 67 | 68 | real_topk = [list(real_topk_probs[t]) for t in range(len(real_topk_probs))] 69 | 70 | context_strings = [[self.enc.decoder[s.item()] for s in context[t]] for t in range(len(context))] 71 | context_strings = [[self.postprocess(s) for s in context_strings[t]] for t in range(len(context_strings))] 72 | del context, logits, y, yhat, 73 | torch.cuda.empty_cache() 74 | """ 75 | pred_topk = [[list(zip([self.enc.decoder[p] for p in sorted_preds[t][i][:topk]], 76 | list(map(lambda x: round(x, 5),yhat[t][i][sorted_preds[t][i][ 77 | :topk]].data.cpu().numpy().tolist())))) 78 | for i in range(y[t].shape[0])] for t in range(y.shape[0])] 79 | pred_topk = [[[(self.postprocess(t[0]), t[1]) for t in pred] for pred in pred_topk[t]] for t in range(len(pred_topk))] 80 | """ 81 | payload = {'context_strings': context_strings, 82 | 'real_probs': real_topk}#, 'pred_topk': pred_topk} 83 | 84 | #del context, logits, y, yhat, 85 | #torch.cuda.empty_cache() 86 | #code.interact(local=locals()) 87 | return payload 88 | 89 | def postprocess(self, token): 90 | with_space = False 91 | with_break = False 92 | #print(token, token[0], token[1:]), 93 | if token[0] == 'Ġ': 94 | with_space = True 95 | token = token[1:] 96 | elif token.startswith('â'): 97 | token = ' ' 98 | elif token.startswith('Ċ'): 99 | token = ' ' 100 | with_break = True 101 | 102 | if len(token)>0 and token[0] == "Â": 103 | token = token[1:] 104 | token = '-' if token.startswith('â') else token 105 | token = '“' if token.startswith('ľ') else token 106 | token = '”' if token.startswith('Ŀ') else token 107 | token = "'" if token.startswith('Ļ') else token 108 | #if with_space: 109 | # token = '\u0120' + token 110 | #if with_break: 111 | # token = '\u010A' + token 112 | #print(token) 113 | return token 114 | 115 | if __name__=="__main__": 116 | test = GPT2(location="../../Summarization/models/output_large/checkpoint-70000") 117 | payload = test.get_probabilities(["the following is", "what about this"], topk = 40) 118 | print(payload['real_probs']) 119 | #for t in range(len(payload["pred_topk"])): 120 | # for i in range(len(payload["pred_topk"][t])): 121 | # print(payload["real_probs"][t][i]) 122 | #code.interact(local=locals()) 123 | test = GPT2() 124 | payload = test.get_probabilities(["the following is", "what about this"], topk = 40) 125 | print("And without fine tuning: ") 126 | print(payload['real_probs']) 127 | #for t in range(len(payload["pred_topk"])): 128 | # for i in range(len(payload["pred_topk"][t])): 129 | # print(payload["real_probs"][t][i]) 130 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | absl-py==0.8.1 2 | alabaster==0.7.11 3 | alembic==1.0.5 4 | anaconda-client==1.7.2 5 | anaconda-navigator==1.9.2 6 | anaconda-project==0.8.2 7 | appdirs==1.4.3 8 | asn1crypto==0.24.0 9 | astor==0.8.0 10 | astroid==2.0.4 11 | astropy==3.0.4 12 | async-generator==1.10 13 | atomicwrites==1.2.1 14 | attrs==18.2.0 15 | Automat==0.7.0 16 | Babel==2.6.0 17 | backcall==0.1.0 18 | backports.shutil-get-terminal-size==1.0.0 19 | bash-kernel==0.7.1 20 | batchspawner==0.8.1 21 | bcolz==1.2.1 22 | beautifulsoup4==4.6.3 23 | bitarray==0.8.3 24 | bkcharts==0.2 25 | blaze==0.11.3 26 | bleach==3.0.2 27 | blis==0.4.1 28 | bokeh==0.13.0 29 | boto==2.49.0 30 | boto3==1.10.0 31 | botocore==1.13.0 32 | Bottleneck==1.2.1 33 | catalogue==0.0.8 34 | certifi==2018.8.24 35 | cffi==1.11.5 36 | chardet==3.0.4 37 | click==6.7 38 | cloudpickle==0.5.5 39 | clyent==1.2.2 40 | colorama==0.3.9 41 | conda==4.5.11 42 | conda-build==3.15.1 43 | constantly==15.1.0 44 | contextlib2==0.5.5 45 | coverage==3.7.1 46 | coveralls==0.5 47 | cryptography==2.3.1 48 | cycler==0.10.0 49 | cymem==2.0.2 50 | Cython==0.28.5 51 | cytoolz==0.9.0.1 52 | dask==0.19.1 53 | datascience==0.10.6 54 | datashape==0.5.4 55 | decorator==4.3.0 56 | defusedxml==0.5.0 57 | dill==0.3.2 58 | distributed==1.23.1 59 | docopt==0.6.2 60 | docutils==0.14 61 | en-core-web-sm==2.2.5 62 | entrypoints==0.2.3 63 | et-xmlfile==1.0.1 64 | fairseq==0.9.0 65 | fastai==1.0.57 66 | fastBPE==0.1.0 67 | fastcache==1.0.2 68 | fastprogress==0.1.21 69 | fasttext==0.9.1 70 | filelock==3.0.8 71 | Flask==1.0.2 72 | Flask-Cors==3.0.6 73 | folium==0.2.1 74 | future==0.18.2 75 | gast==0.2.2 76 | gensim==3.8.2 77 | gevent==1.3.6 78 | glob2==0.6 79 | gmpy2==2.0.8 80 | google-pasta==0.1.7 81 | googleapis-common-protos==1.52.0 82 | greenlet==0.4.15 83 | grpcio==1.24.1 84 | h5py==2.8.0 85 | heapdict==1.0.0 86 | html5lib==1.0.1 87 | hyperlink==18.0.0 88 | idna==2.7 89 | imageio==2.4.1 90 | imagesize==1.1.0 91 | importlib-metadata==1.3.0 92 | incremental==17.5.0 93 | ipykernel==5.1.0 94 | ipython==6.2.1 95 | ipython-genutils==0.2.0 96 | ipywidgets==7.4.2 97 | isort==4.3.4 98 | itsdangerous==0.24 99 | jdcal==1.4 100 | jedi==0.13.1 101 | jeepney==0.3.1 102 | Jinja2==2.10 103 | jmespath==0.9.4 104 | joblib==0.14.0 105 | jsonschema==2.6.0 106 | jupyter==1.0.0 107 | jupyter-client==5.2.4 108 | jupyter-console==5.2.0 109 | jupyter-core==4.4.0 110 | jupyterhub==0.9.4 111 | jupyterhub-ldapauthenticator==1.2.2 112 | jupyterlab==0.34.9 113 | jupyterlab-launcher==0.13.1 114 | Keras==2.3.1 115 | Keras-Applications==1.0.8 116 | Keras-Preprocessing==1.1.0 117 | keyring==13.2.1 118 | kiwisolver==1.0.1 119 | lazy-object-proxy==1.3.1 120 | ldap3==2.5.1 121 | llvmlite==0.24.0 122 | locket==0.2.0 123 | lxml==4.2.5 124 | Mako==1.0.7 125 | Markdown==3.1.1 126 | MarkupSafe==1.1.0 127 | matlab-kernel==0.16.1 128 | matplotlib==2.2.3 129 | mccabe==0.6.1 130 | mecab-python3==0.996.5 131 | metakernel==0.20.14 132 | mistune==0.8.4 133 | mkl-fft==1.0.4 134 | mkl-random==1.0.1 135 | more-itertools==4.3.0 136 | mpmath==1.0.0 137 | msgpack==0.5.6 138 | multipledispatch==0.6.0 139 | murmurhash==1.0.2 140 | navigator-updater==0.2.1 141 | nbconvert==5.4.0 142 | nbformat==4.4.0 143 | nbgitpuller==0.6.1 144 | networkx==2.1 145 | nltk==3.3 146 | nmslib==1.8.1 147 | nose==1.3.7 148 | notebook==5.7.2 149 | numba==0.39.0 150 | numexpr==2.6.8 151 | numpy==1.16.2 152 | numpydoc==0.8.0 153 | nvidia-ml-py3==7.352.0 154 | odo==0.5.1 155 | olefile==0.46 156 | openpyxl==2.5.6 157 | opt-einsum==3.1.0 158 | packaging==17.1 159 | pamela==0.3.0 160 | pandas==0.23.4 161 | pandocfilters==1.4.2 162 | parso==0.3.1 163 | partd==0.3.8 164 | path.py==11.1.0 165 | pathlib2==2.3.2 166 | patsy==0.5.0 167 | pep8==1.7.1 168 | pexpect==4.6.0 169 | pickleshare==0.7.5 170 | Pillow==5.2.0 171 | pkginfo==1.4.2 172 | plac==0.9.6 173 | pluggy==0.7.1 174 | ply==3.11 175 | portalocker==1.7.0 176 | preshed==3.0.2 177 | prometheus-client==0.5.0 178 | promise==2.3 179 | prompt-toolkit==1.0.15 180 | protobuf==3.10.0 181 | psutil==5.4.7 182 | ptyprocess==0.6.0 183 | py==1.6.0 184 | pyasn1==0.4.4 185 | pyasn1-modules==0.2.2 186 | pybind11==2.4.1 187 | pycodestyle==2.4.0 188 | pycosat==0.6.3 189 | pycparser==2.18 190 | pycrypto==2.6.1 191 | pycurl==7.43.0.2 192 | pyflakes==2.0.0 193 | Pygments==2.3.0 194 | pylint==2.1.1 195 | pymongo==3.9.0 196 | pyodbc==4.0.24 197 | pyOpenSSL==18.0.0 198 | pyparsing==2.2.0 199 | pyrouge==0.1.3 200 | pyrsistent==0.15.7 201 | PySocks==1.6.8 202 | pytest==3.8.0 203 | pytest-arraydiff==0.2 204 | pytest-astropy==0.4.0 205 | pytest-doctestplus==0.1.3 206 | pytest-openfiles==0.3.0 207 | pytest-remotedata==0.3.0 208 | python-dateutil==2.7.5 209 | python-editor==1.0.3 210 | python-oauth2==1.1.0 211 | pytorch-transformers==1.2.0 212 | pytz==2018.5 213 | PyWavelets==1.0.0 214 | PyYAML==3.13 215 | pyzmq==17.1.2 216 | QtAwesome==0.4.4 217 | qtconsole==4.4.1 218 | QtPy==1.5.0 219 | regex==2019.8.19 220 | requests==2.19.1 221 | rope==0.11.0 222 | rouge-score==0.0.3 223 | ruamel-yaml==0.15.46 224 | s3transfer==0.2.1 225 | sacrebleu==1.4.6 226 | sacremoses==0.0.35 227 | scikit-image==0.14.0 228 | scikit-learn==0.19.2 229 | scipy==1.1.0 230 | seaborn==0.9.0 231 | SecretStorage==3.1.0 232 | Send2Trash==1.5.0 233 | sentencepiece==0.1.83 234 | seqeval==0.0.12 235 | service-identity==17.0.0 236 | simplegeneric==0.8.1 237 | simpletransformers==0.22.1 238 | singledispatch==3.4.0.3 239 | six==1.12.0 240 | smappdragon==0.0.47 241 | smart-open==2.0.0 242 | snowballstemmer==1.2.1 243 | sortedcollections==1.0.1 244 | sortedcontainers==2.0.5 245 | spacy==2.2.3 246 | Sphinx==1.7.9 247 | sphinxcontrib-websupport==1.1.0 248 | spyder==3.3.1 249 | spyder-kernels==0.2.6 250 | SQLAlchemy==1.2.11 251 | srsly==0.1.0 252 | statsmodels==0.9.0 253 | sympy==1.2 254 | tables==3.4.4 255 | tblib==1.3.2 256 | tensorboard==2.0.0 257 | tensorboardX==2.0 258 | tensorflow==2.0.0 259 | tensorflow-datasets==3.2.1 260 | tensorflow-estimator==2.0.1 261 | tensorflow-metadata==0.22.2 262 | termcolor==1.1.0 263 | terminado==0.8.1 264 | testpath==0.3.1 265 | thinc==7.3.1 266 | tokenizers==0.5.2 267 | toolz==0.9.0 268 | torch==1.4.0 269 | torchvision==0.4.0 270 | tornado==5.1.1 271 | tqdm==4.43.0 272 | traitlets==4.3.2 273 | transformers==2.6.0 274 | Twisted==18.7.0 275 | typing==3.7.4.1 276 | unicodecsv==0.14.1 277 | uritools==3.0.0 278 | urlextract==0.14.0 279 | urllib3==1.23 280 | virtualenv==16.1.0 281 | wasabi==0.4.2 282 | wcwidth==0.1.7 283 | webencodings==0.5.1 284 | Werkzeug==0.14.1 285 | widgetsnbextension==3.4.2 286 | wrapt==1.11.2 287 | wurlitzer==1.0.2 288 | xlrd==1.1.0 289 | XlsxWriter==1.1.0 290 | xlwt==1.3.0 291 | zict==0.1.3 292 | zipp==0.6.0 293 | zope.interface==4.5.0 294 | -------------------------------------------------------------------------------- /get_matrices.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | sys.path.insert(0, "./models/") 4 | sys.path.insert(0, "./read-data/") 5 | from read import get_text 6 | from gpt2 import GPT2 7 | from tqdm import tqdm 8 | import preprocess_subsequence 9 | import nltk 10 | import math 11 | from nltk import tokenize 12 | import numpy as np 13 | #from sklearn.metrics import precision_score, recall_score 14 | import pickle 15 | import spacy 16 | import argparse 17 | parser = argparse.ArgumentParser() 18 | parser.add_argument("--index") 19 | parser.add_argument("--input_data") 20 | parser.add_argument("--index_file") 21 | parser.add_argument("--output_format") 22 | args = parser.parse_args() 23 | print(args.index) 24 | print(args.index_file) 25 | print(args.input_data) 26 | print(args.output_format) 27 | 28 | data = pickle.load(open(args.input_data, "rb"))# 29 | print("Length ", len(data)) 30 | indices = pickle.load(open(args.index_file, "rb")) 31 | print(indices) 32 | lower, upper = indices[int(args.index)] 33 | print(lower, upper) 34 | 35 | nlp = spacy.load("en_core_web_sm") 36 | 37 | model = GPT2(device="cuda", location="./path/to/saved/model/") 38 | 39 | if os.path.exists(args.output_format+args.index+".pkl"): 40 | output = pickle.load(open(args.output_format+args.index+".pkl","rb")) 41 | else: 42 | output = {} 43 | 44 | def get_probabilities(articles): 45 | """ 46 | Given a batch of articles (can be any strings) run a forward pass on GPT2 and obtain word probabilities for the same 47 | """ 48 | article_splits = [article.split(" ") for article in articles] 49 | payload = model.get_probabilities(articles, topk = 20) 50 | res = [[] for i in range(len(articles))] 51 | for t, article in enumerate(articles): 52 | context = "" 53 | idx = 0 54 | chain = False 55 | next_word = "" 56 | article_words = article_splits[t] 57 | #print(article, article_words) 58 | word_probability = 1.0 59 | gt_count = 0 60 | idx+=1 61 | found_words = [] 62 | for i, word in enumerate(payload["context_strings"][t][:-1]): 63 | context = context+" "+word 64 | probability = payload['real_probs'][t][i]#[1] 65 | next_word_fragment = payload["context_strings"][t][i+1] 66 | 67 | next_word += next_word_fragment 68 | #print(next_word, article_words[gt_count]) 69 | if next_word == article_words[gt_count]: 70 | chain = False 71 | gt_count+=1 72 | else: 73 | chain = True 74 | 75 | word_probability *= probability 76 | assert word_probability <= 1.0, print(word_probability, context) 77 | if chain == False: 78 | #print("Word Probability: ", word_probability, next_word) 79 | res[t].append(word_probability) 80 | word_probability = 1.0 81 | next_word = "" 82 | #print(gt_count, len(article_words)) 83 | if gt_count == len(article_words): 84 | break 85 | return res 86 | 87 | 88 | def get_npmi_matrix(sentences, method = 1, batch_size = 1): 89 | """ 90 | Accepts a list of sentences of length n and returns 3 objects: 91 | - Normalised PMI nxn matrix - temp 92 | - PMI nxn matrix - temp2 93 | - List of length n indicating sentence-wise surprisal i.e. p(sentence) - p 94 | 95 | To optimize performance, we do the forward pass batchwise by assembling the batch and maintaining batch indices 96 | For each batch we call get_probabilities 97 | """ 98 | temp = np.zeros((len(sentences), len(sentences))) 99 | temp2 = np.zeros((len(sentences), len(sentences))) 100 | batch_indices = {} 101 | batch = [] 102 | batchCount = 0 103 | batchSize = batch_size 104 | #print(len(sentences)) 105 | c = 0 106 | p = [] 107 | for i in range(len(sentences)): 108 | result = get_probabilities([sentences[i]]) 109 | try: 110 | p.append(sum([math.log(i) for i in result[0]])) 111 | except: 112 | print("Math domain error surprise", i) 113 | return temp, temp2, p 114 | for i in range(len(sentences)): 115 | for j in range(len(sentences)): 116 | if i==j: 117 | temp[i][j] = -1 118 | temp2[i][j] = -1 119 | continue 120 | article = sentences[i] + " "+ sentences[j] 121 | #print(article) 122 | batch_indices[str(i)+"-"+str(j)+"-"+str(len(sentences[i].split()))] = batchCount 123 | batch.append(article) 124 | batchCount+=1 125 | 126 | if batchCount == batchSize or (i == len(sentences)-1 and j == len(sentences)-1): 127 | #print(batch) 128 | c+=1 129 | result = get_probabilities(batch) 130 | for key in batch_indices.keys(): 131 | #print(key) 132 | #print(key.split("-")) 133 | idx_i, idx_j, idx_l = [int(idx) for idx in key.split("-")] 134 | try: 135 | pxy = sum([math.log(q) for q in result[batch_indices[key]][idx_l:]]) 136 | py = p[idx_j] 137 | px = p[idx_i] 138 | 139 | temp[idx_i][idx_j] = (pxy - py)/(-1*(pxy+px)) 140 | temp2[idx_i][idx_j] = (pxy - py) 141 | except ZeroDivisionError: 142 | print("Zero division error ", idx_i, idx_j) 143 | temp[idx_i][idx_j] = -1 144 | temp2[idx_i][idx_j] = -1 145 | except: 146 | print("Math Domain Error", i, j) 147 | if temp[idx_i][idx_j] > 1 or temp[idx_i][idx_j] < -1: 148 | print("Normalise assert ", temp[idx_i][idx_j], idx_i, idx_j) 149 | batchCount = 0 150 | batch = [] 151 | batch_indices = {} 152 | return temp, temp2, p 153 | 154 | def remove_unicode(text): 155 | return ''.join([i if ord(i) < 128 else ' ' for i in text]) 156 | 157 | def get_article(idx): 158 | """ 159 | For each document in the dataset, split it into sentences and call get_npmi_matrix to create the matrices 160 | """ 161 | print(idx) 162 | article, abstract = data[idx] 163 | #sentences = tokenize.sent_tokenize(article) 164 | doc = nlp(article) 165 | sentences = [remove_unicode(sentence.text) for sentence in doc.sents] 166 | normalised, vanilla, surprise = get_npmi_matrix(sentences, batch_size = 10) 167 | #avg = get_pmi_matrix(sentences, method = 1) 168 | output[idx] = {} 169 | output[idx]["vanilla"] = vanilla 170 | output[idx]["normalised"] = normalised 171 | output[idx]["surprise"] = surprise 172 | #output[idx]["averaging"] = avg 173 | #pickle.dump(output, open("full_set_1.pkl", "wb")) 174 | return 175 | 176 | """ 177 | Main iteration loop, creates matrices for each document in the dataset 178 | """ 179 | c = 0 180 | for idx in range(len(data)): 181 | if idx>=lower and idx 0 and not line.isspace())] 117 | 118 | self.examples = tokenizer.batch_encode_plus(lines, add_special_tokens=True, max_length=block_size)["input_ids"] 119 | 120 | def __len__(self): 121 | return len(self.examples) 122 | 123 | def __getitem__(self, i): 124 | return torch.tensor(self.examples[i], dtype=torch.long) 125 | 126 | 127 | def load_and_cache_examples(args, tokenizer, evaluate=False): 128 | file_path = args.eval_data_file if evaluate else args.train_data_file 129 | if args.line_by_line: 130 | return LineByLineTextDataset(tokenizer, args, file_path=file_path, block_size=args.block_size) 131 | else: 132 | return TextDataset(tokenizer, args, file_path=file_path, block_size=args.block_size) 133 | 134 | 135 | def set_seed(args): 136 | random.seed(args.seed) 137 | np.random.seed(args.seed) 138 | torch.manual_seed(args.seed) 139 | if args.n_gpu > 0: 140 | torch.cuda.manual_seed_all(args.seed) 141 | 142 | 143 | def _sorted_checkpoints(args, checkpoint_prefix="checkpoint", use_mtime=False) -> List[str]: 144 | ordering_and_checkpoint_path = [] 145 | 146 | glob_checkpoints = glob.glob(os.path.join(args.output_dir, "{}-*".format(checkpoint_prefix))) 147 | 148 | for path in glob_checkpoints: 149 | if use_mtime: 150 | ordering_and_checkpoint_path.append((os.path.getmtime(path), path)) 151 | else: 152 | regex_match = re.match(".*{}-([0-9]+)".format(checkpoint_prefix), path) 153 | if regex_match and regex_match.groups(): 154 | ordering_and_checkpoint_path.append((int(regex_match.groups()[0]), path)) 155 | 156 | checkpoints_sorted = sorted(ordering_and_checkpoint_path) 157 | checkpoints_sorted = [checkpoint[1] for checkpoint in checkpoints_sorted] 158 | return checkpoints_sorted 159 | 160 | 161 | def _rotate_checkpoints(args, checkpoint_prefix="checkpoint", use_mtime=False) -> None: 162 | if not args.save_total_limit: 163 | return 164 | if args.save_total_limit <= 0: 165 | return 166 | 167 | # Check if we should delete older checkpoint(s) 168 | checkpoints_sorted = _sorted_checkpoints(args, checkpoint_prefix, use_mtime) 169 | if len(checkpoints_sorted) <= args.save_total_limit: 170 | return 171 | 172 | number_of_checkpoints_to_delete = max(0, len(checkpoints_sorted) - args.save_total_limit) 173 | checkpoints_to_be_deleted = checkpoints_sorted[:number_of_checkpoints_to_delete] 174 | for checkpoint in checkpoints_to_be_deleted: 175 | logger.info("Deleting older checkpoint [{}] due to args.save_total_limit".format(checkpoint)) 176 | shutil.rmtree(checkpoint) 177 | 178 | 179 | def mask_tokens(inputs: torch.Tensor, tokenizer: PreTrainedTokenizer, args) -> Tuple[torch.Tensor, torch.Tensor]: 180 | """ Prepare masked tokens inputs/labels for masked language modeling: 80% MASK, 10% random, 10% original. """ 181 | 182 | if tokenizer.mask_token is None: 183 | raise ValueError( 184 | "This tokenizer does not have a mask token which is necessary for masked language modeling. Remove the --mlm flag if you want to use this tokenizer." 185 | ) 186 | 187 | labels = inputs.clone() 188 | # We sample a few tokens in each sequence for masked-LM training (with probability args.mlm_probability defaults to 0.15 in Bert/RoBERTa) 189 | probability_matrix = torch.full(labels.shape, args.mlm_probability) 190 | special_tokens_mask = [ 191 | tokenizer.get_special_tokens_mask(val, already_has_special_tokens=True) for val in labels.tolist() 192 | ] 193 | probability_matrix.masked_fill_(torch.tensor(special_tokens_mask, dtype=torch.bool), value=0.0) 194 | if tokenizer._pad_token is not None: 195 | padding_mask = labels.eq(tokenizer.pad_token_id) 196 | probability_matrix.masked_fill_(padding_mask, value=0.0) 197 | masked_indices = torch.bernoulli(probability_matrix).bool() 198 | labels[~masked_indices] = -100 # We only compute loss on masked tokens 199 | 200 | # 80% of the time, we replace masked input tokens with tokenizer.mask_token ([MASK]) 201 | indices_replaced = torch.bernoulli(torch.full(labels.shape, 0.8)).bool() & masked_indices 202 | inputs[indices_replaced] = tokenizer.convert_tokens_to_ids(tokenizer.mask_token) 203 | 204 | # 10% of the time, we replace masked input tokens with random word 205 | indices_random = torch.bernoulli(torch.full(labels.shape, 0.5)).bool() & masked_indices & ~indices_replaced 206 | random_words = torch.randint(len(tokenizer), labels.shape, dtype=torch.long) 207 | inputs[indices_random] = random_words[indices_random] 208 | 209 | # The rest of the time (10% of the time) we keep the masked input tokens unchanged 210 | return inputs, labels 211 | 212 | 213 | def train(args, train_dataset, model: PreTrainedModel, tokenizer: PreTrainedTokenizer) -> Tuple[int, float]: 214 | """ Train the model """ 215 | if args.local_rank in [-1, 0]: 216 | tb_writer = SummaryWriter() 217 | 218 | args.train_batch_size = args.per_gpu_train_batch_size * max(1, args.n_gpu) 219 | 220 | def collate(examples: List[torch.Tensor]): 221 | if tokenizer._pad_token is None: 222 | return pad_sequence(examples, batch_first=True) 223 | return pad_sequence(examples, batch_first=True, padding_value=tokenizer.pad_token_id) 224 | 225 | train_sampler = RandomSampler(train_dataset) if args.local_rank == -1 else DistributedSampler(train_dataset) 226 | train_dataloader = DataLoader( 227 | train_dataset, sampler=train_sampler, batch_size=args.train_batch_size, collate_fn=collate 228 | ) 229 | 230 | if args.max_steps > 0: 231 | t_total = args.max_steps 232 | args.num_train_epochs = args.max_steps // (len(train_dataloader) // args.gradient_accumulation_steps) + 1 233 | else: 234 | t_total = len(train_dataloader) // args.gradient_accumulation_steps * args.num_train_epochs 235 | 236 | model = model.module if hasattr(model, "module") else model # Take care of distributed/parallel training 237 | model.resize_token_embeddings(len(tokenizer)) 238 | 239 | # Prepare optimizer and schedule (linear warmup and decay) 240 | no_decay = ["bias", "LayerNorm.weight"] 241 | optimizer_grouped_parameters = [ 242 | { 243 | "params": [p for n, p in model.named_parameters() if not any(nd in n for nd in no_decay)], 244 | "weight_decay": args.weight_decay, 245 | }, 246 | {"params": [p for n, p in model.named_parameters() if any(nd in n for nd in no_decay)], "weight_decay": 0.0}, 247 | ] 248 | optimizer = AdamW(optimizer_grouped_parameters, lr=args.learning_rate, eps=args.adam_epsilon) 249 | scheduler = get_linear_schedule_with_warmup( 250 | optimizer, num_warmup_steps=args.warmup_steps, num_training_steps=t_total 251 | ) 252 | 253 | # Check if saved optimizer or scheduler states exist 254 | if ( 255 | args.model_name_or_path 256 | and os.path.isfile(os.path.join(args.model_name_or_path, "optimizer.pt")) 257 | and os.path.isfile(os.path.join(args.model_name_or_path, "scheduler.pt")) 258 | ): 259 | # Load in optimizer and scheduler states 260 | optimizer.load_state_dict(torch.load(os.path.join(args.model_name_or_path, "optimizer.pt"))) 261 | scheduler.load_state_dict(torch.load(os.path.join(args.model_name_or_path, "scheduler.pt"))) 262 | 263 | if args.fp16: 264 | try: 265 | from apex import amp 266 | except ImportError: 267 | raise ImportError("Please install apex from https://www.github.com/nvidia/apex to use fp16 training.") 268 | model, optimizer = amp.initialize(model, optimizer, opt_level=args.fp16_opt_level) 269 | 270 | # multi-gpu training (should be after apex fp16 initialization) 271 | if args.n_gpu > 1: 272 | model = torch.nn.DataParallel(model) 273 | 274 | # Distributed training (should be after apex fp16 initialization) 275 | if args.local_rank != -1: 276 | model = torch.nn.parallel.DistributedDataParallel( 277 | model, device_ids=[args.local_rank], output_device=args.local_rank, find_unused_parameters=True 278 | ) 279 | 280 | # Train! 281 | logger.info("***** Running training *****") 282 | logger.info(" Num examples = %d", len(train_dataset)) 283 | logger.info(" Num Epochs = %d", args.num_train_epochs) 284 | logger.info(" Instantaneous batch size per GPU = %d", args.per_gpu_train_batch_size) 285 | logger.info( 286 | " Total train batch size (w. parallel, distributed & accumulation) = %d", 287 | args.train_batch_size 288 | * args.gradient_accumulation_steps 289 | * (torch.distributed.get_world_size() if args.local_rank != -1 else 1), 290 | ) 291 | logger.info(" Gradient Accumulation steps = %d", args.gradient_accumulation_steps) 292 | logger.info(" Total optimization steps = %d", t_total) 293 | 294 | global_step = 0 295 | epochs_trained = 0 296 | steps_trained_in_current_epoch = 0 297 | # Check if continuing training from a checkpoint 298 | if args.model_name_or_path and os.path.exists(args.model_name_or_path): 299 | try: 300 | # set global_step to gobal_step of last saved checkpoint from model path 301 | checkpoint_suffix = args.model_name_or_path.split("-")[-1].split("/")[0] 302 | global_step = int(checkpoint_suffix) 303 | epochs_trained = global_step // (len(train_dataloader) // args.gradient_accumulation_steps) 304 | steps_trained_in_current_epoch = global_step % (len(train_dataloader) // args.gradient_accumulation_steps) 305 | 306 | logger.info(" Continuing training from checkpoint, will skip to saved global_step") 307 | logger.info(" Continuing training from epoch %d", epochs_trained) 308 | logger.info(" Continuing training from global step %d", global_step) 309 | logger.info(" Will skip the first %d steps in the first epoch", steps_trained_in_current_epoch) 310 | except ValueError: 311 | logger.info(" Starting fine-tuning.") 312 | 313 | tr_loss, logging_loss = 0.0, 0.0 314 | 315 | model.zero_grad() 316 | train_iterator = trange( 317 | epochs_trained, int(args.num_train_epochs), desc="Epoch", disable=args.local_rank not in [-1, 0] 318 | ) 319 | set_seed(args) # Added here for reproducibility 320 | for epoch in train_iterator: 321 | epoch_iterator = tqdm(train_dataloader, desc="Iteration", disable=args.local_rank not in [-1, 0]) 322 | 323 | if args.local_rank != -1: 324 | train_sampler.set_epoch(epoch) 325 | 326 | for step, batch in enumerate(epoch_iterator): 327 | 328 | # Skip past any already trained steps if resuming training 329 | if steps_trained_in_current_epoch > 0: 330 | steps_trained_in_current_epoch -= 1 331 | continue 332 | 333 | inputs, labels = mask_tokens(batch, tokenizer, args) if args.mlm else (batch, batch) 334 | inputs = inputs.to(args.device) 335 | labels = labels.to(args.device) 336 | model.train() 337 | outputs = model(inputs, masked_lm_labels=labels) if args.mlm else model(inputs, labels=labels) 338 | loss = outputs[0] # model outputs are always tuple in transformers (see doc) 339 | 340 | if args.n_gpu > 1: 341 | loss = loss.mean() # mean() to average on multi-gpu parallel training 342 | if args.gradient_accumulation_steps > 1: 343 | loss = loss / args.gradient_accumulation_steps 344 | 345 | if args.fp16: 346 | with amp.scale_loss(loss, optimizer) as scaled_loss: 347 | scaled_loss.backward() 348 | else: 349 | loss.backward() 350 | 351 | tr_loss += loss.item() 352 | if (step + 1) % args.gradient_accumulation_steps == 0: 353 | if args.fp16: 354 | torch.nn.utils.clip_grad_norm_(amp.master_params(optimizer), args.max_grad_norm) 355 | else: 356 | torch.nn.utils.clip_grad_norm_(model.parameters(), args.max_grad_norm) 357 | optimizer.step() 358 | scheduler.step() # Update learning rate schedule 359 | model.zero_grad() 360 | global_step += 1 361 | 362 | if args.local_rank in [-1, 0] and args.logging_steps > 0 and global_step % args.logging_steps == 0: 363 | # Log metrics 364 | if ( 365 | args.local_rank == -1 and args.evaluate_during_training 366 | ): # Only evaluate when single GPU otherwise metrics may not average well 367 | results = evaluate(args, model, tokenizer) 368 | for key, value in results.items(): 369 | tb_writer.add_scalar("eval_{}".format(key), value, global_step) 370 | tb_writer.add_scalar("lr", scheduler.get_lr()[0], global_step) 371 | tb_writer.add_scalar("loss", (tr_loss - logging_loss) / args.logging_steps, global_step) 372 | logging_loss = tr_loss 373 | 374 | if args.local_rank in [-1, 0] and args.save_steps > 0 and global_step % args.save_steps == 0: 375 | checkpoint_prefix = "checkpoint" 376 | # Save model checkpoint 377 | output_dir = os.path.join(args.output_dir, "{}-{}".format(checkpoint_prefix, global_step)) 378 | os.makedirs(output_dir, exist_ok=True) 379 | model_to_save = ( 380 | model.module if hasattr(model, "module") else model 381 | ) # Take care of distributed/parallel training 382 | model_to_save.save_pretrained(output_dir) 383 | tokenizer.save_pretrained(output_dir) 384 | 385 | torch.save(args, os.path.join(output_dir, "training_args.bin")) 386 | logger.info("Saving model checkpoint to %s", output_dir) 387 | 388 | _rotate_checkpoints(args, checkpoint_prefix) 389 | 390 | torch.save(optimizer.state_dict(), os.path.join(output_dir, "optimizer.pt")) 391 | torch.save(scheduler.state_dict(), os.path.join(output_dir, "scheduler.pt")) 392 | logger.info("Saving optimizer and scheduler states to %s", output_dir) 393 | 394 | if args.max_steps > 0 and global_step > args.max_steps: 395 | epoch_iterator.close() 396 | break 397 | if args.max_steps > 0 and global_step > args.max_steps: 398 | train_iterator.close() 399 | break 400 | 401 | if args.local_rank in [-1, 0]: 402 | tb_writer.close() 403 | 404 | return global_step, tr_loss / global_step 405 | 406 | 407 | def evaluate(args, model: PreTrainedModel, tokenizer: PreTrainedTokenizer, prefix="") -> Dict: 408 | # Loop to handle MNLI double evaluation (matched, mis-matched) 409 | eval_output_dir = args.output_dir 410 | 411 | eval_dataset = load_and_cache_examples(args, tokenizer, evaluate=True) 412 | 413 | if args.local_rank in [-1, 0]: 414 | os.makedirs(eval_output_dir, exist_ok=True) 415 | 416 | args.eval_batch_size = args.per_gpu_eval_batch_size * max(1, args.n_gpu) 417 | # Note that DistributedSampler samples randomly 418 | 419 | def collate(examples: List[torch.Tensor]): 420 | if tokenizer._pad_token is None: 421 | return pad_sequence(examples, batch_first=True) 422 | return pad_sequence(examples, batch_first=True, padding_value=tokenizer.pad_token_id) 423 | 424 | eval_sampler = SequentialSampler(eval_dataset) 425 | eval_dataloader = DataLoader( 426 | eval_dataset, sampler=eval_sampler, batch_size=args.eval_batch_size, collate_fn=collate 427 | ) 428 | 429 | # multi-gpu evaluate 430 | if args.n_gpu > 1: 431 | model = torch.nn.DataParallel(model) 432 | 433 | # Eval! 434 | logger.info("***** Running evaluation {} *****".format(prefix)) 435 | logger.info(" Num examples = %d", len(eval_dataset)) 436 | logger.info(" Batch size = %d", args.eval_batch_size) 437 | eval_loss = 0.0 438 | nb_eval_steps = 0 439 | model.eval() 440 | 441 | for batch in tqdm(eval_dataloader, desc="Evaluating"): 442 | inputs, labels = mask_tokens(batch, tokenizer, args) if args.mlm else (batch, batch) 443 | inputs = inputs.to(args.device) 444 | labels = labels.to(args.device) 445 | 446 | with torch.no_grad(): 447 | outputs = model(inputs, masked_lm_labels=labels) if args.mlm else model(inputs, labels=labels) 448 | lm_loss = outputs[0] 449 | eval_loss += lm_loss.mean().item() 450 | nb_eval_steps += 1 451 | 452 | eval_loss = eval_loss / nb_eval_steps 453 | perplexity = torch.exp(torch.tensor(eval_loss)) 454 | 455 | result = {"perplexity": perplexity} 456 | 457 | output_eval_file = os.path.join(eval_output_dir, prefix, "eval_results.txt") 458 | with open(output_eval_file, "w") as writer: 459 | logger.info("***** Eval results {} *****".format(prefix)) 460 | for key in sorted(result.keys()): 461 | logger.info(" %s = %s", key, str(result[key])) 462 | writer.write("%s = %s\n" % (key, str(result[key]))) 463 | 464 | return result 465 | 466 | 467 | def main(): 468 | parser = argparse.ArgumentParser() 469 | 470 | # Required parameters 471 | parser.add_argument( 472 | "--train_data_file", default=None, type=str, required=True, help="The input training data file (a text file)." 473 | ) 474 | parser.add_argument( 475 | "--output_dir", 476 | type=str, 477 | required=True, 478 | help="The output directory where the model predictions and checkpoints will be written.", 479 | ) 480 | parser.add_argument( 481 | "--model_type", type=str, required=True, help="The model architecture to be trained or fine-tuned.", 482 | ) 483 | 484 | # Other parameters 485 | parser.add_argument( 486 | "--eval_data_file", 487 | default=None, 488 | type=str, 489 | help="An optional input evaluation data file to evaluate the perplexity on (a text file).", 490 | ) 491 | parser.add_argument( 492 | "--line_by_line", 493 | action="store_true", 494 | help="Whether distinct lines of text in the dataset are to be handled as distinct sequences.", 495 | ) 496 | parser.add_argument( 497 | "--should_continue", action="store_true", help="Whether to continue from latest checkpoint in output_dir" 498 | ) 499 | parser.add_argument( 500 | "--model_name_or_path", 501 | default=None, 502 | type=str, 503 | help="The model checkpoint for weights initialization. Leave None if you want to train a model from scratch.", 504 | ) 505 | 506 | parser.add_argument( 507 | "--mlm", action="store_true", help="Train with masked-language modeling loss instead of language modeling." 508 | ) 509 | parser.add_argument( 510 | "--mlm_probability", type=float, default=0.15, help="Ratio of tokens to mask for masked language modeling loss" 511 | ) 512 | 513 | parser.add_argument( 514 | "--config_name", 515 | default=None, 516 | type=str, 517 | help="Optional pretrained config name or path if not the same as model_name_or_path. If both are None, initialize a new config.", 518 | ) 519 | parser.add_argument( 520 | "--tokenizer_name", 521 | default=None, 522 | type=str, 523 | help="Optional pretrained tokenizer name or path if not the same as model_name_or_path. If both are None, initialize a new tokenizer.", 524 | ) 525 | parser.add_argument( 526 | "--cache_dir", 527 | default=None, 528 | type=str, 529 | help="Optional directory to store the pre-trained models downloaded from s3 (instead of the default one)", 530 | ) 531 | parser.add_argument( 532 | "--block_size", 533 | default=-1, 534 | type=int, 535 | help="Optional input sequence length after tokenization." 536 | "The training dataset will be truncated in block of this size for training." 537 | "Default to the model max input length for single sentence inputs (take into account special tokens).", 538 | ) 539 | parser.add_argument("--do_train", action="store_true", help="Whether to run training.") 540 | parser.add_argument("--do_eval", action="store_true", help="Whether to run eval on the dev set.") 541 | parser.add_argument( 542 | "--evaluate_during_training", action="store_true", help="Run evaluation during training at each logging step." 543 | ) 544 | 545 | parser.add_argument("--per_gpu_train_batch_size", default=1, type=int, help="Batch size per GPU/CPU for training.") 546 | parser.add_argument( 547 | "--per_gpu_eval_batch_size", default=1, type=int, help="Batch size per GPU/CPU for evaluation." 548 | ) 549 | parser.add_argument( 550 | "--gradient_accumulation_steps", 551 | type=int, 552 | default=1, 553 | help="Number of updates steps to accumulate before performing a backward/update pass.", 554 | ) 555 | parser.add_argument("--learning_rate", default=5e-5, type=float, help="The initial learning rate for Adam.") 556 | parser.add_argument("--weight_decay", default=0.0, type=float, help="Weight decay if we apply some.") 557 | parser.add_argument("--adam_epsilon", default=1e-8, type=float, help="Epsilon for Adam optimizer.") 558 | parser.add_argument("--max_grad_norm", default=1.0, type=float, help="Max gradient norm.") 559 | parser.add_argument( 560 | "--num_train_epochs", default=1.0, type=float, help="Total number of training epochs to perform." 561 | ) 562 | parser.add_argument( 563 | "--max_steps", 564 | default=-1, 565 | type=int, 566 | help="If > 0: set total number of training steps to perform. Override num_train_epochs.", 567 | ) 568 | parser.add_argument("--warmup_steps", default=0, type=int, help="Linear warmup over warmup_steps.") 569 | 570 | parser.add_argument("--logging_steps", type=int, default=500, help="Log every X updates steps.") 571 | parser.add_argument("--save_steps", type=int, default=500, help="Save checkpoint every X updates steps.") 572 | parser.add_argument( 573 | "--save_total_limit", 574 | type=int, 575 | default=None, 576 | help="Limit the total amount of checkpoints, delete the older checkpoints in the output_dir, does not delete by default", 577 | ) 578 | parser.add_argument( 579 | "--eval_all_checkpoints", 580 | action="store_true", 581 | help="Evaluate all checkpoints starting with the same prefix as model_name_or_path ending and ending with step number", 582 | ) 583 | parser.add_argument("--no_cuda", action="store_true", help="Avoid using CUDA when available") 584 | parser.add_argument( 585 | "--overwrite_output_dir", action="store_true", help="Overwrite the content of the output directory" 586 | ) 587 | parser.add_argument( 588 | "--overwrite_cache", action="store_true", help="Overwrite the cached training and evaluation sets" 589 | ) 590 | parser.add_argument("--seed", type=int, default=42, help="random seed for initialization") 591 | 592 | parser.add_argument( 593 | "--fp16", 594 | action="store_true", 595 | help="Whether to use 16-bit (mixed) precision (through NVIDIA apex) instead of 32-bit", 596 | ) 597 | parser.add_argument( 598 | "--fp16_opt_level", 599 | type=str, 600 | default="O1", 601 | help="For fp16: Apex AMP optimization level selected in ['O0', 'O1', 'O2', and 'O3']." 602 | "See details at https://nvidia.github.io/apex/amp.html", 603 | ) 604 | parser.add_argument("--local_rank", type=int, default=-1, help="For distributed training: local_rank") 605 | parser.add_argument("--server_ip", type=str, default="", help="For distant debugging.") 606 | parser.add_argument("--server_port", type=str, default="", help="For distant debugging.") 607 | args = parser.parse_args() 608 | 609 | if args.model_type in ["bert", "roberta", "distilbert", "camembert"] and not args.mlm: 610 | raise ValueError( 611 | "BERT and RoBERTa-like models do not have LM heads but masked LM heads. They must be run using the --mlm " 612 | "flag (masked language modeling)." 613 | ) 614 | if args.eval_data_file is None and args.do_eval: 615 | raise ValueError( 616 | "Cannot do evaluation without an evaluation data file. Either supply a file to --eval_data_file " 617 | "or remove the --do_eval argument." 618 | ) 619 | if args.should_continue: 620 | sorted_checkpoints = _sorted_checkpoints(args) 621 | if len(sorted_checkpoints) == 0: 622 | raise ValueError("Used --should_continue but no checkpoint was found in --output_dir.") 623 | else: 624 | args.model_name_or_path = sorted_checkpoints[-1] 625 | 626 | if ( 627 | os.path.exists(args.output_dir) 628 | and os.listdir(args.output_dir) 629 | and args.do_train 630 | and not args.overwrite_output_dir 631 | and not args.should_continue 632 | ): 633 | raise ValueError( 634 | "Output directory ({}) already exists and is not empty. Use --overwrite_output_dir to overcome.".format( 635 | args.output_dir 636 | ) 637 | ) 638 | 639 | # Setup distant debugging if needed 640 | if args.server_ip and args.server_port: 641 | # Distant debugging - see https://code.visualstudio.com/docs/python/debugging#_attach-to-a-local-script 642 | import ptvsd 643 | 644 | print("Waiting for debugger attach") 645 | ptvsd.enable_attach(address=(args.server_ip, args.server_port), redirect_output=True) 646 | ptvsd.wait_for_attach() 647 | 648 | # Setup CUDA, GPU & distributed training 649 | if args.local_rank == -1 or args.no_cuda: 650 | device = torch.device("cuda" if torch.cuda.is_available() and not args.no_cuda else "cpu") 651 | args.n_gpu = 0 if args.no_cuda else torch.cuda.device_count() 652 | else: # Initializes the distributed backend which will take care of sychronizing nodes/GPUs 653 | torch.cuda.set_device(args.local_rank) 654 | device = torch.device("cuda", args.local_rank) 655 | torch.distributed.init_process_group(backend="nccl") 656 | args.n_gpu = 1 657 | args.device = device 658 | 659 | # Setup logging 660 | logging.basicConfig( 661 | format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", 662 | datefmt="%m/%d/%Y %H:%M:%S", 663 | level=logging.INFO if args.local_rank in [-1, 0] else logging.WARN, 664 | ) 665 | logger.warning( 666 | "Process rank: %s, device: %s, n_gpu: %s, distributed training: %s, 16-bits training: %s", 667 | args.local_rank, 668 | device, 669 | args.n_gpu, 670 | bool(args.local_rank != -1), 671 | args.fp16, 672 | ) 673 | 674 | # Set seed 675 | set_seed(args) 676 | 677 | # Load pretrained model and tokenizer 678 | if args.local_rank not in [-1, 0]: 679 | torch.distributed.barrier() # Barrier to make sure only the first process in distributed training download model & vocab 680 | 681 | if args.config_name: 682 | config = AutoConfig.from_pretrained(args.config_name, cache_dir=args.cache_dir) 683 | elif args.model_name_or_path: 684 | config = AutoConfig.from_pretrained(args.model_name_or_path, cache_dir=args.cache_dir) 685 | else: 686 | # When we release a pip version exposing CONFIG_MAPPING, 687 | # we can do `config = CONFIG_MAPPING[args.model_type]()`. 688 | raise ValueError( 689 | "You are instantiating a new config instance from scratch. This is not supported, but you can do it from another script, save it," 690 | "and load it from here, using --config_name" 691 | ) 692 | 693 | if args.tokenizer_name: 694 | tokenizer = AutoTokenizer.from_pretrained(args.tokenizer_name, cache_dir=args.cache_dir) 695 | elif args.model_name_or_path: 696 | tokenizer = AutoTokenizer.from_pretrained(args.model_name_or_path, cache_dir=args.cache_dir) 697 | else: 698 | raise ValueError( 699 | "You are instantiating a new tokenizer from scratch. This is not supported, but you can do it from another script, save it," 700 | "and load it from here, using --tokenizer_name" 701 | ) 702 | 703 | if args.block_size <= 0: 704 | args.block_size = 512#tokenizer.max_len 705 | # Our input block size will be the max possible for the model 706 | else: 707 | args.block_size = min(args.block_size, tokenizer.max_len) 708 | #print("Block_size", args.block_size) 709 | if args.model_name_or_path: 710 | model = AutoModelWithLMHead.from_pretrained( 711 | args.model_name_or_path, 712 | from_tf=bool(".ckpt" in args.model_name_or_path), 713 | config=config, 714 | cache_dir=args.cache_dir, 715 | ) 716 | else: 717 | logger.info("Training new model from scratch") 718 | model = AutoModelWithLMHead.from_config(config) 719 | 720 | model.to(args.device) 721 | 722 | if args.local_rank == 0: 723 | torch.distributed.barrier() # End of barrier to make sure only the first process in distributed training download model & vocab 724 | 725 | logger.info("Training/evaluation parameters %s", args) 726 | 727 | # Training 728 | if args.do_train: 729 | if args.local_rank not in [-1, 0]: 730 | torch.distributed.barrier() # Barrier to make sure only the first process in distributed training process the dataset, and the others will use the cache 731 | 732 | train_dataset = load_and_cache_examples(args, tokenizer, evaluate=False) 733 | 734 | if args.local_rank == 0: 735 | torch.distributed.barrier() 736 | 737 | global_step, tr_loss = train(args, train_dataset, model, tokenizer) 738 | logger.info(" global_step = %s, average loss = %s", global_step, tr_loss) 739 | 740 | # Saving best-practices: if you use save_pretrained for the model and tokenizer, you can reload them using from_pretrained() 741 | if args.do_train and (args.local_rank == -1 or torch.distributed.get_rank() == 0): 742 | # Create output directory if needed 743 | if args.local_rank in [-1, 0]: 744 | os.makedirs(args.output_dir, exist_ok=True) 745 | 746 | logger.info("Saving model checkpoint to %s", args.output_dir) 747 | # Save a trained model, configuration and tokenizer using `save_pretrained()`. 748 | # They can then be reloaded using `from_pretrained()` 749 | model_to_save = ( 750 | model.module if hasattr(model, "module") else model 751 | ) # Take care of distributed/parallel training 752 | model_to_save.save_pretrained(args.output_dir) 753 | tokenizer.save_pretrained(args.output_dir) 754 | 755 | # Good practice: save your training arguments together with the trained model 756 | torch.save(args, os.path.join(args.output_dir, "training_args.bin")) 757 | 758 | # Load a trained model and vocabulary that you have fine-tuned 759 | model = AutoModelWithLMHead.from_pretrained(args.output_dir) 760 | tokenizer = AutoTokenizer.from_pretrained(args.output_dir) 761 | model.to(args.device) 762 | 763 | # Evaluation 764 | results = {} 765 | if args.do_eval and args.local_rank in [-1, 0]: 766 | checkpoints = [args.output_dir] 767 | if args.eval_all_checkpoints: 768 | checkpoints = list( 769 | os.path.dirname(c) for c in sorted(glob.glob(args.output_dir + "/**/" + WEIGHTS_NAME, recursive=True)) 770 | ) 771 | logging.getLogger("transformers.modeling_utils").setLevel(logging.WARN) # Reduce logging 772 | logger.info("Evaluate the following checkpoints: %s", checkpoints) 773 | for checkpoint in checkpoints: 774 | global_step = checkpoint.split("-")[-1] if len(checkpoints) > 1 else "" 775 | prefix = checkpoint.split("/")[-1] if checkpoint.find("checkpoint") != -1 else "" 776 | 777 | model = AutoModelWithLMHead.from_pretrained(checkpoint) 778 | model.to(args.device) 779 | result = evaluate(args, model, tokenizer, prefix=prefix) 780 | result = dict((k + "_{}".format(global_step), v) for k, v in result.items()) 781 | results.update(result) 782 | 783 | return results 784 | 785 | 786 | if __name__ == "__main__": 787 | main() 788 | --------------------------------------------------------------------------------