├── models
├── finetune_large.sh
├── README.md
├── gpt2.py
└── run_language_modeling.py
├── get_interpolated_output.py
├── README.md
├── requirements.txt
└── get_matrices.py
/models/finetune_large.sh:
--------------------------------------------------------------------------------
1 | export TRAIN_FILE=./path/to/finetune-lm-train
2 | export TEST_FILE=./path/to/finetune-lm-eval
3 |
4 | python run_language_modeling.py \
5 | --output_dir=output_large \
6 | --model_type=gpt2-large \
7 | --model_name_or_path=gpt2-large \
8 | --do_train \
9 | --train_data_file=$TRAIN_FILE \
10 | --do_eval \
11 | --eval_data_file=$TEST_FILE
12 |
--------------------------------------------------------------------------------
/models/README.md:
--------------------------------------------------------------------------------
1 | # Model Directory
2 |
3 | We use GPT2-large as provided by [HF transformers](https://github.com/huggingface/transformers). This directory contains code used to interface with GPT-2 and obtain word probabilities in a sequence, it's imported as a utility.
4 |
5 | ## Fine tuning GPT2
6 |
7 | run\_language\_modeling.py is a standard fine tuning python script.
8 | You only need to modify the train/test files within finetune\_large.sh and execute the same. The files we used are [here](https://drive.google.com/drive/folders/1XrlzvJqmvcK0IpYK-VwIN5tk2y6iIILi?usp=sharing), adapt to your domain
9 |
10 | ```
11 | sh finetune_large.sh
12 | ```
13 |
14 | ## Using the model
15 |
16 | gpt2.py is a utility used by other elements of our project. As a standalone, it loads the checkpoint saved by the fine tuning script and has a function to obtain the probability of each word in a paragraph/sentence of text.
17 |
--------------------------------------------------------------------------------
/get_interpolated_output.py:
--------------------------------------------------------------------------------
1 | import random
2 | import pickle
3 | import spacy
4 | import numpy as np
5 | nlp = spacy.load("en_core_web_sm")
6 |
7 | import argparse
8 |
9 | parser = argparse.ArgumentParser(description='Process some integers.')
10 | parser.add_argument("--classifier")
11 | parser.add_argument("--output_file")
12 | parser.add_argument("--dataset_file")
13 | parser.add_argument("--matrix_file_pattern")
14 |
15 | args = parser.parse_args()
16 |
17 | clf = pickle.load(open(args.classifier, "rb"))
18 |
19 | data = pickle.load(open(args.dataset_file, "rb"))
20 |
21 | for file_idx in range(5):
22 | count = 0
23 | output = pickle.load(open(args.matrix_file_pattern+str(file_idx)+".pkl", "rb"))
24 | for key in output.keys():
25 | #if key not in ref.keys():
26 | # continue
27 | count+=1
28 | article, abstract = data[key]
29 | print("--------------------")
30 | print(key)
31 | selected = []
32 | doc = nlp(article)
33 | sentences = [sentence.text.strip() for sentence in doc.sents]
34 |
35 | matrix = output[key]['vanilla']
36 | matrix[matrix<0] = 0
37 | relevance = []
38 | surprise = output[key]['surprise']
39 | for idx in range(len(sentences)):
40 | relevance.append(sum(matrix[idx]))
41 |
42 | penalty = [0 for i in range(len(sentences))]
43 | #print(surprise)
44 | #print(matrix)
45 | try:
46 | for j in range(1, 9):
47 | selected = []
48 | summary = ""
49 | for k in range(j):
50 | maxIdx = -1
51 | maxVal = -float('inf')
52 | #print(maxIdx, selected)
53 | for i in range(len(sentences)):
54 | temp = np.dot(clf.coef_[0], [penalty[i], relevance[i]])
55 | if temp > maxVal and i not in selected:
56 | maxIdx = i
57 | maxVal = temp
58 |
59 | #print(maxVal, maxIdx)
60 | for i in range(len(sentences)):
61 | penalty[i]+=matrix[i][maxIdx]
62 |
63 | selected.append(maxIdx)
64 | #print(selected)
65 | summary = ""
66 | for i in sorted(selected):
67 | summary+= sentences[i]+" "
68 |
69 | summary = ' '.join(summary.split())
70 |
71 | with open(args.output_file+str(j), "a") as f:
72 | f.write(summary+'\n')
73 |
74 | with open("./path/to/output/gold", "a") as f:
75 | f.write(' '.join(abstract.split())+'\n')
76 | except:
77 | print("Missed ", key)
78 |
79 | #print("SUMMARY ", summary)
80 |
81 | #if count == 10:
82 | # exit(0)
83 |
84 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # Unsupervised Extractive Summarization using Mutual Information
2 |
3 | ## Steps to perform summarization
4 |
5 | 1. Fine tune the required language model
6 | 2. Create sentence-wise mutual information matrices for documents in the dataset
7 | 3. Use the created matrices to generate extractive summaries
8 | 4. Evaluation
9 |
10 | ### 1. Fine tune the required language model
11 |
12 | We use GPT2-large as provided by [HF transformers](https://github.com/huggingface/transformers) with a standard fine tuning script as provided in the model directory.
13 |
14 | ### 2. Creating mutual information matrices
15 |
16 | For each document in the dataset, if there are n sentences, we need to create the nxn pairwise matrix as described in the paper.
17 |
18 | ```
19 | python get_matrices.py --index 0 --input_data ./path/to/dataset.pkl --index_file ./path/to/indices.pkl --output_format ./path/to/output/sent_train_set_
20 | ```
21 |
22 | The parameters indicate the following:
23 | 1. input\_data - Pickle file containing the document-summary files. Check that the format matches the following input scheme:
24 | ```
25 | data = pickle.load(open(args.input_data, "rb"))
26 | for i in range(len(data)):
27 | article, abstract = data[i]
28 | ```
29 | 2. index\_file - Since the process takes a while, we split the same into different parts by means of a pickled index file. The index file consists of a list of tuples of (start index, end index) within the indexing of the length of the dataset file. So if the dataset is of length 100, our index file might be [(0, 30), (30, 60), (60, 90), (90, 100)] and each execution of get\_matrices with a particular index creates the matrices associated with those datapoints in the dataset. The simplest way to handle this would be to make a list with just one tuple as [(0, length of dataset)].
30 | 3. index - Index within the index file to execute. Index file is a list, so if index is 0, then the matrices are created for the documents associated with the indices enclosed by the first tuple in the index file. So you can parallelize the process by running multiple at the same time.
31 | 4. output\_format - File format/location where the output is stored. Output is stored in the form of a pickle file at output format location with the index appended at the end.
32 |
33 | #### Output Format
34 | For the given execution line, the output would look like:
35 | ```
36 | ./path/to/output/sent_train_set_0.pkl
37 | ```
38 | Where the location/directory was given by the output format parameter in step 1 and the 0 is given by the index executed by the current script (in practise you would have 0, 1, 2... all running in parallel)
39 |
40 | The pickle file is indexed as
41 | ```
42 | data = pickle.load(open("./path/to/output/sent_train_set_0.pkl", "rb"))
43 | doc_number = 0
44 | print(data[doc_number]['vanilla'])
45 | ```
46 | Within the pickle file, the primary index is the index of the document within the dataset pickle file. We create the 'vanilla' matrix which is an nxn sentence-wise PMI matrix as described in the paper. Additionally we also create the unused 'surprise' list which is the sentence-wise surprisal for each sentence and the 'normalised' PMI matrix.
47 |
48 | ### 3. Generate Summaries
49 | From the generated pmi matrices, we generate summaries using the algorithm detailed in the paper.
50 | ```
51 | python get_interpolated_output.py --output_file ./path/to/output/interpolated_ --dataset_file ./path/to/dataset.pkl --matrix_file_pattern ./path/to/output/sent_train_set_
52 | ```
53 | Consumes the original dataset pickled file and the generated output files from the previous step. The script iterates over all the different output files created (hence why we aceept it as a pattern), and generates extractive summaries from length 1 to length 9 in 9 separate files using the format specified here. So here interpolated\_1 will consist of a text file where each line corresponds to a summary of length 1 sentence for each sentence in the dataset. And there will be 9 files like this. Additionally we also save all the gold summaries in a similar file. The reason for this file format is that it is compatible with the Rouge package in the next step. Additionally change the interpolation coefficients here if needed. To interpolate between relevance and redundancy, in the paper we used a classifier to learn weights assigned to each. A simpler alternative to just run inference is to set the weights to +1 for relevance and -1 for redundancy [here](https://github.com/vishakhpk/mi-unsup-summ/blob/196d3b646460f03cfcf9e41e1db621868a7156d0/get_interpolated_output.py#L54) and comment out the [line](https://github.com/vishakhpk/mi-unsup-summ/blob/196d3b646460f03cfcf9e41e1db621868a7156d0/get_interpolated_output.py#L17) loading the classifier.
54 |
55 | ### 4. Evaluation
56 | Standard rouge evaluation using [rouge scorer](https://github.com/google-research/google-research/tree/master/rouge)
57 | ```
58 | python3 -m rouge_score.rouge --target_filepattern=./path/to/output/gold --prediction_filepattern=./path/to/output/interpolated_3 --output_filename=./path/to/output/results.csv --use_stemmer=true
59 | ```
60 |
61 | ## Data for the paper
62 | The following contains the preprocessed datasets, created PMI matrices, generated summaries and Rouge score reports used in our paper:
63 |
64 | [Google Drive Link](https://drive.google.com/drive/folders/1dBPd7trOOdKTNFDtUSGH9Z3zZ2PucDmL?usp=sharing)
65 |
66 | Please reach out if you would like to use our saved language models (vishakh@nyu.edu). We use GPT2 large, fine tuned on the document sentences from the various domains, using the script in the models directory. The input files for the fine tuning script are available in folder [LM-Files](https://drive.google.com/drive/folders/1XrlzvJqmvcK0IpYK-VwIN5tk2y6iIILi?usp=sharing) at the above drive location.
67 |
68 | ## TODOs
69 | 1. Convert scripts to a reusable utility
70 | 2. Maybe remove indices files to make things easier
71 |
--------------------------------------------------------------------------------
/models/gpt2.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import numpy as np
3 | import code
4 | from transformers import GPT2Tokenizer, GPT2LMHeadModel
5 |
6 | class GPT2:
7 | """
8 | Citation: https://github.com/HendrikStrobelt/detecting-fake-text/blob/master/backend/api.py
9 | Model class for GPT-2. Primarily used to obtain word probabilities
10 | """
11 | def __init__(self, device = "cpu", location = ""):
12 | if location == "":
13 | self.enc = GPT2Tokenizer.from_pretrained("gpt2-large")
14 | self.model = GPT2LMHeadModel.from_pretrained("gpt2-large")
15 | else:
16 | self.enc = GPT2Tokenizer.from_pretrained(location)
17 | self.model = GPT2LMHeadModel.from_pretrained(location)
18 | self.device = torch.device(device)
19 | self.model.eval()
20 | self.start_tok = "<|endoftext|>"
21 | #SPECIAL_TOKENS = [""]
22 | #self.enc.add_special_tokens(SPECIAL_TOKENS)
23 | #self.model.set_num_special_tokens(len(SPECIAL_TOKENS))
24 | self.model.to(self.device)
25 |
26 | def pad(self, context):
27 | max_len = max([len(sentence) for sentence in context])
28 | #print("Maximum Length: ", max_len)
29 | for i in range(len(context)):
30 | #print(len(context[i]), max_len - len(context[i]))
31 | for j in range(max_len - len(context[i])):
32 | context[i].append(context[i][0])
33 | #print(max_len - len(sentences[i].split()))
34 | #print("i: ", sentences[i])
35 | #print([[self.enc.encode("") for idx in range(max_len - len(in_text[i].split()))] for i in range(len(in_text))])
36 | #print(sentences)
37 | #print([len(context[i]) for i in range(len(context))])
38 | return context
39 |
40 | def get_probabilities(self, in_text, topk = 40):
41 | """
42 | Take in a sequence of text tokens, make predictions on each word given past context and
43 | return topk
44 |
45 | Returns:
46 | Dictionary "payload" containing:
47 | real_probs
48 | - List of tuples, one for each token in sequence
49 | - Probability of the actual words in the sequence
50 | - Each tuple of the form (position of next word in prediction, predicted probability)
51 |
52 | context_strings:
53 | - Strings in the sequence along with start token
54 | """
55 | with torch.no_grad():
56 | start_tok = torch.full((1, 1), self.enc.encoder[self.start_tok],
57 | device=self.device, dtype=torch.long)
58 | context = [self.start_tok+" "+in_text[i] for i in range(len(in_text))]
59 | context = [self.enc.encode(context[i]) for i in range(len(context))]
60 | context = self.pad(context)
61 | context = torch.tensor(context, device=self.device, dtype=torch.long)
62 | logits, _ = self.model(context)
63 | yhat = torch.softmax(logits[:, :-1], dim=-1)
64 | y = context[:, 1:]
65 | real_topk_probs = [yhat[t][np.arange(0, y[t].shape[0], 1), y[t]].data.cpu().numpy().tolist() for t in range(yhat.shape[0])]
66 | real_topk_probs = [list(map(lambda x: round(x, 15), real_topk_probs[t])) for t in range(len(real_topk_probs))]
67 |
68 | real_topk = [list(real_topk_probs[t]) for t in range(len(real_topk_probs))]
69 |
70 | context_strings = [[self.enc.decoder[s.item()] for s in context[t]] for t in range(len(context))]
71 | context_strings = [[self.postprocess(s) for s in context_strings[t]] for t in range(len(context_strings))]
72 | del context, logits, y, yhat,
73 | torch.cuda.empty_cache()
74 | """
75 | pred_topk = [[list(zip([self.enc.decoder[p] for p in sorted_preds[t][i][:topk]],
76 | list(map(lambda x: round(x, 5),yhat[t][i][sorted_preds[t][i][
77 | :topk]].data.cpu().numpy().tolist()))))
78 | for i in range(y[t].shape[0])] for t in range(y.shape[0])]
79 | pred_topk = [[[(self.postprocess(t[0]), t[1]) for t in pred] for pred in pred_topk[t]] for t in range(len(pred_topk))]
80 | """
81 | payload = {'context_strings': context_strings,
82 | 'real_probs': real_topk}#, 'pred_topk': pred_topk}
83 |
84 | #del context, logits, y, yhat,
85 | #torch.cuda.empty_cache()
86 | #code.interact(local=locals())
87 | return payload
88 |
89 | def postprocess(self, token):
90 | with_space = False
91 | with_break = False
92 | #print(token, token[0], token[1:]),
93 | if token[0] == 'Ġ':
94 | with_space = True
95 | token = token[1:]
96 | elif token.startswith('â'):
97 | token = ' '
98 | elif token.startswith('Ċ'):
99 | token = ' '
100 | with_break = True
101 |
102 | if len(token)>0 and token[0] == "Â":
103 | token = token[1:]
104 | token = '-' if token.startswith('â') else token
105 | token = '“' if token.startswith('ľ') else token
106 | token = '”' if token.startswith('Ŀ') else token
107 | token = "'" if token.startswith('Ļ') else token
108 | #if with_space:
109 | # token = '\u0120' + token
110 | #if with_break:
111 | # token = '\u010A' + token
112 | #print(token)
113 | return token
114 |
115 | if __name__=="__main__":
116 | test = GPT2(location="../../Summarization/models/output_large/checkpoint-70000")
117 | payload = test.get_probabilities(["the following is", "what about this"], topk = 40)
118 | print(payload['real_probs'])
119 | #for t in range(len(payload["pred_topk"])):
120 | # for i in range(len(payload["pred_topk"][t])):
121 | # print(payload["real_probs"][t][i])
122 | #code.interact(local=locals())
123 | test = GPT2()
124 | payload = test.get_probabilities(["the following is", "what about this"], topk = 40)
125 | print("And without fine tuning: ")
126 | print(payload['real_probs'])
127 | #for t in range(len(payload["pred_topk"])):
128 | # for i in range(len(payload["pred_topk"][t])):
129 | # print(payload["real_probs"][t][i])
130 |
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | absl-py==0.8.1
2 | alabaster==0.7.11
3 | alembic==1.0.5
4 | anaconda-client==1.7.2
5 | anaconda-navigator==1.9.2
6 | anaconda-project==0.8.2
7 | appdirs==1.4.3
8 | asn1crypto==0.24.0
9 | astor==0.8.0
10 | astroid==2.0.4
11 | astropy==3.0.4
12 | async-generator==1.10
13 | atomicwrites==1.2.1
14 | attrs==18.2.0
15 | Automat==0.7.0
16 | Babel==2.6.0
17 | backcall==0.1.0
18 | backports.shutil-get-terminal-size==1.0.0
19 | bash-kernel==0.7.1
20 | batchspawner==0.8.1
21 | bcolz==1.2.1
22 | beautifulsoup4==4.6.3
23 | bitarray==0.8.3
24 | bkcharts==0.2
25 | blaze==0.11.3
26 | bleach==3.0.2
27 | blis==0.4.1
28 | bokeh==0.13.0
29 | boto==2.49.0
30 | boto3==1.10.0
31 | botocore==1.13.0
32 | Bottleneck==1.2.1
33 | catalogue==0.0.8
34 | certifi==2018.8.24
35 | cffi==1.11.5
36 | chardet==3.0.4
37 | click==6.7
38 | cloudpickle==0.5.5
39 | clyent==1.2.2
40 | colorama==0.3.9
41 | conda==4.5.11
42 | conda-build==3.15.1
43 | constantly==15.1.0
44 | contextlib2==0.5.5
45 | coverage==3.7.1
46 | coveralls==0.5
47 | cryptography==2.3.1
48 | cycler==0.10.0
49 | cymem==2.0.2
50 | Cython==0.28.5
51 | cytoolz==0.9.0.1
52 | dask==0.19.1
53 | datascience==0.10.6
54 | datashape==0.5.4
55 | decorator==4.3.0
56 | defusedxml==0.5.0
57 | dill==0.3.2
58 | distributed==1.23.1
59 | docopt==0.6.2
60 | docutils==0.14
61 | en-core-web-sm==2.2.5
62 | entrypoints==0.2.3
63 | et-xmlfile==1.0.1
64 | fairseq==0.9.0
65 | fastai==1.0.57
66 | fastBPE==0.1.0
67 | fastcache==1.0.2
68 | fastprogress==0.1.21
69 | fasttext==0.9.1
70 | filelock==3.0.8
71 | Flask==1.0.2
72 | Flask-Cors==3.0.6
73 | folium==0.2.1
74 | future==0.18.2
75 | gast==0.2.2
76 | gensim==3.8.2
77 | gevent==1.3.6
78 | glob2==0.6
79 | gmpy2==2.0.8
80 | google-pasta==0.1.7
81 | googleapis-common-protos==1.52.0
82 | greenlet==0.4.15
83 | grpcio==1.24.1
84 | h5py==2.8.0
85 | heapdict==1.0.0
86 | html5lib==1.0.1
87 | hyperlink==18.0.0
88 | idna==2.7
89 | imageio==2.4.1
90 | imagesize==1.1.0
91 | importlib-metadata==1.3.0
92 | incremental==17.5.0
93 | ipykernel==5.1.0
94 | ipython==6.2.1
95 | ipython-genutils==0.2.0
96 | ipywidgets==7.4.2
97 | isort==4.3.4
98 | itsdangerous==0.24
99 | jdcal==1.4
100 | jedi==0.13.1
101 | jeepney==0.3.1
102 | Jinja2==2.10
103 | jmespath==0.9.4
104 | joblib==0.14.0
105 | jsonschema==2.6.0
106 | jupyter==1.0.0
107 | jupyter-client==5.2.4
108 | jupyter-console==5.2.0
109 | jupyter-core==4.4.0
110 | jupyterhub==0.9.4
111 | jupyterhub-ldapauthenticator==1.2.2
112 | jupyterlab==0.34.9
113 | jupyterlab-launcher==0.13.1
114 | Keras==2.3.1
115 | Keras-Applications==1.0.8
116 | Keras-Preprocessing==1.1.0
117 | keyring==13.2.1
118 | kiwisolver==1.0.1
119 | lazy-object-proxy==1.3.1
120 | ldap3==2.5.1
121 | llvmlite==0.24.0
122 | locket==0.2.0
123 | lxml==4.2.5
124 | Mako==1.0.7
125 | Markdown==3.1.1
126 | MarkupSafe==1.1.0
127 | matlab-kernel==0.16.1
128 | matplotlib==2.2.3
129 | mccabe==0.6.1
130 | mecab-python3==0.996.5
131 | metakernel==0.20.14
132 | mistune==0.8.4
133 | mkl-fft==1.0.4
134 | mkl-random==1.0.1
135 | more-itertools==4.3.0
136 | mpmath==1.0.0
137 | msgpack==0.5.6
138 | multipledispatch==0.6.0
139 | murmurhash==1.0.2
140 | navigator-updater==0.2.1
141 | nbconvert==5.4.0
142 | nbformat==4.4.0
143 | nbgitpuller==0.6.1
144 | networkx==2.1
145 | nltk==3.3
146 | nmslib==1.8.1
147 | nose==1.3.7
148 | notebook==5.7.2
149 | numba==0.39.0
150 | numexpr==2.6.8
151 | numpy==1.16.2
152 | numpydoc==0.8.0
153 | nvidia-ml-py3==7.352.0
154 | odo==0.5.1
155 | olefile==0.46
156 | openpyxl==2.5.6
157 | opt-einsum==3.1.0
158 | packaging==17.1
159 | pamela==0.3.0
160 | pandas==0.23.4
161 | pandocfilters==1.4.2
162 | parso==0.3.1
163 | partd==0.3.8
164 | path.py==11.1.0
165 | pathlib2==2.3.2
166 | patsy==0.5.0
167 | pep8==1.7.1
168 | pexpect==4.6.0
169 | pickleshare==0.7.5
170 | Pillow==5.2.0
171 | pkginfo==1.4.2
172 | plac==0.9.6
173 | pluggy==0.7.1
174 | ply==3.11
175 | portalocker==1.7.0
176 | preshed==3.0.2
177 | prometheus-client==0.5.0
178 | promise==2.3
179 | prompt-toolkit==1.0.15
180 | protobuf==3.10.0
181 | psutil==5.4.7
182 | ptyprocess==0.6.0
183 | py==1.6.0
184 | pyasn1==0.4.4
185 | pyasn1-modules==0.2.2
186 | pybind11==2.4.1
187 | pycodestyle==2.4.0
188 | pycosat==0.6.3
189 | pycparser==2.18
190 | pycrypto==2.6.1
191 | pycurl==7.43.0.2
192 | pyflakes==2.0.0
193 | Pygments==2.3.0
194 | pylint==2.1.1
195 | pymongo==3.9.0
196 | pyodbc==4.0.24
197 | pyOpenSSL==18.0.0
198 | pyparsing==2.2.0
199 | pyrouge==0.1.3
200 | pyrsistent==0.15.7
201 | PySocks==1.6.8
202 | pytest==3.8.0
203 | pytest-arraydiff==0.2
204 | pytest-astropy==0.4.0
205 | pytest-doctestplus==0.1.3
206 | pytest-openfiles==0.3.0
207 | pytest-remotedata==0.3.0
208 | python-dateutil==2.7.5
209 | python-editor==1.0.3
210 | python-oauth2==1.1.0
211 | pytorch-transformers==1.2.0
212 | pytz==2018.5
213 | PyWavelets==1.0.0
214 | PyYAML==3.13
215 | pyzmq==17.1.2
216 | QtAwesome==0.4.4
217 | qtconsole==4.4.1
218 | QtPy==1.5.0
219 | regex==2019.8.19
220 | requests==2.19.1
221 | rope==0.11.0
222 | rouge-score==0.0.3
223 | ruamel-yaml==0.15.46
224 | s3transfer==0.2.1
225 | sacrebleu==1.4.6
226 | sacremoses==0.0.35
227 | scikit-image==0.14.0
228 | scikit-learn==0.19.2
229 | scipy==1.1.0
230 | seaborn==0.9.0
231 | SecretStorage==3.1.0
232 | Send2Trash==1.5.0
233 | sentencepiece==0.1.83
234 | seqeval==0.0.12
235 | service-identity==17.0.0
236 | simplegeneric==0.8.1
237 | simpletransformers==0.22.1
238 | singledispatch==3.4.0.3
239 | six==1.12.0
240 | smappdragon==0.0.47
241 | smart-open==2.0.0
242 | snowballstemmer==1.2.1
243 | sortedcollections==1.0.1
244 | sortedcontainers==2.0.5
245 | spacy==2.2.3
246 | Sphinx==1.7.9
247 | sphinxcontrib-websupport==1.1.0
248 | spyder==3.3.1
249 | spyder-kernels==0.2.6
250 | SQLAlchemy==1.2.11
251 | srsly==0.1.0
252 | statsmodels==0.9.0
253 | sympy==1.2
254 | tables==3.4.4
255 | tblib==1.3.2
256 | tensorboard==2.0.0
257 | tensorboardX==2.0
258 | tensorflow==2.0.0
259 | tensorflow-datasets==3.2.1
260 | tensorflow-estimator==2.0.1
261 | tensorflow-metadata==0.22.2
262 | termcolor==1.1.0
263 | terminado==0.8.1
264 | testpath==0.3.1
265 | thinc==7.3.1
266 | tokenizers==0.5.2
267 | toolz==0.9.0
268 | torch==1.4.0
269 | torchvision==0.4.0
270 | tornado==5.1.1
271 | tqdm==4.43.0
272 | traitlets==4.3.2
273 | transformers==2.6.0
274 | Twisted==18.7.0
275 | typing==3.7.4.1
276 | unicodecsv==0.14.1
277 | uritools==3.0.0
278 | urlextract==0.14.0
279 | urllib3==1.23
280 | virtualenv==16.1.0
281 | wasabi==0.4.2
282 | wcwidth==0.1.7
283 | webencodings==0.5.1
284 | Werkzeug==0.14.1
285 | widgetsnbextension==3.4.2
286 | wrapt==1.11.2
287 | wurlitzer==1.0.2
288 | xlrd==1.1.0
289 | XlsxWriter==1.1.0
290 | xlwt==1.3.0
291 | zict==0.1.3
292 | zipp==0.6.0
293 | zope.interface==4.5.0
294 |
--------------------------------------------------------------------------------
/get_matrices.py:
--------------------------------------------------------------------------------
1 | import os
2 | import sys
3 | sys.path.insert(0, "./models/")
4 | sys.path.insert(0, "./read-data/")
5 | from read import get_text
6 | from gpt2 import GPT2
7 | from tqdm import tqdm
8 | import preprocess_subsequence
9 | import nltk
10 | import math
11 | from nltk import tokenize
12 | import numpy as np
13 | #from sklearn.metrics import precision_score, recall_score
14 | import pickle
15 | import spacy
16 | import argparse
17 | parser = argparse.ArgumentParser()
18 | parser.add_argument("--index")
19 | parser.add_argument("--input_data")
20 | parser.add_argument("--index_file")
21 | parser.add_argument("--output_format")
22 | args = parser.parse_args()
23 | print(args.index)
24 | print(args.index_file)
25 | print(args.input_data)
26 | print(args.output_format)
27 |
28 | data = pickle.load(open(args.input_data, "rb"))#
29 | print("Length ", len(data))
30 | indices = pickle.load(open(args.index_file, "rb"))
31 | print(indices)
32 | lower, upper = indices[int(args.index)]
33 | print(lower, upper)
34 |
35 | nlp = spacy.load("en_core_web_sm")
36 |
37 | model = GPT2(device="cuda", location="./path/to/saved/model/")
38 |
39 | if os.path.exists(args.output_format+args.index+".pkl"):
40 | output = pickle.load(open(args.output_format+args.index+".pkl","rb"))
41 | else:
42 | output = {}
43 |
44 | def get_probabilities(articles):
45 | """
46 | Given a batch of articles (can be any strings) run a forward pass on GPT2 and obtain word probabilities for the same
47 | """
48 | article_splits = [article.split(" ") for article in articles]
49 | payload = model.get_probabilities(articles, topk = 20)
50 | res = [[] for i in range(len(articles))]
51 | for t, article in enumerate(articles):
52 | context = ""
53 | idx = 0
54 | chain = False
55 | next_word = ""
56 | article_words = article_splits[t]
57 | #print(article, article_words)
58 | word_probability = 1.0
59 | gt_count = 0
60 | idx+=1
61 | found_words = []
62 | for i, word in enumerate(payload["context_strings"][t][:-1]):
63 | context = context+" "+word
64 | probability = payload['real_probs'][t][i]#[1]
65 | next_word_fragment = payload["context_strings"][t][i+1]
66 |
67 | next_word += next_word_fragment
68 | #print(next_word, article_words[gt_count])
69 | if next_word == article_words[gt_count]:
70 | chain = False
71 | gt_count+=1
72 | else:
73 | chain = True
74 |
75 | word_probability *= probability
76 | assert word_probability <= 1.0, print(word_probability, context)
77 | if chain == False:
78 | #print("Word Probability: ", word_probability, next_word)
79 | res[t].append(word_probability)
80 | word_probability = 1.0
81 | next_word = ""
82 | #print(gt_count, len(article_words))
83 | if gt_count == len(article_words):
84 | break
85 | return res
86 |
87 |
88 | def get_npmi_matrix(sentences, method = 1, batch_size = 1):
89 | """
90 | Accepts a list of sentences of length n and returns 3 objects:
91 | - Normalised PMI nxn matrix - temp
92 | - PMI nxn matrix - temp2
93 | - List of length n indicating sentence-wise surprisal i.e. p(sentence) - p
94 |
95 | To optimize performance, we do the forward pass batchwise by assembling the batch and maintaining batch indices
96 | For each batch we call get_probabilities
97 | """
98 | temp = np.zeros((len(sentences), len(sentences)))
99 | temp2 = np.zeros((len(sentences), len(sentences)))
100 | batch_indices = {}
101 | batch = []
102 | batchCount = 0
103 | batchSize = batch_size
104 | #print(len(sentences))
105 | c = 0
106 | p = []
107 | for i in range(len(sentences)):
108 | result = get_probabilities([sentences[i]])
109 | try:
110 | p.append(sum([math.log(i) for i in result[0]]))
111 | except:
112 | print("Math domain error surprise", i)
113 | return temp, temp2, p
114 | for i in range(len(sentences)):
115 | for j in range(len(sentences)):
116 | if i==j:
117 | temp[i][j] = -1
118 | temp2[i][j] = -1
119 | continue
120 | article = sentences[i] + " "+ sentences[j]
121 | #print(article)
122 | batch_indices[str(i)+"-"+str(j)+"-"+str(len(sentences[i].split()))] = batchCount
123 | batch.append(article)
124 | batchCount+=1
125 |
126 | if batchCount == batchSize or (i == len(sentences)-1 and j == len(sentences)-1):
127 | #print(batch)
128 | c+=1
129 | result = get_probabilities(batch)
130 | for key in batch_indices.keys():
131 | #print(key)
132 | #print(key.split("-"))
133 | idx_i, idx_j, idx_l = [int(idx) for idx in key.split("-")]
134 | try:
135 | pxy = sum([math.log(q) for q in result[batch_indices[key]][idx_l:]])
136 | py = p[idx_j]
137 | px = p[idx_i]
138 |
139 | temp[idx_i][idx_j] = (pxy - py)/(-1*(pxy+px))
140 | temp2[idx_i][idx_j] = (pxy - py)
141 | except ZeroDivisionError:
142 | print("Zero division error ", idx_i, idx_j)
143 | temp[idx_i][idx_j] = -1
144 | temp2[idx_i][idx_j] = -1
145 | except:
146 | print("Math Domain Error", i, j)
147 | if temp[idx_i][idx_j] > 1 or temp[idx_i][idx_j] < -1:
148 | print("Normalise assert ", temp[idx_i][idx_j], idx_i, idx_j)
149 | batchCount = 0
150 | batch = []
151 | batch_indices = {}
152 | return temp, temp2, p
153 |
154 | def remove_unicode(text):
155 | return ''.join([i if ord(i) < 128 else ' ' for i in text])
156 |
157 | def get_article(idx):
158 | """
159 | For each document in the dataset, split it into sentences and call get_npmi_matrix to create the matrices
160 | """
161 | print(idx)
162 | article, abstract = data[idx]
163 | #sentences = tokenize.sent_tokenize(article)
164 | doc = nlp(article)
165 | sentences = [remove_unicode(sentence.text) for sentence in doc.sents]
166 | normalised, vanilla, surprise = get_npmi_matrix(sentences, batch_size = 10)
167 | #avg = get_pmi_matrix(sentences, method = 1)
168 | output[idx] = {}
169 | output[idx]["vanilla"] = vanilla
170 | output[idx]["normalised"] = normalised
171 | output[idx]["surprise"] = surprise
172 | #output[idx]["averaging"] = avg
173 | #pickle.dump(output, open("full_set_1.pkl", "wb"))
174 | return
175 |
176 | """
177 | Main iteration loop, creates matrices for each document in the dataset
178 | """
179 | c = 0
180 | for idx in range(len(data)):
181 | if idx>=lower and idx 0 and not line.isspace())]
117 |
118 | self.examples = tokenizer.batch_encode_plus(lines, add_special_tokens=True, max_length=block_size)["input_ids"]
119 |
120 | def __len__(self):
121 | return len(self.examples)
122 |
123 | def __getitem__(self, i):
124 | return torch.tensor(self.examples[i], dtype=torch.long)
125 |
126 |
127 | def load_and_cache_examples(args, tokenizer, evaluate=False):
128 | file_path = args.eval_data_file if evaluate else args.train_data_file
129 | if args.line_by_line:
130 | return LineByLineTextDataset(tokenizer, args, file_path=file_path, block_size=args.block_size)
131 | else:
132 | return TextDataset(tokenizer, args, file_path=file_path, block_size=args.block_size)
133 |
134 |
135 | def set_seed(args):
136 | random.seed(args.seed)
137 | np.random.seed(args.seed)
138 | torch.manual_seed(args.seed)
139 | if args.n_gpu > 0:
140 | torch.cuda.manual_seed_all(args.seed)
141 |
142 |
143 | def _sorted_checkpoints(args, checkpoint_prefix="checkpoint", use_mtime=False) -> List[str]:
144 | ordering_and_checkpoint_path = []
145 |
146 | glob_checkpoints = glob.glob(os.path.join(args.output_dir, "{}-*".format(checkpoint_prefix)))
147 |
148 | for path in glob_checkpoints:
149 | if use_mtime:
150 | ordering_and_checkpoint_path.append((os.path.getmtime(path), path))
151 | else:
152 | regex_match = re.match(".*{}-([0-9]+)".format(checkpoint_prefix), path)
153 | if regex_match and regex_match.groups():
154 | ordering_and_checkpoint_path.append((int(regex_match.groups()[0]), path))
155 |
156 | checkpoints_sorted = sorted(ordering_and_checkpoint_path)
157 | checkpoints_sorted = [checkpoint[1] for checkpoint in checkpoints_sorted]
158 | return checkpoints_sorted
159 |
160 |
161 | def _rotate_checkpoints(args, checkpoint_prefix="checkpoint", use_mtime=False) -> None:
162 | if not args.save_total_limit:
163 | return
164 | if args.save_total_limit <= 0:
165 | return
166 |
167 | # Check if we should delete older checkpoint(s)
168 | checkpoints_sorted = _sorted_checkpoints(args, checkpoint_prefix, use_mtime)
169 | if len(checkpoints_sorted) <= args.save_total_limit:
170 | return
171 |
172 | number_of_checkpoints_to_delete = max(0, len(checkpoints_sorted) - args.save_total_limit)
173 | checkpoints_to_be_deleted = checkpoints_sorted[:number_of_checkpoints_to_delete]
174 | for checkpoint in checkpoints_to_be_deleted:
175 | logger.info("Deleting older checkpoint [{}] due to args.save_total_limit".format(checkpoint))
176 | shutil.rmtree(checkpoint)
177 |
178 |
179 | def mask_tokens(inputs: torch.Tensor, tokenizer: PreTrainedTokenizer, args) -> Tuple[torch.Tensor, torch.Tensor]:
180 | """ Prepare masked tokens inputs/labels for masked language modeling: 80% MASK, 10% random, 10% original. """
181 |
182 | if tokenizer.mask_token is None:
183 | raise ValueError(
184 | "This tokenizer does not have a mask token which is necessary for masked language modeling. Remove the --mlm flag if you want to use this tokenizer."
185 | )
186 |
187 | labels = inputs.clone()
188 | # We sample a few tokens in each sequence for masked-LM training (with probability args.mlm_probability defaults to 0.15 in Bert/RoBERTa)
189 | probability_matrix = torch.full(labels.shape, args.mlm_probability)
190 | special_tokens_mask = [
191 | tokenizer.get_special_tokens_mask(val, already_has_special_tokens=True) for val in labels.tolist()
192 | ]
193 | probability_matrix.masked_fill_(torch.tensor(special_tokens_mask, dtype=torch.bool), value=0.0)
194 | if tokenizer._pad_token is not None:
195 | padding_mask = labels.eq(tokenizer.pad_token_id)
196 | probability_matrix.masked_fill_(padding_mask, value=0.0)
197 | masked_indices = torch.bernoulli(probability_matrix).bool()
198 | labels[~masked_indices] = -100 # We only compute loss on masked tokens
199 |
200 | # 80% of the time, we replace masked input tokens with tokenizer.mask_token ([MASK])
201 | indices_replaced = torch.bernoulli(torch.full(labels.shape, 0.8)).bool() & masked_indices
202 | inputs[indices_replaced] = tokenizer.convert_tokens_to_ids(tokenizer.mask_token)
203 |
204 | # 10% of the time, we replace masked input tokens with random word
205 | indices_random = torch.bernoulli(torch.full(labels.shape, 0.5)).bool() & masked_indices & ~indices_replaced
206 | random_words = torch.randint(len(tokenizer), labels.shape, dtype=torch.long)
207 | inputs[indices_random] = random_words[indices_random]
208 |
209 | # The rest of the time (10% of the time) we keep the masked input tokens unchanged
210 | return inputs, labels
211 |
212 |
213 | def train(args, train_dataset, model: PreTrainedModel, tokenizer: PreTrainedTokenizer) -> Tuple[int, float]:
214 | """ Train the model """
215 | if args.local_rank in [-1, 0]:
216 | tb_writer = SummaryWriter()
217 |
218 | args.train_batch_size = args.per_gpu_train_batch_size * max(1, args.n_gpu)
219 |
220 | def collate(examples: List[torch.Tensor]):
221 | if tokenizer._pad_token is None:
222 | return pad_sequence(examples, batch_first=True)
223 | return pad_sequence(examples, batch_first=True, padding_value=tokenizer.pad_token_id)
224 |
225 | train_sampler = RandomSampler(train_dataset) if args.local_rank == -1 else DistributedSampler(train_dataset)
226 | train_dataloader = DataLoader(
227 | train_dataset, sampler=train_sampler, batch_size=args.train_batch_size, collate_fn=collate
228 | )
229 |
230 | if args.max_steps > 0:
231 | t_total = args.max_steps
232 | args.num_train_epochs = args.max_steps // (len(train_dataloader) // args.gradient_accumulation_steps) + 1
233 | else:
234 | t_total = len(train_dataloader) // args.gradient_accumulation_steps * args.num_train_epochs
235 |
236 | model = model.module if hasattr(model, "module") else model # Take care of distributed/parallel training
237 | model.resize_token_embeddings(len(tokenizer))
238 |
239 | # Prepare optimizer and schedule (linear warmup and decay)
240 | no_decay = ["bias", "LayerNorm.weight"]
241 | optimizer_grouped_parameters = [
242 | {
243 | "params": [p for n, p in model.named_parameters() if not any(nd in n for nd in no_decay)],
244 | "weight_decay": args.weight_decay,
245 | },
246 | {"params": [p for n, p in model.named_parameters() if any(nd in n for nd in no_decay)], "weight_decay": 0.0},
247 | ]
248 | optimizer = AdamW(optimizer_grouped_parameters, lr=args.learning_rate, eps=args.adam_epsilon)
249 | scheduler = get_linear_schedule_with_warmup(
250 | optimizer, num_warmup_steps=args.warmup_steps, num_training_steps=t_total
251 | )
252 |
253 | # Check if saved optimizer or scheduler states exist
254 | if (
255 | args.model_name_or_path
256 | and os.path.isfile(os.path.join(args.model_name_or_path, "optimizer.pt"))
257 | and os.path.isfile(os.path.join(args.model_name_or_path, "scheduler.pt"))
258 | ):
259 | # Load in optimizer and scheduler states
260 | optimizer.load_state_dict(torch.load(os.path.join(args.model_name_or_path, "optimizer.pt")))
261 | scheduler.load_state_dict(torch.load(os.path.join(args.model_name_or_path, "scheduler.pt")))
262 |
263 | if args.fp16:
264 | try:
265 | from apex import amp
266 | except ImportError:
267 | raise ImportError("Please install apex from https://www.github.com/nvidia/apex to use fp16 training.")
268 | model, optimizer = amp.initialize(model, optimizer, opt_level=args.fp16_opt_level)
269 |
270 | # multi-gpu training (should be after apex fp16 initialization)
271 | if args.n_gpu > 1:
272 | model = torch.nn.DataParallel(model)
273 |
274 | # Distributed training (should be after apex fp16 initialization)
275 | if args.local_rank != -1:
276 | model = torch.nn.parallel.DistributedDataParallel(
277 | model, device_ids=[args.local_rank], output_device=args.local_rank, find_unused_parameters=True
278 | )
279 |
280 | # Train!
281 | logger.info("***** Running training *****")
282 | logger.info(" Num examples = %d", len(train_dataset))
283 | logger.info(" Num Epochs = %d", args.num_train_epochs)
284 | logger.info(" Instantaneous batch size per GPU = %d", args.per_gpu_train_batch_size)
285 | logger.info(
286 | " Total train batch size (w. parallel, distributed & accumulation) = %d",
287 | args.train_batch_size
288 | * args.gradient_accumulation_steps
289 | * (torch.distributed.get_world_size() if args.local_rank != -1 else 1),
290 | )
291 | logger.info(" Gradient Accumulation steps = %d", args.gradient_accumulation_steps)
292 | logger.info(" Total optimization steps = %d", t_total)
293 |
294 | global_step = 0
295 | epochs_trained = 0
296 | steps_trained_in_current_epoch = 0
297 | # Check if continuing training from a checkpoint
298 | if args.model_name_or_path and os.path.exists(args.model_name_or_path):
299 | try:
300 | # set global_step to gobal_step of last saved checkpoint from model path
301 | checkpoint_suffix = args.model_name_or_path.split("-")[-1].split("/")[0]
302 | global_step = int(checkpoint_suffix)
303 | epochs_trained = global_step // (len(train_dataloader) // args.gradient_accumulation_steps)
304 | steps_trained_in_current_epoch = global_step % (len(train_dataloader) // args.gradient_accumulation_steps)
305 |
306 | logger.info(" Continuing training from checkpoint, will skip to saved global_step")
307 | logger.info(" Continuing training from epoch %d", epochs_trained)
308 | logger.info(" Continuing training from global step %d", global_step)
309 | logger.info(" Will skip the first %d steps in the first epoch", steps_trained_in_current_epoch)
310 | except ValueError:
311 | logger.info(" Starting fine-tuning.")
312 |
313 | tr_loss, logging_loss = 0.0, 0.0
314 |
315 | model.zero_grad()
316 | train_iterator = trange(
317 | epochs_trained, int(args.num_train_epochs), desc="Epoch", disable=args.local_rank not in [-1, 0]
318 | )
319 | set_seed(args) # Added here for reproducibility
320 | for epoch in train_iterator:
321 | epoch_iterator = tqdm(train_dataloader, desc="Iteration", disable=args.local_rank not in [-1, 0])
322 |
323 | if args.local_rank != -1:
324 | train_sampler.set_epoch(epoch)
325 |
326 | for step, batch in enumerate(epoch_iterator):
327 |
328 | # Skip past any already trained steps if resuming training
329 | if steps_trained_in_current_epoch > 0:
330 | steps_trained_in_current_epoch -= 1
331 | continue
332 |
333 | inputs, labels = mask_tokens(batch, tokenizer, args) if args.mlm else (batch, batch)
334 | inputs = inputs.to(args.device)
335 | labels = labels.to(args.device)
336 | model.train()
337 | outputs = model(inputs, masked_lm_labels=labels) if args.mlm else model(inputs, labels=labels)
338 | loss = outputs[0] # model outputs are always tuple in transformers (see doc)
339 |
340 | if args.n_gpu > 1:
341 | loss = loss.mean() # mean() to average on multi-gpu parallel training
342 | if args.gradient_accumulation_steps > 1:
343 | loss = loss / args.gradient_accumulation_steps
344 |
345 | if args.fp16:
346 | with amp.scale_loss(loss, optimizer) as scaled_loss:
347 | scaled_loss.backward()
348 | else:
349 | loss.backward()
350 |
351 | tr_loss += loss.item()
352 | if (step + 1) % args.gradient_accumulation_steps == 0:
353 | if args.fp16:
354 | torch.nn.utils.clip_grad_norm_(amp.master_params(optimizer), args.max_grad_norm)
355 | else:
356 | torch.nn.utils.clip_grad_norm_(model.parameters(), args.max_grad_norm)
357 | optimizer.step()
358 | scheduler.step() # Update learning rate schedule
359 | model.zero_grad()
360 | global_step += 1
361 |
362 | if args.local_rank in [-1, 0] and args.logging_steps > 0 and global_step % args.logging_steps == 0:
363 | # Log metrics
364 | if (
365 | args.local_rank == -1 and args.evaluate_during_training
366 | ): # Only evaluate when single GPU otherwise metrics may not average well
367 | results = evaluate(args, model, tokenizer)
368 | for key, value in results.items():
369 | tb_writer.add_scalar("eval_{}".format(key), value, global_step)
370 | tb_writer.add_scalar("lr", scheduler.get_lr()[0], global_step)
371 | tb_writer.add_scalar("loss", (tr_loss - logging_loss) / args.logging_steps, global_step)
372 | logging_loss = tr_loss
373 |
374 | if args.local_rank in [-1, 0] and args.save_steps > 0 and global_step % args.save_steps == 0:
375 | checkpoint_prefix = "checkpoint"
376 | # Save model checkpoint
377 | output_dir = os.path.join(args.output_dir, "{}-{}".format(checkpoint_prefix, global_step))
378 | os.makedirs(output_dir, exist_ok=True)
379 | model_to_save = (
380 | model.module if hasattr(model, "module") else model
381 | ) # Take care of distributed/parallel training
382 | model_to_save.save_pretrained(output_dir)
383 | tokenizer.save_pretrained(output_dir)
384 |
385 | torch.save(args, os.path.join(output_dir, "training_args.bin"))
386 | logger.info("Saving model checkpoint to %s", output_dir)
387 |
388 | _rotate_checkpoints(args, checkpoint_prefix)
389 |
390 | torch.save(optimizer.state_dict(), os.path.join(output_dir, "optimizer.pt"))
391 | torch.save(scheduler.state_dict(), os.path.join(output_dir, "scheduler.pt"))
392 | logger.info("Saving optimizer and scheduler states to %s", output_dir)
393 |
394 | if args.max_steps > 0 and global_step > args.max_steps:
395 | epoch_iterator.close()
396 | break
397 | if args.max_steps > 0 and global_step > args.max_steps:
398 | train_iterator.close()
399 | break
400 |
401 | if args.local_rank in [-1, 0]:
402 | tb_writer.close()
403 |
404 | return global_step, tr_loss / global_step
405 |
406 |
407 | def evaluate(args, model: PreTrainedModel, tokenizer: PreTrainedTokenizer, prefix="") -> Dict:
408 | # Loop to handle MNLI double evaluation (matched, mis-matched)
409 | eval_output_dir = args.output_dir
410 |
411 | eval_dataset = load_and_cache_examples(args, tokenizer, evaluate=True)
412 |
413 | if args.local_rank in [-1, 0]:
414 | os.makedirs(eval_output_dir, exist_ok=True)
415 |
416 | args.eval_batch_size = args.per_gpu_eval_batch_size * max(1, args.n_gpu)
417 | # Note that DistributedSampler samples randomly
418 |
419 | def collate(examples: List[torch.Tensor]):
420 | if tokenizer._pad_token is None:
421 | return pad_sequence(examples, batch_first=True)
422 | return pad_sequence(examples, batch_first=True, padding_value=tokenizer.pad_token_id)
423 |
424 | eval_sampler = SequentialSampler(eval_dataset)
425 | eval_dataloader = DataLoader(
426 | eval_dataset, sampler=eval_sampler, batch_size=args.eval_batch_size, collate_fn=collate
427 | )
428 |
429 | # multi-gpu evaluate
430 | if args.n_gpu > 1:
431 | model = torch.nn.DataParallel(model)
432 |
433 | # Eval!
434 | logger.info("***** Running evaluation {} *****".format(prefix))
435 | logger.info(" Num examples = %d", len(eval_dataset))
436 | logger.info(" Batch size = %d", args.eval_batch_size)
437 | eval_loss = 0.0
438 | nb_eval_steps = 0
439 | model.eval()
440 |
441 | for batch in tqdm(eval_dataloader, desc="Evaluating"):
442 | inputs, labels = mask_tokens(batch, tokenizer, args) if args.mlm else (batch, batch)
443 | inputs = inputs.to(args.device)
444 | labels = labels.to(args.device)
445 |
446 | with torch.no_grad():
447 | outputs = model(inputs, masked_lm_labels=labels) if args.mlm else model(inputs, labels=labels)
448 | lm_loss = outputs[0]
449 | eval_loss += lm_loss.mean().item()
450 | nb_eval_steps += 1
451 |
452 | eval_loss = eval_loss / nb_eval_steps
453 | perplexity = torch.exp(torch.tensor(eval_loss))
454 |
455 | result = {"perplexity": perplexity}
456 |
457 | output_eval_file = os.path.join(eval_output_dir, prefix, "eval_results.txt")
458 | with open(output_eval_file, "w") as writer:
459 | logger.info("***** Eval results {} *****".format(prefix))
460 | for key in sorted(result.keys()):
461 | logger.info(" %s = %s", key, str(result[key]))
462 | writer.write("%s = %s\n" % (key, str(result[key])))
463 |
464 | return result
465 |
466 |
467 | def main():
468 | parser = argparse.ArgumentParser()
469 |
470 | # Required parameters
471 | parser.add_argument(
472 | "--train_data_file", default=None, type=str, required=True, help="The input training data file (a text file)."
473 | )
474 | parser.add_argument(
475 | "--output_dir",
476 | type=str,
477 | required=True,
478 | help="The output directory where the model predictions and checkpoints will be written.",
479 | )
480 | parser.add_argument(
481 | "--model_type", type=str, required=True, help="The model architecture to be trained or fine-tuned.",
482 | )
483 |
484 | # Other parameters
485 | parser.add_argument(
486 | "--eval_data_file",
487 | default=None,
488 | type=str,
489 | help="An optional input evaluation data file to evaluate the perplexity on (a text file).",
490 | )
491 | parser.add_argument(
492 | "--line_by_line",
493 | action="store_true",
494 | help="Whether distinct lines of text in the dataset are to be handled as distinct sequences.",
495 | )
496 | parser.add_argument(
497 | "--should_continue", action="store_true", help="Whether to continue from latest checkpoint in output_dir"
498 | )
499 | parser.add_argument(
500 | "--model_name_or_path",
501 | default=None,
502 | type=str,
503 | help="The model checkpoint for weights initialization. Leave None if you want to train a model from scratch.",
504 | )
505 |
506 | parser.add_argument(
507 | "--mlm", action="store_true", help="Train with masked-language modeling loss instead of language modeling."
508 | )
509 | parser.add_argument(
510 | "--mlm_probability", type=float, default=0.15, help="Ratio of tokens to mask for masked language modeling loss"
511 | )
512 |
513 | parser.add_argument(
514 | "--config_name",
515 | default=None,
516 | type=str,
517 | help="Optional pretrained config name or path if not the same as model_name_or_path. If both are None, initialize a new config.",
518 | )
519 | parser.add_argument(
520 | "--tokenizer_name",
521 | default=None,
522 | type=str,
523 | help="Optional pretrained tokenizer name or path if not the same as model_name_or_path. If both are None, initialize a new tokenizer.",
524 | )
525 | parser.add_argument(
526 | "--cache_dir",
527 | default=None,
528 | type=str,
529 | help="Optional directory to store the pre-trained models downloaded from s3 (instead of the default one)",
530 | )
531 | parser.add_argument(
532 | "--block_size",
533 | default=-1,
534 | type=int,
535 | help="Optional input sequence length after tokenization."
536 | "The training dataset will be truncated in block of this size for training."
537 | "Default to the model max input length for single sentence inputs (take into account special tokens).",
538 | )
539 | parser.add_argument("--do_train", action="store_true", help="Whether to run training.")
540 | parser.add_argument("--do_eval", action="store_true", help="Whether to run eval on the dev set.")
541 | parser.add_argument(
542 | "--evaluate_during_training", action="store_true", help="Run evaluation during training at each logging step."
543 | )
544 |
545 | parser.add_argument("--per_gpu_train_batch_size", default=1, type=int, help="Batch size per GPU/CPU for training.")
546 | parser.add_argument(
547 | "--per_gpu_eval_batch_size", default=1, type=int, help="Batch size per GPU/CPU for evaluation."
548 | )
549 | parser.add_argument(
550 | "--gradient_accumulation_steps",
551 | type=int,
552 | default=1,
553 | help="Number of updates steps to accumulate before performing a backward/update pass.",
554 | )
555 | parser.add_argument("--learning_rate", default=5e-5, type=float, help="The initial learning rate for Adam.")
556 | parser.add_argument("--weight_decay", default=0.0, type=float, help="Weight decay if we apply some.")
557 | parser.add_argument("--adam_epsilon", default=1e-8, type=float, help="Epsilon for Adam optimizer.")
558 | parser.add_argument("--max_grad_norm", default=1.0, type=float, help="Max gradient norm.")
559 | parser.add_argument(
560 | "--num_train_epochs", default=1.0, type=float, help="Total number of training epochs to perform."
561 | )
562 | parser.add_argument(
563 | "--max_steps",
564 | default=-1,
565 | type=int,
566 | help="If > 0: set total number of training steps to perform. Override num_train_epochs.",
567 | )
568 | parser.add_argument("--warmup_steps", default=0, type=int, help="Linear warmup over warmup_steps.")
569 |
570 | parser.add_argument("--logging_steps", type=int, default=500, help="Log every X updates steps.")
571 | parser.add_argument("--save_steps", type=int, default=500, help="Save checkpoint every X updates steps.")
572 | parser.add_argument(
573 | "--save_total_limit",
574 | type=int,
575 | default=None,
576 | help="Limit the total amount of checkpoints, delete the older checkpoints in the output_dir, does not delete by default",
577 | )
578 | parser.add_argument(
579 | "--eval_all_checkpoints",
580 | action="store_true",
581 | help="Evaluate all checkpoints starting with the same prefix as model_name_or_path ending and ending with step number",
582 | )
583 | parser.add_argument("--no_cuda", action="store_true", help="Avoid using CUDA when available")
584 | parser.add_argument(
585 | "--overwrite_output_dir", action="store_true", help="Overwrite the content of the output directory"
586 | )
587 | parser.add_argument(
588 | "--overwrite_cache", action="store_true", help="Overwrite the cached training and evaluation sets"
589 | )
590 | parser.add_argument("--seed", type=int, default=42, help="random seed for initialization")
591 |
592 | parser.add_argument(
593 | "--fp16",
594 | action="store_true",
595 | help="Whether to use 16-bit (mixed) precision (through NVIDIA apex) instead of 32-bit",
596 | )
597 | parser.add_argument(
598 | "--fp16_opt_level",
599 | type=str,
600 | default="O1",
601 | help="For fp16: Apex AMP optimization level selected in ['O0', 'O1', 'O2', and 'O3']."
602 | "See details at https://nvidia.github.io/apex/amp.html",
603 | )
604 | parser.add_argument("--local_rank", type=int, default=-1, help="For distributed training: local_rank")
605 | parser.add_argument("--server_ip", type=str, default="", help="For distant debugging.")
606 | parser.add_argument("--server_port", type=str, default="", help="For distant debugging.")
607 | args = parser.parse_args()
608 |
609 | if args.model_type in ["bert", "roberta", "distilbert", "camembert"] and not args.mlm:
610 | raise ValueError(
611 | "BERT and RoBERTa-like models do not have LM heads but masked LM heads. They must be run using the --mlm "
612 | "flag (masked language modeling)."
613 | )
614 | if args.eval_data_file is None and args.do_eval:
615 | raise ValueError(
616 | "Cannot do evaluation without an evaluation data file. Either supply a file to --eval_data_file "
617 | "or remove the --do_eval argument."
618 | )
619 | if args.should_continue:
620 | sorted_checkpoints = _sorted_checkpoints(args)
621 | if len(sorted_checkpoints) == 0:
622 | raise ValueError("Used --should_continue but no checkpoint was found in --output_dir.")
623 | else:
624 | args.model_name_or_path = sorted_checkpoints[-1]
625 |
626 | if (
627 | os.path.exists(args.output_dir)
628 | and os.listdir(args.output_dir)
629 | and args.do_train
630 | and not args.overwrite_output_dir
631 | and not args.should_continue
632 | ):
633 | raise ValueError(
634 | "Output directory ({}) already exists and is not empty. Use --overwrite_output_dir to overcome.".format(
635 | args.output_dir
636 | )
637 | )
638 |
639 | # Setup distant debugging if needed
640 | if args.server_ip and args.server_port:
641 | # Distant debugging - see https://code.visualstudio.com/docs/python/debugging#_attach-to-a-local-script
642 | import ptvsd
643 |
644 | print("Waiting for debugger attach")
645 | ptvsd.enable_attach(address=(args.server_ip, args.server_port), redirect_output=True)
646 | ptvsd.wait_for_attach()
647 |
648 | # Setup CUDA, GPU & distributed training
649 | if args.local_rank == -1 or args.no_cuda:
650 | device = torch.device("cuda" if torch.cuda.is_available() and not args.no_cuda else "cpu")
651 | args.n_gpu = 0 if args.no_cuda else torch.cuda.device_count()
652 | else: # Initializes the distributed backend which will take care of sychronizing nodes/GPUs
653 | torch.cuda.set_device(args.local_rank)
654 | device = torch.device("cuda", args.local_rank)
655 | torch.distributed.init_process_group(backend="nccl")
656 | args.n_gpu = 1
657 | args.device = device
658 |
659 | # Setup logging
660 | logging.basicConfig(
661 | format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
662 | datefmt="%m/%d/%Y %H:%M:%S",
663 | level=logging.INFO if args.local_rank in [-1, 0] else logging.WARN,
664 | )
665 | logger.warning(
666 | "Process rank: %s, device: %s, n_gpu: %s, distributed training: %s, 16-bits training: %s",
667 | args.local_rank,
668 | device,
669 | args.n_gpu,
670 | bool(args.local_rank != -1),
671 | args.fp16,
672 | )
673 |
674 | # Set seed
675 | set_seed(args)
676 |
677 | # Load pretrained model and tokenizer
678 | if args.local_rank not in [-1, 0]:
679 | torch.distributed.barrier() # Barrier to make sure only the first process in distributed training download model & vocab
680 |
681 | if args.config_name:
682 | config = AutoConfig.from_pretrained(args.config_name, cache_dir=args.cache_dir)
683 | elif args.model_name_or_path:
684 | config = AutoConfig.from_pretrained(args.model_name_or_path, cache_dir=args.cache_dir)
685 | else:
686 | # When we release a pip version exposing CONFIG_MAPPING,
687 | # we can do `config = CONFIG_MAPPING[args.model_type]()`.
688 | raise ValueError(
689 | "You are instantiating a new config instance from scratch. This is not supported, but you can do it from another script, save it,"
690 | "and load it from here, using --config_name"
691 | )
692 |
693 | if args.tokenizer_name:
694 | tokenizer = AutoTokenizer.from_pretrained(args.tokenizer_name, cache_dir=args.cache_dir)
695 | elif args.model_name_or_path:
696 | tokenizer = AutoTokenizer.from_pretrained(args.model_name_or_path, cache_dir=args.cache_dir)
697 | else:
698 | raise ValueError(
699 | "You are instantiating a new tokenizer from scratch. This is not supported, but you can do it from another script, save it,"
700 | "and load it from here, using --tokenizer_name"
701 | )
702 |
703 | if args.block_size <= 0:
704 | args.block_size = 512#tokenizer.max_len
705 | # Our input block size will be the max possible for the model
706 | else:
707 | args.block_size = min(args.block_size, tokenizer.max_len)
708 | #print("Block_size", args.block_size)
709 | if args.model_name_or_path:
710 | model = AutoModelWithLMHead.from_pretrained(
711 | args.model_name_or_path,
712 | from_tf=bool(".ckpt" in args.model_name_or_path),
713 | config=config,
714 | cache_dir=args.cache_dir,
715 | )
716 | else:
717 | logger.info("Training new model from scratch")
718 | model = AutoModelWithLMHead.from_config(config)
719 |
720 | model.to(args.device)
721 |
722 | if args.local_rank == 0:
723 | torch.distributed.barrier() # End of barrier to make sure only the first process in distributed training download model & vocab
724 |
725 | logger.info("Training/evaluation parameters %s", args)
726 |
727 | # Training
728 | if args.do_train:
729 | if args.local_rank not in [-1, 0]:
730 | torch.distributed.barrier() # Barrier to make sure only the first process in distributed training process the dataset, and the others will use the cache
731 |
732 | train_dataset = load_and_cache_examples(args, tokenizer, evaluate=False)
733 |
734 | if args.local_rank == 0:
735 | torch.distributed.barrier()
736 |
737 | global_step, tr_loss = train(args, train_dataset, model, tokenizer)
738 | logger.info(" global_step = %s, average loss = %s", global_step, tr_loss)
739 |
740 | # Saving best-practices: if you use save_pretrained for the model and tokenizer, you can reload them using from_pretrained()
741 | if args.do_train and (args.local_rank == -1 or torch.distributed.get_rank() == 0):
742 | # Create output directory if needed
743 | if args.local_rank in [-1, 0]:
744 | os.makedirs(args.output_dir, exist_ok=True)
745 |
746 | logger.info("Saving model checkpoint to %s", args.output_dir)
747 | # Save a trained model, configuration and tokenizer using `save_pretrained()`.
748 | # They can then be reloaded using `from_pretrained()`
749 | model_to_save = (
750 | model.module if hasattr(model, "module") else model
751 | ) # Take care of distributed/parallel training
752 | model_to_save.save_pretrained(args.output_dir)
753 | tokenizer.save_pretrained(args.output_dir)
754 |
755 | # Good practice: save your training arguments together with the trained model
756 | torch.save(args, os.path.join(args.output_dir, "training_args.bin"))
757 |
758 | # Load a trained model and vocabulary that you have fine-tuned
759 | model = AutoModelWithLMHead.from_pretrained(args.output_dir)
760 | tokenizer = AutoTokenizer.from_pretrained(args.output_dir)
761 | model.to(args.device)
762 |
763 | # Evaluation
764 | results = {}
765 | if args.do_eval and args.local_rank in [-1, 0]:
766 | checkpoints = [args.output_dir]
767 | if args.eval_all_checkpoints:
768 | checkpoints = list(
769 | os.path.dirname(c) for c in sorted(glob.glob(args.output_dir + "/**/" + WEIGHTS_NAME, recursive=True))
770 | )
771 | logging.getLogger("transformers.modeling_utils").setLevel(logging.WARN) # Reduce logging
772 | logger.info("Evaluate the following checkpoints: %s", checkpoints)
773 | for checkpoint in checkpoints:
774 | global_step = checkpoint.split("-")[-1] if len(checkpoints) > 1 else ""
775 | prefix = checkpoint.split("/")[-1] if checkpoint.find("checkpoint") != -1 else ""
776 |
777 | model = AutoModelWithLMHead.from_pretrained(checkpoint)
778 | model.to(args.device)
779 | result = evaluate(args, model, tokenizer, prefix=prefix)
780 | result = dict((k + "_{}".format(global_step), v) for k, v in result.items())
781 | results.update(result)
782 |
783 | return results
784 |
785 |
786 | if __name__ == "__main__":
787 | main()
788 |
--------------------------------------------------------------------------------