├── .gitattributes ├── .gitignore ├── LICENSE ├── README.md ├── data.sh ├── doc ├── DialogRPT-EMNLP.pdf ├── demo.PNG ├── icon.png ├── toy.tsv ├── toy.tsv.ensemble.jsonl └── toy.tsv.updown.jsonl ├── requirements.txt ├── restore └── ensemble.yml └── src ├── attic └── data.py ├── data_new.py ├── feeder.py ├── generation.py ├── main.py ├── master.py ├── model.py ├── score.py ├── shared.py └── transformers19 ├── __init__.py ├── configuration_gpt2.py ├── configuration_utils.py ├── file_utils.py ├── modeling_gpt2.py ├── modeling_utils.py ├── tokenization_gpt2.py └── tokenization_utils.py /.gitattributes: -------------------------------------------------------------------------------- 1 | # Auto detect text files and perform LF normalization 2 | * text=auto 3 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | 2 | *.pyc 3 | *.pth 4 | out/ 5 | data/ 6 | .vscode/ 7 | test/ 8 | doc/arxiv/ 9 | model_card/ 10 | 11 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2020 Sean Xiang Gao 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 |
2 |
3 |
4 |
5 |
6 |
7 | # DialogRPT: Dialog Ranking Pretrained Transformers
8 |
9 | [DialogRPT](https://arxiv.org/abs/2009.06978/) predicts human feedback (upvotes👍 or replies💬) of dialogue responses.
10 |
11 | It is a set of dialog response ranking models proposed by [Microsoft Research NLP Group](https://www.microsoft.com/en-us/research/group/natural-language-processing/) trained on 100 + millions of human feedback data, accepted to appear at [EMNLP'20](https://2020.emnlp.org/).
12 | It can be used to improve existing dialog generation model (e.g., [DialoGPT](https://github.com/microsoft/DialoGPT)) by re-ranking the generated response candidates.
13 | This repo provides a PyTorch implementation and pretrained models.
14 |
15 | Quick links:
16 | * [Paper](https://arxiv.org/abs/2009.06978/)
17 | * [Intro Talk](https://slideslive.com/38938970/dialogue-response-ranking-training-with-largescale-human-feedback-data) and [Slides](https://github.com/golsun/DialogRPT/blob/master/doc/DialogRPT-EMNLP.pdf)
18 | * Demo: [original](https://colab.research.google.com/drive/1jQXzTYsgdZIQjJKrX4g3CP0_PGCeVU3C?usp=sharing) or [HuggingFace](https://colab.research.google.com/drive/1cAtfkbhqsRsT59y3imjR1APw3MHDMkuV?usp=sharing)
19 | * [Dataset](https://dialogfeedback.github.io/data.html)
20 |
21 | We considered the following tasks and provided corresponding pretrained models.
22 | (Click 💾 to download original pytorch checkpoint for this repo, or click 🤗 to use HuggingFace model card)
23 |
24 |
25 | | Task | Description | Pretrained model |
26 | | :------------- | :----------- | :-----------: |
27 | | **Human feedback** | | |
28 | | `updown` | How likely the response gets the most upvotes? | [💾](https://xiagnlp2.blob.core.windows.net/dialogrpt/updown.pth) / [🤗](https://huggingface.co/microsoft/DialogRPT-updown?text=I+love+NLP%21+<%7Cendoftext%7C>+Me+too%21) |
29 | | `width`| How likely the response gets the most direct replies? | [💾](https://xiagnlp2.blob.core.windows.net/dialogrpt/width.pth) / [🤗](https://huggingface.co/microsoft/DialogRPT-width?text=I+love+NLP%21+<%7Cendoftext%7C>+Me+too%21) |
30 | | `depth`| How likely the response gets the longest follow-up thread? | [💾](https://xiagnlp2.blob.core.windows.net/dialogrpt/depth.pth) / [🤗](https://huggingface.co/microsoft/DialogRPT-depth?text=I+love+NLP%21+<%7Cendoftext%7C>+Me+too%21) |
31 | | **Human-like** (human vs fake) | | |
32 | | `human_vs_rand`| How relevant the response is for the given context? | [💾](https://xiagnlp2.blob.core.windows.net/dialogrpt/human_vs_rand.pth) / [🤗](https://huggingface.co/microsoft/DialogRPT-human-vs-rand?text=I+love+NLP%21+<%7Cendoftext%7C>+Me+too%21) |
33 | | `human_vs_machine`| How likely the response is human-written rather than machine-generated? | [💾](https://xiagnlp2.blob.core.windows.net/dialogrpt/human_vs_machine.pth) / [🤗](https://huggingface.co/microsoft/DialogRPT-human-vs-machine?text=I+love+NLP%21+<%7Cendoftext%7C>+Me+too%21) |
34 |
35 |
36 | ## Contents:
37 |
38 | * [Quick Start](#Quick-Start)
39 | * [Install](#Install), try [this demo](https://colab.research.google.com/drive/1jQXzTYsgdZIQjJKrX4g3CP0_PGCeVU3C?usp=sharing), or use Hugging Face model card with this [demo](https://colab.research.google.com/drive/1cAtfkbhqsRsT59y3imjR1APw3MHDMkuV?usp=sharing)
40 | * [Use rankers only](#Use-rankers-only): use DialogRPT as evalution metric
41 | * [Use generator + ranker](#Use-generator-+-ranker): improve generator by reranking hypotheses with DialogRPT
42 | * [Data](#Data)
43 | * [Training](#Training)
44 | * [Evaluation](#Evaluation)
45 | * [Human feedback prediction](#Human-feedback-prediction)
46 | * [Human-like classification](#Human-like-classification)
47 | * [Citation](#Citation)
48 |
49 |
50 |
51 |
52 | ## Quick Start
53 |
54 |
55 | ### Install
56 |
57 | **Option 1**: run locally
58 | ```
59 | git clone https://github.com/golsun/DialogRPT
60 | cd DialogRPT
61 | conda create -n dialogrpt python=3.6
62 | conda activate dialogrpt
63 | pip install -r requirements.txt
64 | ```
65 |
66 | **Option 2**: run on Colab Notebook. You can either use [Demo (original)](https://colab.research.google.com/drive/1jQXzTYsgdZIQjJKrX4g3CP0_PGCeVU3C?usp=sharing) or [Demo (HuggingFace)](https://colab.research.google.com/drive/1cAtfkbhqsRsT59y3imjR1APw3MHDMkuV?usp=sharing)
67 |
68 |
69 |
70 | ### Use rankers only
71 | In the following example, the model predicts that, given the same context *"I love NLP!"*, response *"Here’s a free textbook (URL) in case anyone needs it."* is gets more upvotes than response *"Me too!"*.
72 | ```bash
73 | python src/score.py play -p=restore/updown.pth
74 | #
75 | # Context: I love NLP!
76 | # Response: Here’s a free textbook (URL) in case anyone needs it.
77 | # score = 0.613
78 |
79 | # Context: I love NLP!
80 | # Response: Me too!
81 | # score = 0.111
82 | ```
83 | You can also play the ensemble model, which involves multiple models defined in its [config file](restore/ensemble.yml) (see this file for details).
84 | ```bash
85 | python src/main.py play -p=restore/ensemble.yml
86 | ```
87 | To score a list of (context, response) pairs, please provide a input file (`--data`), which is tab-separated in format `context \t response0 \t response1 ...`. See example [input file](https://github.com/golsun/DialogRPT/blob/master/doc/toy.tsv)
88 | * Using a single ranker (see [expected output](https://github.com/golsun/DialogRPT/blob/master/doc/toy.tsv.updown.jsonl))
89 | ```bash
90 | python src/score.py test --data=doc/toy.tsv -p=restore/updown.pth
91 | # downloading pretrained model to restore/updown.pth
92 | # 100% [....................] 1520029114 / 1520029114
93 | # loading from restore/updown.pth
94 | # ranking doc/toy.tsv
95 | # totally processed 2 line, avg_hyp_score 0.264, top_hyp_score 0.409
96 | # results saved to doc/toy.tsv.ranked.jsonl
97 | ```
98 | * Using an ensemble model (see [expected output](https://github.com/golsun/DialogRPT/blob/master/doc/toy.tsv.ensemble.jsonl))
99 | ```bash
100 | python src/score.py test --data=doc/toy.tsv -p=restore/ensemble.yml
101 | ```
102 | Statistics of the scoring results can be shown with the following command, e.g. for `doc/toy.tsv.ensemble.jsonl`
103 | ```bash
104 | python src/score.py stats --data=doc/toy.tsv.ensemble.jsonl
105 | # |best |avg
106 | # ----------------------------------------
107 | # _score |0.339 |0.206
108 | # human_vs_rand |0.928 |0.861
109 | # human_vs_machine |0.575 |0.525
110 | # updown |0.409 |0.264
111 | # depth |0.304 |0.153
112 | # width |0.225 |0.114
113 | # final |0.339 |0.206
114 | # ----------------------------------------
115 | # n_cxt: 2
116 | # avg n_hyp per cxt: 2.50
117 | ```
118 |
119 |
120 | ### Use generator + ranker
121 | Dialog generation models can be improved by integrating with the response ranking models.
122 | For example, given the context *"Can we restart 2020?"*, DialoGPT may return the following responses by sampling decoding (or you can try beam search without `--sampling`). Some of them, e.g., *"Yes, we can."* has a high generation probability (`gen 0.496`), but less interesting (`ranker 0.302`). So the rankers will put in position lower than ones more likely to be upvoted, e.g. *"I think we should go back to the beginning, and start from the beginning."* which is relatively less likely to be generated (`gen 0.383`) but seems more interesting (`ranker 0.431`)
123 | ```bash
124 | python src/generation.py play -pg=restore/medium_ft.pkl -pr=restore/updown.pth --sampling
125 | #
126 | # Context: Can we restart 2020?
127 | # 0.431 gen 0.383 ranker 0.431 I think we should go back to the beginning, and start from the beginning.
128 | # 0.429 gen 0.227 ranker 0.429 I think I'll just sit here and wait for 2020
129 | # 0.377 gen 0.249 ranker 0.377 Yeah, let's just start from the beginning
130 | # 0.323 gen 0.195 ranker 0.323 I think we should just give up and let the year just pass.
131 | # 0.304 gen 0.395 ranker 0.304 Yes. We can.
132 | # 0.302 gen 0.496 ranker 0.302 Yes, we can.
133 | # 0.283 gen 0.351 ranker 0.283 It's been a while since we've seen a good reboot.
134 | # 0.174 gen 0.306 ranker 0.174 I'm up for it
135 | # 0.168 gen 0.463 ranker 0.168 I'm down
136 | # 0.153 gen 0.328 ranker 0.153 I think so, yes.
137 | # ...
138 | ```
139 | Similarly, you can use the [ensemble model](restore/ensemble.yml).
140 | ```
141 | python src/generation.py -pg=restore/medium_ft.pkl -pr=restore/ensemble.yml
142 | ```
143 | To generate from a list of contexts stored in a line-separated file, provide it with `--path_test` and use the command below:
144 | ```
145 | python src/generation.py test --path_test=path/to/list/of/contexts -pg=restore/medium_ft.pkl -pr=restore/ensemble.yml
146 | ```
147 |
148 |
149 | ## Data
150 |
151 | As the Pushshift Reddit dataset was deleted from [this server](https://files.pushshift.io/reddit), the data extraction pipeline of this release no longer works. As an alternative, you may want to use the Pushshift [API](https://github.com/pushshift/api).
152 |
153 | ## Training
154 | We use [DialoGPT](https://github.com/microsoft/DialoGPT) to initialize the model. Please download with
155 | ```
156 | wget https://convaisharables.blob.core.windows.net/lsp/multiref/medium_ft.pkl -P restore
157 | ```
158 | For the human feedback prediction tasks, we specify `min_score_gap` and `min_rank_gap` to only validate on less-noisy samples (not applied to training).
159 | ```
160 | python src/main.py train --data=data/out/updown -p=restore/medium_ft.pkl --min_score_gap=20 --min_rank_gap=0.5
161 | python src/main.py train --data=data/out/depth -p=restore/medium_ft.pkl --min_score_gap=4 --min_rank_gap=0.5
162 | python src/main.py train --data=data/out/width -p=restore/medium_ft.pkl --min_score_gap=4 --min_rank_gap=0.5
163 | ```
164 | For `human_vs_rand` task, use the `--mismatch` flag to feed rand human response as negative examples. We can reuse previous dataset (e.g. `data/out/updown`).
165 | ```
166 | python src/main.py train --data=data/out/updown -p=restore/medium_ft.pkl --mismatch
167 | ```
168 | For `human_vs_machine` task, we build dataset by pair human response with a response generated by [DialoGPT](https://github.com/microsoft/DialoGPT) with topk decoding
169 | ```
170 | python src/main.py train --data=data/out/human_vs_machine -p=restore/medium_ft.pkl
171 | ```
172 |
173 | We trained all models on a Nvidia V100 4-core GPU (each core with 32G memory) with the following hyperparameters. Checkpoint with the best validation accuracy is used as final model.
174 | | Argument | Value | Description |
175 | | :------------- | :-----------: |:------------- |
176 | | `batch` | 256 | total batch size for all GPUs. |
177 | | `vali_size` | 1024 | number of samples used for validation (i.e. dev set size). |
178 | | `lr` | 3e-05 | learning rate |
179 | | `max_seq_len` | 50 | max allowed sequence length.
if longer, leading tokens will be truncated |
180 | | `max_hr_gap` | 1 | max allowed hour difference between positive and negative samples.
If longer, this pair will be discarded for train/vali|
181 |
182 |
183 | ## Evaluation
184 |
185 | ### Human feedback prediction
186 |
187 | The performance on `updown`, `depth`, and `width` can be measured with the following commands, respectively.
188 | The `--min_score_gap` and `--min_rank_gap` arguments are consistent with the values used to measure validation loss during training.
189 | ```
190 | python src/score.py eval_human_feedback -p=restore/updown.pth --data=test/human_feedback/updown.tsv --min_score_gap=20 --min_rank_gap=0.5
191 | python src/score.py eval_human_feedback -p=restore/depth.pth --data=test/human_feedback/depth.tsv --min_score_gap=4 --min_rank_gap=0.5
192 | python src/score.py eval_human_feedback -p=restore/width.pth --data=test/human_feedback/width.tsv --min_score_gap=4 --min_rank_gap=0.5
193 | ```
194 |
195 | The expected pairwise accuracy on 5000 test samples is listed in the table below (from Table 5 of the [paper](https://arxiv.org/abs/2009.06978)). Note even by random guess one can get accuracy of 0.500.
196 | | human feedback | `updown` | `depth` | `width` |
197 | | :------------- | :------: |:------------: |:--------: |
198 | | Dialog ppl. | 0.488 | 0.508 | 0.513 |
199 | | Reverse dialog ppl. | 0.560 | 0.557 | 0.571 |
200 | | **DialogRPT** (ours)| **0.683** | **0.695** | **0.752** |
201 |
202 | ### Human-like classification
203 |
204 | * `human_vs_rand` task: Although the model is trained on `reddit` corpus only, we measured its **zero-shot** performance on several unseen corpora (`twitter`, `dailydialog` and `personachat`)
205 | ```bash
206 | python src/score.py eval_human_vs_rand -p=restore/human_vs_rand.pth --data=test/human_vs_fake/reddit
207 | python src/score.py eval_human_vs_rand -p=restore/human_vs_rand.pth --data=test/human_vs_fake/dailydialog
208 | python src/score.py eval_human_vs_rand -p=restore/human_vs_rand.pth --data=test/human_vs_fake/twitter
209 | python src/score.py eval_human_vs_rand -p=restore/human_vs_rand.pth --data=test/human_vs_fake/personachat
210 | ```
211 | The expected `hits@k` metric on 5000 test samples is listed in the table below (from Table 7 of the [paper](https://arxiv.org/abs/2009.06978)).
212 | `hits@k` measures, for the same context, given `k` positive responses and `n` negative responses, how many positive responses are in top-`k` of the ranked responses.
213 | | `human_vs_rand` | `reddit` | `dailydialog` | `twitter` | `personachat` |
214 | | :------------- | :------: |:------------: |:--------: |:------------: |
215 | | BM25 | 0.309 | 0.182 | 0.178 | 0.117 |
216 | | Dialog ppl. | 0.560 | 0.176 | 0.107 | 0.108 |
217 | | Reverse dialog ppl. | 0.775 | 0.457 | 0.440 | 0.449 |
218 | | [ConveRT](https://arxiv.org/abs/1911.03688) | 0.760 | 0.380 | 0.439 | 0.197 |
219 | | **DialogRPT** (ours)| **0.886** | **0.621** | **0.548** | **0.479** |
220 |
221 | * `human_vs_machine` task: its performance is only evaluated for `reddit` corpus.
222 | ```bash
223 | python src/score.py --task=eval_human_vs_machine -p=restore/human_vs_machine.pth --data=test/human_vs_fake/reddit
224 | # expecting accuracy ~0.98
225 | ```
226 |
227 |
228 | ## Citation
229 | If you use our dataset or model, please cite our [paper](https://arxiv.org/abs/2009.06978)
230 |
231 | ```
232 | @inproceedings{gao2020dialogrpt,
233 | title={Dialogue Response RankingTraining with Large-Scale Human Feedback Data},
234 | author={Xiang Gao and Yizhe Zhang and Michel Galley and Chris Brockett and Bill Dolan},
235 | year={2020},
236 | booktitle={EMNLP}
237 | }
238 | ```
239 |
--------------------------------------------------------------------------------
/data.sh:
--------------------------------------------------------------------------------
1 | # step 0. create the data folder
2 |
3 | mkdir "data/bz2"
4 |
5 | # Step 1. Download raw data from a third party dump: https://files.pushshift.io/reddit
6 |
7 | # download comments for year 2011
8 | wget https://files.pushshift.io/reddit/comments/RC_2011-01.bz2 -P data/bz2
9 | wget https://files.pushshift.io/reddit/comments/RC_2011-02.bz2 -P data/bz2
10 | wget https://files.pushshift.io/reddit/comments/RC_2011-03.bz2 -P data/bz2
11 | wget https://files.pushshift.io/reddit/comments/RC_2011-04.bz2 -P data/bz2
12 | wget https://files.pushshift.io/reddit/comments/RC_2011-05.bz2 -P data/bz2
13 | wget https://files.pushshift.io/reddit/comments/RC_2011-06.bz2 -P data/bz2
14 | wget https://files.pushshift.io/reddit/comments/RC_2011-07.bz2 -P data/bz2
15 | wget https://files.pushshift.io/reddit/comments/RC_2011-08.bz2 -P data/bz2
16 | wget https://files.pushshift.io/reddit/comments/RC_2011-09.bz2 -P data/bz2
17 | wget https://files.pushshift.io/reddit/comments/RC_2011-10.bz2 -P data/bz2
18 | wget https://files.pushshift.io/reddit/comments/RC_2011-11.bz2 -P data/bz2
19 | wget https://files.pushshift.io/reddit/comments/RC_2011-12.bz2 -P data/bz2
20 |
21 | # download comments for year 2012
22 | wget https://files.pushshift.io/reddit/comments/RC_2012-01.bz2 -P data/bz2
23 | wget https://files.pushshift.io/reddit/comments/RC_2012-02.bz2 -P data/bz2
24 | wget https://files.pushshift.io/reddit/comments/RC_2012-03.bz2 -P data/bz2
25 | wget https://files.pushshift.io/reddit/comments/RC_2012-04.bz2 -P data/bz2
26 | wget https://files.pushshift.io/reddit/comments/RC_2012-05.bz2 -P data/bz2
27 | wget https://files.pushshift.io/reddit/comments/RC_2012-06.bz2 -P data/bz2
28 | wget https://files.pushshift.io/reddit/comments/RC_2012-07.bz2 -P data/bz2
29 | wget https://files.pushshift.io/reddit/comments/RC_2012-08.bz2 -P data/bz2
30 | wget https://files.pushshift.io/reddit/comments/RC_2012-09.bz2 -P data/bz2
31 | wget https://files.pushshift.io/reddit/comments/RC_2012-10.bz2 -P data/bz2
32 | wget https://files.pushshift.io/reddit/comments/RC_2012-11.bz2 -P data/bz2
33 | wget https://files.pushshift.io/reddit/comments/RC_2012-12.bz2 -P data/bz2
34 |
35 | # download submissions for year 2011
36 | wget https://files.pushshift.io/reddit/submissions/RS_2011-01.bz2 -P data/bz2
37 | wget https://files.pushshift.io/reddit/submissions/RS_2011-02.bz2 -P data/bz2
38 | wget https://files.pushshift.io/reddit/submissions/RS_2011-03.bz2 -P data/bz2
39 | wget https://files.pushshift.io/reddit/submissions/RS_2011-04.bz2 -P data/bz2
40 | wget https://files.pushshift.io/reddit/submissions/RS_2011-05.bz2 -P data/bz2
41 | wget https://files.pushshift.io/reddit/submissions/RS_2011-06.bz2 -P data/bz2
42 | wget https://files.pushshift.io/reddit/submissions/RS_2011-07.bz2 -P data/bz2
43 | wget https://files.pushshift.io/reddit/submissions/RS_2011-08.bz2 -P data/bz2
44 | wget https://files.pushshift.io/reddit/submissions/RS_2011-09.bz2 -P data/bz2
45 | wget https://files.pushshift.io/reddit/submissions/RS_2011-10.bz2 -P data/bz2
46 | wget https://files.pushshift.io/reddit/submissions/RS_2011-11.bz2 -P data/bz2
47 | wget https://files.pushshift.io/reddit/submissions/RS_2011-12.bz2 -P data/bz2
48 |
49 | # download submissions for year 2011
50 | wget https://files.pushshift.io/reddit/submissions/RS_2012-01.bz2 -P data/bz2
51 | wget https://files.pushshift.io/reddit/submissions/RS_2012-02.bz2 -P data/bz2
52 | wget https://files.pushshift.io/reddit/submissions/RS_2012-03.bz2 -P data/bz2
53 | wget https://files.pushshift.io/reddit/submissions/RS_2012-04.bz2 -P data/bz2
54 | wget https://files.pushshift.io/reddit/submissions/RS_2012-05.bz2 -P data/bz2
55 | wget https://files.pushshift.io/reddit/submissions/RS_2012-06.bz2 -P data/bz2
56 | wget https://files.pushshift.io/reddit/submissions/RS_2012-07.bz2 -P data/bz2
57 | wget https://files.pushshift.io/reddit/submissions/RS_2012-08.bz2 -P data/bz2
58 | wget https://files.pushshift.io/reddit/submissions/RS_2012-09.bz2 -P data/bz2
59 | wget https://files.pushshift.io/reddit/submissions/RS_2012-10.bz2 -P data/bz2
60 | wget https://files.pushshift.io/reddit/submissions/RS_2012-11.bz2 -P data/bz2
61 | wget https://files.pushshift.io/reddit/submissions/RS_2012-12.bz2 -P data/bz2
62 |
63 | # Step 2. Read the `.bz2` files and group items from the same subreddit
64 |
65 | python src/data.py bz2 2011
66 | python src/data.py bz2 2012
67 |
68 | # Step 3. extract basic attributes and dialog trees.
69 |
70 | python src/data.py basic 2011
71 | python src/data.py basic 2012
72 |
73 | # Step 4. Build training and testing data for different feedback signals.
74 |
75 | python src/data.py updown 2011 --year_to=2012
76 | python src/data.py depth 2011 --year_to=2012
77 | python src/data.py width 2011 --year_to=2012
--------------------------------------------------------------------------------
/doc/DialogRPT-EMNLP.pdf:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/golsun/DialogRPT/ac9954a784cc88071e5fa309c2afdace2e9b38d7/doc/DialogRPT-EMNLP.pdf
--------------------------------------------------------------------------------
/doc/demo.PNG:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/golsun/DialogRPT/ac9954a784cc88071e5fa309c2afdace2e9b38d7/doc/demo.PNG
--------------------------------------------------------------------------------
/doc/icon.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/golsun/DialogRPT/ac9954a784cc88071e5fa309c2afdace2e9b38d7/doc/icon.png
--------------------------------------------------------------------------------
/doc/toy.tsv:
--------------------------------------------------------------------------------
1 | I love NLP! Here’s a free textbook (URL) in case anyone needs it. Me too! It’s super useful and more and more powerful!
2 | How are you? Not bad. pretty good! :)
--------------------------------------------------------------------------------
/doc/toy.tsv.ensemble.jsonl:
--------------------------------------------------------------------------------
1 | {"line_id": 0, "cxt": "I love NLP!", "hyps": [[0.5319748520851135, {"human_vs_rand": 0.875923216342926, "human_vs_machine": 0.6866590976715088, "updown": 0.6128572225570679, "depth": 0.4968094825744629, "width": 0.3681032657623291, "final": 0.5319748520851135}, "Here\u2019s a free textbook (URL) in case anyone needs it."], [0.22798165678977966, {"human_vs_rand": 0.7438808679580688, "human_vs_machine": 0.6686411499977112, "updown": 0.29050877690315247, "depth": 0.10862986743450165, "width": 0.052612606436014175, "final": 0.22798165678977966}, "It\u2019s super useful and more and more powerful!"], [0.0703396126627922, {"human_vs_rand": 0.630566418170929, "human_vs_machine": 0.6046908497810364, "updown": 0.11055243015289307, "depth": 0.031935129314661026, "width": 0.028544878587126732, "final": 0.0703396126627922}, "Me too!"]]}
2 | {"line_id": 1, "cxt": "How are you?", "hyps": [[0.14587044715881348, {"human_vs_rand": 0.9802649021148682, "human_vs_machine": 0.3306310772895813, "updown": 0.20513156056404114, "depth": 0.111829474568367, "width": 0.08141987770795822, "final": 0.14587044715881348}, "pretty good! :)"], [0.12652477622032166, {"human_vs_rand": 0.962352991104126, "human_vs_machine": 0.462787926197052, "updown": 0.17496052384376526, "depth": 0.0771029144525528, "width": 0.07592014223337173, "final": 0.12652477622032166}, "Not bad."]]}
--------------------------------------------------------------------------------
/doc/toy.tsv.updown.jsonl:
--------------------------------------------------------------------------------
1 | {"line_id": 0, "cxt": "I love NLP!", "hyps": [[0.6128574013710022, "Here\u2019s a free textbook (URL) in case anyone needs it."], [0.2905086874961853, "It\u2019s super useful and more and more powerful!"], [0.11055240780115128, "Me too!"]]}
2 | {"line_id": 1, "cxt": "How are you?", "hyps": [[0.20513157546520233, "pretty good! :)"], [0.17496058344841003, "Not bad."]]}
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | boto3==1.14.59
2 | botocore==1.17.59
3 | certifi==2020.6.20
4 | chardet==3.0.4
5 | cycler==0.10.0
6 | docker==3.2.1
7 | docutils==0.15.2
8 | future==0.18.2
9 | idna==2.10
10 | iterable-queue==1.2.2
11 | jmespath==0.10.0
12 | kiwisolver==1.2.0
13 | matplotlib==3.3.1
14 | ntlm-auth==1.2.0
15 | numpy==1.19.2
16 | Pillow==7.2.0
17 | pyparsing==2.4.7
18 | python-dateutil==2.8.1
19 | PyYAML==5.3.1
20 | regex==2020.7.14
21 | requests==2.24.0
22 | requests-ntlm==1.1.0
23 | s3transfer==0.3.3
24 | scipy==1.5.2
25 | six==1.15.0
26 | smmap2==2.0.3
27 | torch==1.6.0
28 | tqdm==4.48.2
29 | urllib3==1.25.10
30 | wincertstore==0.2
31 |
--------------------------------------------------------------------------------
/restore/ensemble.yml:
--------------------------------------------------------------------------------
1 | # see model.py/JointScorer.core for details
2 | # the `prior` score is the weighted average of `human_vs_rand` and `human_vs_machine` predictions,
3 | # and `cond` is the weighted average of `updown`, `depth`and `width` predictions.
4 | # The final score is the product of `prior` score and `cond` score
5 |
6 | prior:
7 |
8 | - name: human_vs_rand
9 | wt: 0.5
10 | path: restore/human_vs_rand.pth
11 |
12 | - name: human_vs_machine
13 | wt: 0.5
14 | path: restore/human_vs_machine.pth
15 |
16 | cond:
17 |
18 | - name: updown
19 | wt: 1
20 | path: restore/updown.pth
21 |
22 | - name: depth
23 | wt: 0.48
24 | path: restore/depth.pth
25 |
26 | - name: width
27 | wt: -0.5
28 | path: restore/width.pth
--------------------------------------------------------------------------------
/src/attic/data.py:
--------------------------------------------------------------------------------
1 | # author: Xiang Gao at Microsoft Research AI NLP Group
2 |
3 |
4 | import bz2, json, os, pickle, pdb, time, random
5 | import urllib.request
6 | import numpy as np
7 | from shared import _cat_
8 |
9 |
10 | def valid_sub(sub):
11 | if sub.upper() in [
12 | 'CON', 'PRN', 'AUX', 'NUL', 'COM1', 'COM2', 'COM3', 'COM4', 'COM5', 'COM6', 'COM7',
13 | 'COM8', 'COM9', 'LPT1', 'LPT2', 'LPT3', 'LPT4', 'LPT5', 'LPT6', 'LPT7', 'LPT8', 'LPT9']:
14 | # not allowed by Windows system
15 | return False
16 | if ':' in sub:
17 | return False
18 | return True
19 |
20 |
21 | def get_dates(year_from, year_to=None):
22 | if year_to is None:
23 | year_to = year_from
24 | dates = []
25 | for year in range(year_from, year_to + 1):
26 | for _mo in range(1, 12 + 1):
27 | mo = str(_mo)
28 | if len(mo) == 1:
29 | mo = '0' + mo
30 | dates.append(str(year) + '-' + mo)
31 | return dates
32 |
33 |
34 | def extract_rc(date):
35 | path_bz2 = '%s/RC_%s.bz2'%(fld_bz2, date)
36 | nodes = dict()
37 | edges = dict()
38 | subs = set()
39 | n = 0
40 | m = 0
41 | kk = ['body', 'link_id', 'name', 'parent_id', 'subreddit']
42 |
43 | def save(nodes, edges):
44 | for sub in nodes:
45 | fld = fld_jsonl + '/' + sub
46 | try:
47 | os.makedirs(fld, exist_ok=True)
48 | except NotADirectoryError as e:
49 | print(e)
50 | continue
51 | if sub not in subs:
52 | open(fld + '/%s_nodes.jsonl'%date, 'w', encoding="utf-8")
53 | open(fld + '/%s_edges.tsv'%date, 'w', encoding="utf-8")
54 | subs.add(sub)
55 | with open(fld + '/%s_nodes.jsonl'%date, 'a', encoding="utf-8") as f:
56 | f.write('\n'.join(nodes[sub]) + '\n')
57 | with open(fld + '/%s_edges.tsv'%date, 'a', encoding="utf-8") as f:
58 | f.write('\n'.join(edges[sub]) + '\n')
59 |
60 | for line in bz2.open(path_bz2, 'rt', encoding="utf-8"):
61 | n += 1
62 | line = line.strip('\n')
63 | try:
64 | node = json.loads(line)
65 | except Exception:
66 | continue
67 |
68 | ok = True
69 | for k in kk:
70 | if k not in node:
71 | ok = False
72 | break
73 | if not ok:
74 | break
75 |
76 | if not valid_sub(node['subreddit']):
77 | continue
78 |
79 | if node['subreddit'] not in nodes:
80 | nodes[node['subreddit']] = []
81 | edges[node['subreddit']] = []
82 | nodes[node['subreddit']].append(line)
83 | edges[node['subreddit']].append('%s\t%s\t%s'%(node['link_id'], node['parent_id'], node['name']))
84 |
85 | m += 1
86 | if m % 1e5 == 0:
87 | save(nodes, edges)
88 | print('[RC_%s] saved %.2f/%.2f M, %i subreddits'%(date, m/1e6, n/1e6, len(subs)))
89 | nodes = dict()
90 | edges = dict()
91 |
92 | save(nodes, edges)
93 | print('[RC_%s] FINAL %.2f/%.2f M, %i subreddits ================'%(date, m/1e6, n/1e6, len(subs)))
94 | with open(fld_jsonl + '/readme.txt', 'a', encoding='utf-8') as f:
95 | f.write('[%s] saved %i/%i\n'%(date, m, n))
96 |
97 |
98 | def extract_rs(date):
99 | path_bz2 = '%s/RS_%s.bz2'%(fld_bz2, date)
100 | roots = dict()
101 | subs = set()
102 | n = 0
103 | m = 0
104 | kk = ['selftext', 'id', 'title', 'subreddit']
105 |
106 | def save(roots):
107 | for sub in roots:
108 | fld = fld_jsonl + '/' + sub
109 | try:
110 | os.makedirs(fld, exist_ok=True)
111 | except NotADirectoryError as e:
112 | print(e)
113 | continue
114 | if sub not in subs:
115 | open(fld + '/%s_roots.jsonl'%date, 'w', encoding="utf-8")
116 | subs.add(sub)
117 | with open(fld + '/%s_roots.jsonl'%date, 'a', encoding="utf-8") as f:
118 | f.write('\n'.join(roots[sub]) + '\n')
119 |
120 | for line in bz2.open(path_bz2, 'rt', encoding="utf-8"):
121 | n += 1
122 | line = line.strip('\n')
123 | try:
124 | root = json.loads(line)
125 | except Exception:
126 | continue
127 |
128 | ok = True
129 | for k in kk:
130 | if k not in root:
131 | ok = False
132 | break
133 | if not ok:
134 | break
135 | if not valid_sub(root['subreddit']):
136 | continue
137 |
138 | # some bz2, e.g. 2012-09, doesn't have the `name` entry
139 | if 'name' not in root:
140 | root['name'] = 't3_' + root['id']
141 |
142 | if root['subreddit'] not in roots:
143 | roots[root['subreddit']] = []
144 | roots[root['subreddit']].append(line)
145 |
146 | m += 1
147 | if m % 1e4 == 0:
148 | save(roots)
149 | print('[RS_%s] saved %.2f/%.2f M, %i subreddits'%(date, m/1e6, n/1e6, len(subs)))
150 | roots = dict()
151 |
152 | save(roots)
153 | print('[RS_%s] FINAL %.2f/%.2f M, %i subreddits ================'%(
154 | date, m/1e6, n/1e6, len(subs)))
155 | with open(fld_jsonl + '/readme_roots.txt', 'a', encoding='utf-8') as f:
156 | f.write('[%s] saved %i/%i\n'%(date, m, n))
157 |
158 |
159 |
160 | def extract_txt(sub, year, tokenizer, overwrite=False, max_subword=3):
161 | fld = '%s/%s'%(fld_subs, sub)
162 | os.makedirs(fld, exist_ok=True)
163 | path_out = '%s/%i_txt.tsv'%(fld, year)
164 | path_done = path_out + '.done'
165 | if not overwrite and os.path.exists(path_done):
166 | return
167 |
168 | dates = get_dates(year)
169 | open(path_out, 'w', encoding='utf-8')
170 |
171 | def clean(txt):
172 | if txt.strip() in ['[deleted]', '[removed]']:
173 | return None
174 | if '>' in txt or '>' in txt: # no comment in line ('>' means '>')
175 | return None
176 |
177 | # deal with URL
178 | txt = txt.replace('](','] (')
179 | ww = []
180 | for w in txt.split():
181 | if len(w) == 0:
182 | continue
183 | if '://' in w.lower() or 'http' in w.lower():
184 | ww.append('(URL)')
185 | else:
186 | ww.append(w)
187 | if not ww:
188 | return None
189 | if len(ww) > 30: # focus on dialog, so ignore long txt
190 | return None
191 | if len(ww) < 1:
192 | return None
193 | txt = ' '.join(ww)
194 | for c in ['\t', '\n', '\r']: # delimiter or newline
195 | txt = txt.replace(c, ' ')
196 |
197 | ids = tokenizer.encode(txt)
198 | if len(ids) / len(ww) > max_subword: # usually < 1.5. too large means too many unknown words
199 | return None
200 |
201 | ids = ' '.join([str(x) for x in ids])
202 | return txt, ids
203 |
204 | lines = []
205 | m = 0
206 | n = 0
207 | name_set = set()
208 | for date in dates:
209 | path = '%s/%s/%s_nodes.jsonl'%(fld_jsonl, sub, date)
210 | if not os.path.exists(path):
211 | continue
212 | for line in open(path, encoding='utf-8'):
213 | n += 1
214 | d = json.loads(line.strip('\n'))
215 | if d['name'] in name_set:
216 | continue
217 | name_set.add(d['name'])
218 | txt_ids = clean(d['body'])
219 | if txt_ids is not None:
220 | txt, ids = txt_ids
221 | lines.append('%s\t%s\t%s'%(d['name'], txt, ids))
222 | m += 1
223 | if m % 1e4 == 0:
224 | with open(path_out, 'a', encoding='utf-8') as f:
225 | f.write('\n'.join(lines) + '\n')
226 | lines = []
227 |
228 | for date in dates:
229 | path = '%s/%s/%s_roots.jsonl'%(fld_jsonl, sub, date)
230 | if not os.path.exists(path):
231 | continue
232 | for line in open(path, encoding='utf-8'):
233 | n += 1
234 | d = json.loads(line.strip('\n'))
235 | if 'name' not in d:
236 | d['name'] = 't3_' + d['id']
237 | if d['name'] in name_set:
238 | continue
239 | name_set.add(d['name'])
240 | txt_ids = clean(d['title'] + ' ' + d['selftext'])
241 | if txt_ids is not None:
242 | txt, ids = txt_ids
243 | lines.append('%s\t%s\t%s'%(d['name'], txt, ids))
244 | m += 1
245 | if m % 1e4 == 0:
246 | with open(path_out, 'a', encoding='utf-8') as f:
247 | f.write('\n'.join(lines) + '\n')
248 | lines = []
249 | if lines:
250 | with open(path_out, 'a', encoding='utf-8') as f:
251 | f.write('\n'.join(lines))
252 |
253 | s = '[%s %s] txt kept %i/%i'%(sub, year, m, n)
254 | with open(path_done, 'w') as f:
255 | f.write(s)
256 | print(s)
257 |
258 |
259 | def extract_trees(sub, year):
260 | fld = '%s/%s'%(fld_subs, sub)
261 | os.makedirs(fld, exist_ok=True)
262 | path_out = '%s/%i_trees.pkl'%(fld, year)
263 | if os.path.exists(path_out):
264 | return
265 |
266 | trees = dict()
267 | n = 0
268 | for date in get_dates(year):
269 | path = '%s/%s/%s_edges.tsv'%(fld_jsonl, sub, date)
270 | if not os.path.exists(path):
271 | #print('no such file: '+path)
272 | continue
273 | for line in open(path, encoding='utf-8'):
274 | n += 1
275 | link, parent, child = line.strip('\n').split('\t')
276 | if link not in trees:
277 | trees[link] = dict()
278 | trees[link][(parent, child)] = date
279 |
280 | if not trees:
281 | return
282 |
283 | print('[%s %i] %i trees %.1f nodes/tree'%(sub, year, len(trees), n/len(trees)))
284 | os.makedirs(fld, exist_ok=True)
285 | pickle.dump(trees, open(path_out, 'wb'))
286 |
287 |
288 | def extract_time(sub, year, overwrite=False):
289 | fld = '%s/%s'%(fld_subs, sub)
290 | os.makedirs(fld, exist_ok=True)
291 | path_out = '%s/%i_time.tsv'%(fld, year)
292 | path_done = path_out + '.done'
293 | if not overwrite and os.path.exists(path_done):
294 | return
295 | dates = get_dates(year)
296 | suffix = 'nodes'
297 | os.makedirs(fld, exist_ok=True)
298 | open(path_out, 'w', encoding='utf-8')
299 |
300 | lines = []
301 | m = 0
302 | n = 0
303 | name_set = set()
304 | for date in dates:
305 | path = '%s/%s/%s_%s.jsonl'%(fld_jsonl, sub, date, suffix)
306 | if not os.path.exists(path):
307 | continue
308 | for line in open(path, encoding='utf-8'):
309 | n += 1
310 | d = json.loads(line.strip('\n'))
311 | if 'name' not in d:
312 | d['name'] = 't3_' + d['id']
313 | if d['name'] in name_set:
314 | continue
315 | name_set.add(d['name'])
316 | t = d['created_utc']
317 | lines.append('%s\t%s'%(d['name'], t))
318 | m += 1
319 | if m % 1e4 == 0:
320 | with open(path_out, 'a', encoding='utf-8') as f:
321 | f.write('\n'.join(lines) + '\n')
322 | lines = []
323 | with open(path_out, 'a', encoding='utf-8') as f:
324 | f.write('\n'.join(lines))
325 |
326 | s = '[%s %s] time kept %i/%i'%(sub, year, m, n)
327 | with open(path_done, 'w') as f:
328 | f.write(s)
329 | print(s)
330 |
331 |
332 |
333 | def calc_feedback(sub, year, overwrite=False):
334 | fld = '%s/%s'%(fld_subs, sub)
335 | path_out = '%s/%i_feedback.tsv'%(fld, year)
336 | path_done = path_out + '.done'
337 | if not overwrite and os.path.exists(path_done):
338 | return
339 |
340 | path_pkl = '%s/%i_trees.pkl'%(fld, year)
341 | if not os.path.exists(path_pkl):
342 | return
343 | trees = pickle.load(open(path_pkl, 'rb'))
344 | if not trees:
345 | return
346 |
347 | dates = get_dates(year)
348 | updown = dict()
349 | for date in dates:
350 | path = '%s/%s/%s_nodes.jsonl'%(fld_jsonl, sub, date)
351 | if not os.path.exists(path):
352 | continue
353 | for line in open(path, encoding='utf-8'):
354 | d = json.loads(line.strip('\n'))
355 | updown[d['name']] = d['ups'] - d['downs']
356 |
357 | if not updown:
358 | print('empty updown:')
359 | return
360 |
361 | with open(path_out, 'w', encoding='utf-8') as f:
362 | f.write('\t'.join(['#path', 'vol', 'width', 'depth', 'updown']) + '\n')
363 |
364 | print('[%s %s] calculating scores for %i trees'%(sub, year, len(trees)))
365 |
366 | n_tree = 0
367 | n_node = 0
368 | for root in trees:
369 | tree = trees[root]
370 | children = dict()
371 | for parent, child in tree:
372 | if parent not in children:
373 | children[parent] = []
374 | children[parent].append(child)
375 | if root not in children:
376 | continue
377 |
378 | # BFS to get all paths from root to leaf
379 | q = [[root]]
380 | paths = []
381 | while q:
382 | qsize = len(q)
383 | for _ in range(qsize):
384 | path = q.pop(0)
385 | head = path[-1]
386 | if head not in children: # then head is a leaf
387 | paths.append(path)
388 | continue
389 | for child in children[head]:
390 | q.append(path + [child])
391 |
392 | prev = dict()
393 | for path in paths:
394 | for i in range(1, len(path)):
395 | prev[path[i]] = ' '.join(path[:i + 1])
396 |
397 | descendant = dict()
398 | longest_subpath = dict()
399 | while paths:
400 | path = paths.pop(0)
401 | node = path[0]
402 | if node not in descendant:
403 | descendant[node] = set()
404 | longest_subpath[node] = 0
405 | descendant[node] |= set(path[1:])
406 | longest_subpath[node] = max(longest_subpath[node], len(path) - 1)
407 | if len(path) > 1:
408 | paths.append(path[1:])
409 |
410 | sorted_nodes = sorted([(len(prev[node].split()), prev[node], node) for node in prev])
411 | if not sorted_nodes:
412 | continue
413 |
414 | n_tree += 1
415 | lines = []
416 | for _, _, node in sorted_nodes:
417 | if node == root:
418 | continue
419 | if node not in updown:
420 | continue
421 | n_node += 1
422 | lines.append('%s\t%i\t%i\t%i\t%i'%(
423 | prev[node], # turns: path from its root to this node
424 | len(descendant[node]), # vol: num of descendants of this node
425 | len(children.get(node, [])), # width: num of direct childrent of this node
426 | longest_subpath[node], # depth: num of longest subpath of this node
427 | updown[node], # updown: `upvotes - downvotes` of this node
428 | ))
429 | with open(path_out, 'a', encoding='utf-8') as f:
430 | f.write('\n'.join(lines) + '\n')
431 |
432 | if n_tree:
433 | s = '[%s %s] %i tree %i nodes'%(sub, year, n_tree, n_node)
434 | else:
435 | s = '[%s %s] trees are empty'%(sub, year)
436 | with open(path_done, 'w') as f:
437 | f.write(s)
438 | print(s)
439 |
440 |
441 | def create_pairs(year, sub, feedback, overwrite=False):
442 | fld = '%s/%s'%(fld_subs, sub)
443 | path_out = '%s/%i_%s.tsv'%(fld, year, feedback)
444 | path_done = path_out + '.done'
445 | if not overwrite and os.path.exists(path_done):
446 | return
447 |
448 | ix_feedback = ['vol', 'width', 'depth', 'updown'].index(feedback) + 1
449 | path_in = '%s/%i_feedback.tsv'%(fld, year)
450 | if not os.path.exists(path_in):
451 | return
452 |
453 | time = dict()
454 | path_time = '%s/%i_time.tsv'%(fld, year)
455 | if not os.path.exists(path_time):
456 | return
457 | for line in open(path_time):
458 | ss = line.strip('\n').split('\t')
459 | if len(ss) == 2:
460 | name, t = ss
461 | time[name] = int(t)
462 |
463 | open(path_out, 'w', encoding='utf-8')
464 | print('[%s %s] creating pairs...'%(sub, year))
465 |
466 | def match_time(replies, cxt):
467 | scores = sorted(set([score for score, _ in replies]))
468 | m = len(scores)
469 | if m < 2:
470 | return 0 # can't create pairs if m < 2
471 | cand = []
472 | for score, reply in replies:
473 | if reply not in time:
474 | continue
475 | cand.append((time[reply], score, reply))
476 | cand = sorted(cand)
477 | rank = [scores.index(score) / (m - 1) for _, score, _ in cand]
478 | lines = []
479 | for i in range(len(cand) - 1):
480 | t_a, score_a, a = cand[i]
481 | t_b, score_b, b = cand[i + 1]
482 | rank_a = rank[i]
483 | rank_b = rank[i + 1]
484 | if score_a == score_b:
485 | continue
486 | hr = (t_b - t_a)/3600
487 | if score_b > score_a:
488 | score_a, score_b = score_b, score_a
489 | a, b = b, a
490 | rank_a, rank_b = rank_b, rank_a
491 | lines.append('\t'.join([
492 | cxt,
493 | a,
494 | b,
495 | '%.2f'%hr,
496 | '%i'%score_a,
497 | '%i'%score_b,
498 | '%.4f'%rank_a,
499 | '%.4f'%rank_b,
500 | ]))
501 | #pdb.set_trace()
502 | if lines:
503 | with open(path_out, 'a') as f:
504 | f.write('\n'.join(lines) + '\n')
505 | return len(lines)
506 |
507 | n_line = 0
508 | prev = None
509 | replies = []
510 | for line in open(path_in):
511 | if line.startswith('#'):
512 | continue
513 | ss = line.strip('\n').split('\t')
514 | turns = ss[0].split() # including both cxt and resp
515 | if len(turns) < 2:
516 | continue
517 | reply = turns[-1]
518 | try:
519 | score = int(ss[ix_feedback])
520 | except ValueError:
521 | continue
522 | parent = turns[-2]
523 | if parent == prev:
524 | replies.append((score, reply))
525 | else:
526 | if replies:
527 | n_line += match_time(replies, cxt)
528 | cxt = ' '.join(turns[:-1])
529 | prev = parent
530 | replies = [(score, reply)]
531 | if replies:
532 | n_line += match_time(replies, cxt)
533 |
534 | s = '[%s %s %s] %i pairs'%(sub, year, feedback, n_line)
535 | with open(path_done, 'w') as f:
536 | f.write(s)
537 | print(s)
538 |
539 |
540 | def add_seq(sub, year, feedback, overwrite=False):
541 | fname = '%i_%s'%(year, feedback)
542 | fld = '%s/%s'%(fld_subs, sub)
543 | turn_sep = ' 50256 '
544 | path_out = fld + '/%s_ids.tsv'%fname
545 | path_done = path_out + '.done'
546 |
547 | if os.path.exists(path_done) and not overwrite:
548 | return
549 | if not os.path.exists(fld + '/%s.tsv'%fname):
550 | return
551 |
552 | seq = dict()
553 | path = '%s/%i_txt.tsv'%(fld, year)
554 | if not os.path.exists(path):
555 | return
556 | for line in open(path, encoding='utf-8'):
557 | ss = line.strip('\n').split('\t')
558 | if len(ss) != 3:
559 | continue
560 | name, txt, ids = ss
561 | seq[name] = ids
562 | print('loaded %i seq'%len(seq))
563 | open(path_out, 'w', encoding='utf-8')
564 | print('[%s %s %s] adding seq'%(sub, year, feedback))
565 | path = fld + '/%s.tsv'%fname
566 | lines = []
567 | n = 0
568 | m = 0
569 | for line in open(path, encoding='utf-8'):
570 | line = line.strip('\n')
571 | if line.startswith('#'):
572 | continue
573 |
574 | n += 1
575 | ss = line.split('\t')
576 | if len(ss) < 7:
577 | continue
578 | name_cxt, name_pos, name_neg = ss[:3]
579 |
580 | cxt = []
581 | ok = True
582 | for name in name_cxt.split():
583 | if name in seq:
584 | cxt.append(seq[name])
585 | else:
586 | ok = False
587 | break
588 | if not ok:
589 | continue
590 | cxt = turn_sep.join(cxt)
591 |
592 | if name_pos in seq:
593 | reply_pos = seq[name_pos]
594 | else:
595 | continue
596 | if name_neg in seq:
597 | reply_neg = seq[name_neg]
598 | else:
599 | continue
600 |
601 | lines.append('\t'.join([
602 | cxt, reply_pos, reply_neg,
603 | name_cxt, name_pos, name_neg,
604 | ] + ss[3:]))
605 | m += 1
606 | if m % 1e4 == 0:
607 | with open(path_out, 'a', encoding='utf-8') as f:
608 | f.write('\n'.join(lines) + '\n')
609 | lines = []
610 |
611 | with open(path_out, 'a', encoding='utf-8') as f:
612 | f.write('\n'.join(lines))
613 |
614 | s = '[%s %s %s] pair seq %i/%i'%(sub, year, feedback, m, n)
615 | with open(path_done, 'w') as f:
616 | f.write(s)
617 | print(s)
618 |
619 |
620 | def combine_sub(year_from, year_to, feedback, overwrite=False, skip_same_pos=True):
621 | fld = '%s/%s'%(fld_out, feedback)
622 | os.makedirs(fld, exist_ok=True)
623 | path_out = fld + '/raw.tsv'
624 | path_done = path_out + '.done'
625 | if os.path.exists(path_done) and not overwrite:
626 | return path_out
627 |
628 | subs = sorted(os.listdir(fld_subs))
629 | open(path_out, 'w', encoding='utf-8')
630 | lines = []
631 | n = 0
632 | empty = True
633 | non_empty_subreddits = 0
634 | for sub in subs:
635 | empty = True
636 | for year in range(year_from, year_to + 1):
637 | path = '%s/%s/%i_%s_ids.tsv'%(fld_subs, sub, year, feedback)
638 | if not os.path.exists(path):
639 | continue
640 | for line in open(path, encoding='utf-8'):
641 | if line.startswith('#'):
642 | continue
643 | line = line.strip('\n')
644 | if not line:
645 | continue
646 | lines.append(line)
647 | empty = False
648 | n += 1
649 | if n % 1e5 == 0:
650 | with open(path_out, 'a', encoding='utf-8') as f:
651 | f.write('\n'.join(lines) + '\n')
652 | lines = []
653 | s = '[%i %s] saved %.2f M lines from %i subreddits, now is %s'%(year, feedback, n/1e6, non_empty_subreddits + 1, sub)
654 | print(s)
655 | if not empty:
656 | non_empty_subreddits += 1
657 |
658 | with open(path_out, 'a', encoding='utf-8') as f:
659 | f.write('\n'.join(lines))
660 | s = '[%i-%i %s] saved %.2f M lines from %i subreddits'%(year_from, year_to, feedback, n/1e6, non_empty_subreddits )
661 | with open(path_done, 'w') as f:
662 | f.write(s)
663 | print(s)
664 | return path_out
665 |
666 |
667 |
668 | def split_by_root(path, p_test=0.01):
669 |
670 | print('spliting by root '+path)
671 | lines = {
672 | 'train': [],
673 | 'vali': [],
674 | }
675 | prev = None
676 | n = 0
677 |
678 | for k in lines:
679 | if len(lines[k]) == 0:
680 | continue
681 | open(path + '.' + k, 'w', encoding='utf-8')
682 |
683 | for line in open(path, encoding='utf-8'):
684 | line = line.strip('\n')
685 | if not line:
686 | continue
687 | cxt = line.split('\t')[3]
688 | root = cxt.strip().split()[0]
689 | if root != prev:
690 | if np.random.random() < p_test:
691 | k = 'vali'
692 | else:
693 | k = 'train'
694 | #pdb.set_trace()
695 | lines[k].append(line)
696 | prev = root
697 | n += 1
698 | if n % 1e6 == 0:
699 | print('read %i M'%(n/1e6))
700 | for k in lines:
701 | if len(lines[k]) == 0:
702 | continue
703 | with open(path + '.' + k, 'a', encoding='utf-8') as f:
704 | f.write('\n'.join(lines[k]) + '\n')
705 | lines[k] = []
706 |
707 | for k in lines:
708 | if len(lines[k]) == 0:
709 | continue
710 | with open(path + '.' + k, 'a', encoding='utf-8') as f:
711 | f.write('\n'.join(lines[k]))
712 | lines[k] = []
713 |
714 |
715 | def shuffle(feedback, part, n_temp=10):
716 | fld = '%s/%s'%(fld_out, feedback)
717 | path = '%s/raw.tsv.%s'%(fld, part)
718 | path_out = '%s/%s.tsv'%(fld, part)
719 | fld_temp = '%s/temp/%s'%(fld_out, feedback)
720 |
721 | print('slicing '+path)
722 | os.makedirs(fld_temp, exist_ok=True)
723 | lines = [[] for _ in range(n_temp)]
724 |
725 | # split into n_temp files
726 | for i in range(n_temp):
727 | open(fld_temp + '/temp%i'%i, 'w', encoding='utf-8')
728 | n = 0
729 | count = [0] * n_temp
730 | rand = np.random.randint(0, n_temp, 202005)
731 | for line in open(path, encoding='utf-8'):
732 | line = line.strip('\n')
733 | if len(line) == 0:
734 | continue
735 | bucket = rand[n % len(rand)]
736 | lines[bucket].append(line)
737 | count[bucket] += 1
738 | n += 1
739 | if n % 1e6 == 0:
740 | print('read %i M'%(n/1e6))
741 | for i in range(n_temp):
742 | if len(lines[i]) == 0:
743 | continue
744 | with open(fld_temp + '/temp%i'%i, 'a', encoding='utf-8') as f:
745 | f.write('\n'.join(lines[i]) + '\n')
746 | lines[i] = []
747 |
748 | for i in range(n_temp):
749 | with open(fld_temp + '/temp%i'%i, 'a', encoding='utf-8') as f:
750 | f.write('\n'.join(lines[i]))
751 |
752 | # and then merge
753 | open(path_out, 'w', encoding='utf-8')
754 | print(fld_temp)
755 | for i in range(n_temp):
756 | print('reading temp%i'%i)
757 | lines = open(fld_temp + '/temp%i'%i, encoding='utf-8').readlines()
758 | print('shuffling')
759 | jj = list(range(len(lines)))
760 | np.random.shuffle(jj)
761 | print('writing')
762 | with open(path_out, 'a', encoding='utf-8') as f:
763 | f.write('\n'.join([lines[j].strip('\n') for j in jj]) + '\n')
764 |
765 | def get_subs():
766 | return ['4chan']
767 | print('collectiing subs...')
768 | subs = sorted(os.listdir(fld_subs))
769 | print('collected %i subs'%len(subs))
770 | return subs
771 |
772 |
773 | def build_json(year):
774 | for date in get_dates(year):
775 | extract_rc(date)
776 | extract_rs(date)
777 |
778 |
779 | def build_basic(year):
780 | from transformers import GPT2Tokenizer
781 | tokenizer = GPT2Tokenizer.from_pretrained('gpt2')
782 | subs = get_subs()
783 | for sub in subs:
784 | extract_time(sub, year)
785 | extract_txt(sub, year, tokenizer)
786 | extract_trees(sub, year)
787 | calc_feedback(sub, year, overwrite=False)
788 |
789 |
790 | def build_pairs(year_from, year_to, feedback):
791 | subs = get_subs()
792 | for year in range(year_from, year_to + 1):
793 | for sub in subs:
794 | create_pairs(year, sub, feedback, overwrite=False)
795 | add_seq(sub, year, feedback, overwrite=False)
796 | path = combine_sub(year_from, year_to, feedback)
797 | split_by_root(path)
798 | for part in ['train', 'vali']:
799 | shuffle(feedback, part)
800 |
801 |
802 | FLD = 'data'
803 | fld_bz2 = FLD + '/bz2/'
804 | fld_jsonl = FLD + '/jsonl/'
805 | fld_subs = FLD + '/subs/'
806 | fld_out = FLD + '/out/'
807 |
808 | if __name__ == "__main__":
809 | import argparse
810 | parser = argparse.ArgumentParser()
811 | parser.add_argument('task', type=str)
812 | parser.add_argument('year', type=int)
813 | parser.add_argument('--year_to', type=int)
814 | args = parser.parse_args()
815 | if args.task == 'bz2':
816 | build_json(args.year)
817 | elif args.task == 'basic':
818 | build_basic(args.year)
819 | elif args.task in ['updown', 'depth', 'width']:
820 | build_pairs(args.year, args.year_to, args.task)
821 | else:
822 | raise ValueError
--------------------------------------------------------------------------------
/src/data_new.py:
--------------------------------------------------------------------------------
1 | """
2 | data source:
3 | https://academictorrents.com/details/56aa49f9653ba545f48df2e33679f014d2829c10
4 | https://academictorrents.com/details/20520c420c6c846f555523babc8c059e9daa8fc5
5 | """
6 |
7 | def zst2jsonl(path_zst):
8 | path_jsonl = path_zst + '.jsonl'
9 | open(path_jsonl, 'w')
10 | n_line = 0
11 | out = []
12 | with open(path_zst, 'rb') as fh:
13 | dctx = zstd.ZstdDecompressor(max_window_size=2147483648)
14 | with dctx.stream_reader(fh) as reader:
15 | previous_line = ""
16 | while True:
17 | chunk = reader.read(2**24) # 16mb chunks
18 | if not chunk:
19 | break
20 |
21 | string_data = chunk.decode('utf-8')
22 | lines = string_data.split("\n")
23 | for i, line in enumerate(lines[:-1]):
24 | if i == 0:
25 | line = previous_line + line
26 | object = json.loads(line)
27 | out.append(json.dumps(object, ensure_ascii=False))
28 | n_line += 1
29 | if n_line % 1e5 == 0:
30 | print(n_line)
31 | with open(path_zst + '.jsonl', 'a') as f:
32 | f.write('\n'.join(out) + '\n')
33 | out = []
34 | previous_line = lines[-1]
35 |
36 | if out:
37 | with open(path_zst + '.jsonl', 'a') as f:
38 | f.write('\n'.join(out) + '\n')
39 | print(n_line)
40 |
--------------------------------------------------------------------------------
/src/feeder.py:
--------------------------------------------------------------------------------
1 | # author: Xiang Gao at Microsoft Research AI NLP Group
2 |
3 |
4 | import torch, os, pdb
5 | import numpy as np
6 |
7 |
8 | class Feeder:
9 | # load train/vali/test data
10 |
11 | def __init__(self, opt):
12 | self.opt = opt
13 | self.files = dict()
14 | if self.opt.mismatch:
15 | self.files_mismatch = dict()
16 | for sub in ['train', 'vali', 'test']:
17 | self.reset(sub)
18 | self.ix_EOS = 50256
19 | self.ix_OMT = 986
20 |
21 |
22 | def reset(self, sub):
23 | print('resetting '+sub)
24 | path = '%s/%s.tsv'%(self.opt.fld_data, sub)
25 | if os.path.exists(path):
26 | self.files[sub] = open(path)
27 | if self.opt.mismatch:
28 | self.files_mismatch[sub] = open(path)
29 | # assuming f is already shuffled, this step makes f and f_mismatch mismatch
30 | for _ in range(100):
31 | self.files[sub].readline()
32 |
33 |
34 | def get_batch(self, size, sub='train', min_score_gap=1, min_rank_gap=0):
35 | ids_pos = []
36 | len_pos = []
37 | ids_neg = []
38 | len_neg = []
39 | len_cxt = []
40 | score_pos = []
41 | score_neg = []
42 | rank_pos = []
43 | rank_neg = []
44 | hr_gap = []
45 | if sub != 'train':
46 | np.random.seed(2020)
47 |
48 | def ints(s):
49 | return [int(x) for x in s.split()]
50 | def pad(seq):
51 | return seq + [self.ix_EOS] * (self.opt.max_seq_len - len(seq))
52 |
53 | def read():
54 | total = 0
55 | used = 0
56 | for line in self.files[sub]:
57 | if line.startswith('#'):
58 | continue
59 | # old data is title + ' . ' + selftext, ' .' is 764 and often used as ' .jpg' thus misleading
60 | line = line.replace(' 764\t', '\t').replace(' 764 50256 ', ' 50256 ')
61 | total += 1
62 | ss = line.strip('\n').split('\t')
63 | cxt = ints(ss[0])
64 | reply_pos = ints(ss[1])
65 | # _score_pos, _score_neg, _rank_pos, _rank_neg = ss[-4:]
66 | try:
67 | _hr_gap = float(ss[-5])
68 | except ValueError:
69 | _hr_gap = np.nan
70 | _score_pos = int(ss[-4])
71 | _rank_pos = float(ss[-2])
72 |
73 | if self.opt.mismatch:
74 | _score_neg = np.nan
75 | _rank_neg = np.nan
76 | line_mismatch = self.files_mismatch[sub].readline()
77 | ss_mismatch = line_mismatch.strip('\n').split('\t')
78 | reply_neg = ints(ss_mismatch[1])
79 |
80 | else:
81 | reply_neg = ints(ss[2])
82 | _score_neg = int(ss[-3])
83 | _rank_neg = float(ss[-1])
84 | if _score_pos - _score_neg < min_score_gap:
85 | continue
86 | if _rank_pos - _rank_neg < min_rank_gap:
87 | continue
88 | if self.opt.max_hr_gap > 0 and _hr_gap > self.opt.max_hr_gap:
89 | continue
90 |
91 | pos = cxt + [self.ix_EOS] + reply_pos
92 | score_pos.append(_score_pos)
93 | rank_pos.append(_rank_pos)
94 |
95 | neg = cxt + [self.ix_EOS] + reply_neg
96 | score_neg.append(_score_neg)
97 | rank_neg.append(_rank_neg)
98 |
99 | # make sure cxt still same even after cut
100 | n_del = max(len(pos), len(neg)) - self.opt.max_seq_len
101 | if n_del > 0:
102 | pos = pos[n_del:]
103 | neg = neg[n_del:]
104 | cxt = cxt[n_del:]
105 |
106 | len_cxt.append(len(cxt))
107 | len_pos.append(len(pos))
108 | len_neg.append(len(neg))
109 | ids_pos.append(pad(pos))
110 | ids_neg.append(pad(neg))
111 | hr_gap.append(_hr_gap)
112 |
113 | used += 1
114 | if len(ids_pos) == size:
115 | break
116 |
117 | while True:
118 | read()
119 | if len(ids_pos) == size:
120 | break
121 | self.reset(sub)
122 |
123 | ids_pos = torch.LongTensor(ids_pos)
124 | ids_neg = torch.LongTensor(ids_neg)
125 | if self.opt.cuda:
126 | ids_pos = ids_pos.cuda()
127 | ids_neg = ids_neg.cuda()
128 | return {
129 | 'ids_pos':ids_pos,
130 | 'ids_neg':ids_neg,
131 | 'len_pos':len_pos,
132 | 'len_neg':len_neg,
133 | 'len_cxt':len_cxt,
134 | 'score_pos': score_pos,
135 | 'score_neg': score_neg,
136 | 'rank_pos': rank_pos,
137 | 'rank_neg': rank_neg,
138 | 'hr_gap': hr_gap,
139 | }
--------------------------------------------------------------------------------
/src/generation.py:
--------------------------------------------------------------------------------
1 | # author: Xiang Gao at Microsoft Research AI NLP Group
2 |
3 | import torch, pdb
4 | import numpy as np
5 | from shared import download_model, EOS_token
6 |
7 | class GPT2Generator:
8 |
9 | def __init__(self, path, cuda):
10 | from transformers19 import GPT2Tokenizer, GPT2LMHeadModel, GPT2Config
11 | self.tokenizer = GPT2Tokenizer.from_pretrained('gpt2')
12 | model_config = GPT2Config(n_embd=1024, n_layer=24, n_head=16)
13 | self.model = GPT2LMHeadModel(model_config)
14 | download_model(path)
15 | print('loading from '+path)
16 | weights = torch.load(path)
17 | if "lm_head.decoder.weight" in weights:
18 | weights["lm_head.weight"] = weights["lm_head.decoder.weight"]
19 | weights.pop("lm_head.decoder.weight",None)
20 | self.model.load_state_dict(weights)
21 | self.ix_EOS = 50256
22 | self.model.eval()
23 | self.cuda = cuda
24 | if self.cuda:
25 | self.model.cuda()
26 |
27 |
28 | def tokenize(self, cxt):
29 | turns = cxt.split(EOS_token)
30 | ids = []
31 | for turn in turns:
32 | ids += self.tokenizer.encode(turn.strip()) + [self.ix_EOS]
33 | ids = torch.tensor([ids]).view(1, -1)
34 | if self.cuda:
35 | ids = ids.cuda()
36 | return ids
37 |
38 |
39 | def predict_beam(self, cxt, topk=3, topp=0.8, beam=10, max_t=30):
40 | """ pick top tokens at each time step """
41 |
42 | tokens = self.tokenize(cxt)
43 | len_cxt = tokens.shape[1]
44 | sum_logP = [0]
45 | finished = []
46 |
47 | for _ in range(max_t):
48 | outputs = self.model(tokens)
49 | predictions = outputs[0]
50 | logP = torch.log_softmax(predictions[:, -1, :], dim=-1)
51 | next_logP, next_token = torch.topk(logP, topk)
52 | sumlogP_ij = []
53 | sum_prob = 0
54 | for i in range(tokens.shape[0]):
55 | for j in range(topk):
56 | sum_prob += np.exp(logP[i, j].item())
57 | if sum_prob > topp:
58 | break
59 | if next_token[i, j] == self.ix_EOS:
60 | seq = torch.cat([tokens[i, len_cxt:], next_token[i, j].view(1)], dim=-1)
61 | if self.cuda:
62 | seq = seq.cpu()
63 | seq = seq.detach().numpy().tolist()
64 | prob = np.exp((sum_logP[i] + next_logP[i, j].item()) / len(seq))
65 | hyp = self.tokenizer.decode(seq[:-1]) # don't include EOS
66 | finished.append((prob, hyp))
67 | else:
68 | sumlogP_ij.append((
69 | sum_logP[i] + next_logP[i, j].item(),
70 | i, j))
71 |
72 | if not sumlogP_ij:
73 | break
74 | sumlogP_ij = sorted(sumlogP_ij, reverse=True)[:min(len(sumlogP_ij), beam)]
75 | new_tokens = []
76 | new_sum_logP = []
77 | for _sum_logP, i, j in sumlogP_ij:
78 | new_tokens.append(
79 | torch.cat([tokens[i,:], next_token[i, j].view(1)], dim=-1).view(1, -1)
80 | )
81 | new_sum_logP.append(_sum_logP)
82 | tokens = torch.cat(new_tokens, dim=0)
83 | sum_logP = new_sum_logP
84 |
85 | return finished
86 |
87 |
88 | def predict_sampling(self, cxt, temperature=1, n_hyp=5, max_t=30):
89 | """ sampling tokens based on predicted probability """
90 |
91 | tokens = self.tokenize(cxt)
92 | tokens = tokens.repeat(n_hyp, 1)
93 | len_cxt = tokens.shape[1]
94 | sum_logP = [0] * n_hyp
95 | live = [True] * n_hyp
96 | seqs = [[] for _ in range(n_hyp)]
97 | np.random.seed(2020)
98 | for _ in range(max_t):
99 | outputs = self.model(tokens)
100 | predictions = outputs[0]
101 | prob = torch.softmax(predictions[:, -1, :] / temperature, dim=-1)
102 | if self.cuda:
103 | prob = prob.cpu()
104 | prob = prob.detach().numpy()
105 | vocab = prob.shape[-1]
106 | next_tokens = []
107 | for i in range(n_hyp):
108 | next_token = np.random.choice(vocab, p=prob[i,:])
109 | next_tokens.append(next_token)
110 | if not live[i]:
111 | continue
112 | sum_logP[i] += np.log(prob[i, next_token])
113 | seqs[i].append(next_token)
114 | if next_token == self.ix_EOS:
115 | live[i] = False
116 | continue
117 | next_tokens = torch.LongTensor(next_tokens).view(-1, 1)
118 | if self.cuda:
119 | next_tokens = next_tokens.cuda()
120 | tokens = torch.cat([tokens, next_tokens], dim=-1)
121 |
122 | ret = []
123 | for i in range(n_hyp):
124 | if live[i]: # only return hyp that ends with EOS
125 | continue
126 | prob = np.exp(sum_logP[i] / (len(seqs[i]) + 1))
127 | hyp = self.tokenizer.decode(seqs[i][:-1]) # strip EOS
128 | ret.append((prob, hyp))
129 | return ret
130 |
131 |
132 | def play(self, params):
133 | while True:
134 | cxt = input('\nContext:\t')
135 | if not cxt:
136 | break
137 | ret = self.predict(cxt, **params)
138 | for prob, hyp in sorted(ret, reverse=True):
139 | print('%.3f\t%s'%(prob, hyp))
140 |
141 |
142 | class Integrated:
143 | def __init__(self, generator, ranker):
144 | self.generator = generator
145 | self.ranker = ranker
146 |
147 | def predict(self, cxt, wt_ranker, params):
148 | with torch.no_grad():
149 | prob_hyp = self.generator.predict(cxt, **params)
150 | probs = np.array([prob for prob, _ in prob_hyp])
151 | hyps = [hyp for _, hyp in prob_hyp]
152 | if wt_ranker > 0:
153 | scores_ranker = self.ranker.predict(cxt, hyps)
154 | if isinstance(scores_ranker, dict):
155 | scores_ranker = scores_ranker['final']
156 | scores = wt_ranker * scores_ranker + (1 - wt_ranker) * probs
157 | else:
158 | scores = probs
159 | ret = []
160 | for i in range(len(hyps)):
161 | ret.append((scores[i], probs[i], scores_ranker[i], hyps[i]))
162 | ret = sorted(ret, reverse=True)
163 | return ret
164 |
165 |
166 | def play(self, wt_ranker, params):
167 | while True:
168 | cxt = input('\nContext:\t')
169 | if not cxt:
170 | break
171 | ret = self.predict(cxt, wt_ranker, params)
172 | for final, prob_gen, score_ranker, hyp in ret:
173 | print('%.3f gen %.3f ranker %.3f\t%s'%(final, prob_gen, score_ranker, hyp))
174 |
175 |
176 | def test(model, path_in, wt_ranker, params, max_n):
177 | lines = []
178 | for i, line in enumerate(open(path_in, encoding='utf-8')):
179 | print('processing %i-th context'%i)
180 | cxt = line.strip('\n').split('\t')[0]
181 | ret = model.predict(cxt, wt_ranker, **params)
182 | cc = [cxt] + [tup[-1] for tup in ret]
183 | lines.append('\t'.join(cc))
184 | if i == max_n:
185 | break
186 | path_out = path_in + '.hyps'
187 | with open(path_out, 'w', encoding='utf-8') as f:
188 | f.write('\n'.join(lines))
189 | print('saved to '+path_out)
190 |
191 |
192 |
193 | if __name__ == "__main__":
194 | import argparse
195 | parser = argparse.ArgumentParser()
196 | parser.add_argument('task', type=str)
197 | parser.add_argument('--path_generator', '-pg', type=str)
198 | parser.add_argument('--path_ranker', '-pr', type=str)
199 | parser.add_argument('--path_test', type=str)
200 | parser.add_argument('--cpu', action='store_true')
201 | parser.add_argument('--sampling', action='store_true')
202 | parser.add_argument('--topk', type=int, default=3)
203 | parser.add_argument('--beam', type=int, default=3)
204 | parser.add_argument('--wt_ranker', type=float, default=1.)
205 | parser.add_argument('--topp', type=float, default=0.8)
206 | parser.add_argument('--max_n', type=int, default=-1)
207 | parser.add_argument('--temperature', type=float, default=0.5)
208 | parser.add_argument('--n_hyp', type=int, default=10)
209 | args = parser.parse_args()
210 |
211 | cuda = False if args.cpu else torch.cuda.is_available()
212 | generator = GPT2Generator(args.path_generator, cuda)
213 | if args.sampling:
214 | params = {'temperature': args.temperature, 'n_hyp': args.n_hyp}
215 | generator.predict = generator.predict_sampling
216 | else:
217 | params = {'topk': args.topk, 'beam': args.beam, 'topp': args.topp}
218 | generator.predict = generator.predict_beam
219 |
220 | if args.path_ranker is None:
221 | model = generator
222 | else:
223 | from score import get_model
224 | ranker = get_model(args.path_ranker, cuda)
225 | model = Integrated(generator, ranker)
226 |
227 | if args.task == 'play':
228 | model.play(args.wt_ranker, params)
229 | elif args.task == 'test':
230 | test(model, args.path_test, params, args.max_n)
--------------------------------------------------------------------------------
/src/main.py:
--------------------------------------------------------------------------------
1 | # author: Xiang Gao at Microsoft Research AI NLP Group
2 |
3 |
4 | import argparse, torch, time, pdb
5 | import os, socket
6 | from master import Master
7 |
8 |
9 | class Option:
10 |
11 | def __init__(self, args):
12 | if args.cpu or not torch.cuda.is_available():
13 | self.cuda = False
14 | else:
15 | self.cuda = True
16 | self.task = args.task
17 | self.path_load = args.path_load
18 | self.batch = args.batch
19 | self.vali_size = max(self.batch, args.vali_size)
20 | self.vali_print = args.vali_print
21 | self.lr = args.lr
22 | self.max_seq_len = args.max_seq_len
23 | self.min_score_gap = args.min_score_gap
24 | self.min_rank_gap = args.min_rank_gap
25 | self.max_hr_gap = args.max_hr_gap
26 | self.mismatch = args.mismatch
27 | self.fld_data = args.data
28 | if args.task == 'train' or self.path_load is None:
29 | self.fld_out = 'out/%i'%time.time()
30 | else:
31 | self.fld_out = 'out/temp'
32 | os.makedirs(self.fld_out, exist_ok=True)
33 |
34 | self.clip = 1
35 | self.step_max = 1e6
36 | self.step_print = 10
37 | self.step_vali = 100
38 | self.step_save = 500
39 | self.len_acc = self.step_vali
40 |
41 |
42 | def save(self):
43 | d = self.__dict__
44 | lines = []
45 | for k in d:
46 | lines.append('%s\t%s'%(k, d[k]))
47 | with open(self.fld_out + '/opt.tsv', 'w') as f:
48 | f.write('\n'.join(lines))
49 |
50 |
51 | if __name__ == "__main__":
52 | parser = argparse.ArgumentParser()
53 | parser.add_argument('task', type=str)
54 | parser.add_argument('--data', type=str)
55 | parser.add_argument('--batch', type=int, default=256)
56 | parser.add_argument('--vali_size', type=int, default=1024)
57 | parser.add_argument('--vali_print', type=int, default=10)
58 | parser.add_argument('--lr', type=float, default=3e-5)
59 | parser.add_argument('--path_load','-p', type=str)
60 | parser.add_argument('--cpu', action='store_true')
61 | parser.add_argument('--max_seq_len', type=int, default=50)
62 | parser.add_argument('--mismatch', action='store_true')
63 | parser.add_argument('--min_score_gap', type=int)
64 | parser.add_argument('--min_rank_gap', type=float)
65 | parser.add_argument('--max_hr_gap', type=float, default=1)
66 | args = parser.parse_args()
67 |
68 | opt = Option(args)
69 | master = Master(opt)
70 | if args.task == 'train':
71 | master.train()
72 | elif args.task == 'vali':
73 | master.vali()
--------------------------------------------------------------------------------
/src/master.py:
--------------------------------------------------------------------------------
1 | # author: Xiang Gao at Microsoft Research AI NLP Group
2 |
3 |
4 | import torch, os, pdb, time, sys, warnings
5 | import numpy as np
6 | from feeder import Feeder
7 | from model import Scorer, JointScorer
8 | import matplotlib.pyplot as plt
9 |
10 |
11 | class Master:
12 |
13 | def __init__(self, opt):
14 | self.opt = opt
15 | if opt.path_load is not None and (opt.path_load.endswith('.yaml') or opt.path_load.endswith('.yml')):
16 | self._model = JointScorer(opt)
17 | else:
18 | self._model = Scorer(opt)
19 | if opt.path_load is not None:
20 | self._model.load(opt.path_load)
21 | self.parallel()
22 |
23 | if opt.task != 'play':
24 | if opt.fld_data is not None:
25 | self.feeder = Feeder(opt)
26 |
27 | if opt.task == 'train':
28 | opt.save()
29 | os.makedirs(opt.fld_out + '/ckpt', exist_ok=True)
30 | self.path_log = self.opt.fld_out + '/log.txt'
31 | else:
32 | self.path_log = self.opt.fld_out + '/log_infer.txt'
33 |
34 |
35 | def print(self, s=''):
36 | try:
37 | print(s)
38 | except UnicodeEncodeError:
39 | print('[UnicodeEncodeError]')
40 | pass
41 | with open(self.path_log, 'a', encoding='utf-8') as f:
42 | f.write(s+'\n')
43 |
44 |
45 | def parallel(self):
46 | if self.opt.cuda:
47 | self._model = self._model.cuda()
48 | n_gpu = torch.cuda.device_count()
49 | if self.opt.cuda and n_gpu > 1:
50 | print('paralleling on %i GPU'%n_gpu)
51 | self.model = torch.nn.DataParallel(self._model)
52 | # after DataParallel, a warning about RNN weights shows up every batch
53 | warnings.filterwarnings("ignore")
54 | # after DataParallel, attr of self.model become attr of self.model.module
55 | self._model = self.model.module
56 | self.model.core = self.model.module.core
57 | self.model.tokenizer = self._model.tokenizer
58 | else:
59 | self.model = self._model
60 | if self.opt.task == 'train':
61 | self.optimizer = torch.optim.Adam(self._model.parameters(), lr=self.opt.lr)
62 |
63 |
64 | def train(self):
65 | vali_loss, best_acc = self.vali()
66 | best_trained = 0
67 | step = 0
68 | n_trained = 0
69 | t0 = time.time()
70 |
71 | list_trained = [0]
72 | list_train_loss = [np.nan]
73 | list_train_acc = [np.nan]
74 | list_vali_loss = [vali_loss]
75 | list_vali_acc = [best_acc]
76 | acc_history = []
77 |
78 | while step < self.opt.step_max:
79 | self.model.train()
80 | self.optimizer.zero_grad()
81 | batch = self.feeder.get_batch(self.opt.batch)
82 | pred = self.model.forward(batch)
83 | loss = self.loss(pred)
84 | loss = loss.mean() # in case of parallel-training
85 |
86 | loss.backward()
87 | torch.nn.utils.clip_grad_norm_(self.model.parameters(), self.opt.clip)
88 | self.optimizer.step()
89 |
90 | acc = (pred > 0.5).float().mean().item()
91 | acc_history.append(acc)
92 | if len(acc_history) > self.opt.len_acc:
93 | acc_history.pop(0)
94 | avg_train_acc = np.mean(acc_history)
95 | step += 1
96 | n_trained += self.opt.batch
97 | info = 'step %i trained %.3f best %.2f'%(step, n_trained/1e6, best_acc)
98 |
99 | if step % self.opt.step_print == 0:
100 | speed = (n_trained / 1e6) / ((time.time() - t0) / 3600)
101 |
102 | self.print('%s speed %.2f hr_gap %.2f score_gap %.2f rank_gap %.2f loss %.4f acc %.3f'%(
103 | info,
104 | speed,
105 | np.median(batch['hr_gap']),
106 | (np.array(batch['score_pos']) - np.array(batch['score_neg'])).mean(),
107 | (np.array(batch['rank_pos']) - np.array(batch['rank_neg'])).mean(),
108 | loss,
109 | avg_train_acc,
110 | ))
111 |
112 | if step % self.opt.step_vali == 0:
113 | vali_loss, vali_acc = self.vali(info)
114 | if vali_acc > best_acc:
115 | self.save(self.opt.fld_out + '/ckpt/best.pth')
116 | best_acc = vali_acc
117 | best_trained = n_trained
118 | sys.stdout.flush()
119 |
120 | list_trained.append(n_trained/1e6)
121 | list_train_loss.append(loss.item())
122 | list_train_acc.append(avg_train_acc)
123 | list_vali_loss.append(vali_loss)
124 | list_vali_acc.append(vali_acc)
125 | _, axs = plt.subplots(3, 1, sharex=True)
126 |
127 | axs[0].plot(list_trained, list_train_loss, 'b', label='train')
128 | axs[0].plot(list_trained, list_vali_loss, 'r', label='vali')
129 | axs[0].legend(loc='best')
130 | axs[0].set_ylabel('loss')
131 |
132 | axs[1].plot(list_trained, list_train_acc, 'b', label='train')
133 | axs[1].plot(list_trained, list_vali_acc, 'r', label='vali')
134 | axs[1].plot([best_trained/1e6, n_trained/1e6], [best_acc, best_acc], 'k:')
135 | axs[1].set_ylabel('acc')
136 |
137 | axs[-1].set_xlabel('trained (M)')
138 | axs[0].set_title(self.opt.fld_out + '\n' + self.opt.fld_data + '\nbest_acc = %.4f'%best_acc)
139 | plt.tight_layout()
140 | plt.savefig(self.opt.fld_out + '/log.png')
141 | plt.close()
142 |
143 | if step % self.opt.step_save == 0:
144 | self.save(self.opt.fld_out + '/ckpt/last.pth')
145 |
146 |
147 | def loss(self, pred):
148 | # pred is the probability to pick the positive response, given a context and a negative response
149 | return - torch.log(pred).mean()
150 |
151 |
152 | def vali(self, info=''):
153 | n_print = min(self.opt.batch, self.opt.vali_print)
154 | self.model.eval()
155 | loss = 0
156 | acc = 0
157 | hr_gap = 0
158 | score_gap = 0
159 | rank_gap = 0
160 | n_batch = int(self.opt.vali_size/self.opt.batch)
161 | self.feeder.reset('vali')
162 |
163 | for _ in range(n_batch):
164 | batch = self.feeder.get_batch(self.opt.batch, sub='vali',
165 | min_score_gap=self.opt.min_score_gap, min_rank_gap=self.opt.min_rank_gap)
166 | with torch.no_grad():
167 | pred = self.model.forward(batch)
168 | loss += self.loss(pred)
169 | acc += (pred > 0.5).float().mean()
170 | score_gap += (np.array(batch['score_pos']) - np.array(batch['score_neg'])).mean()
171 | rank_gap += (np.array(batch['rank_pos']) - np.array(batch['rank_neg'])).mean()
172 | hr_gap += np.median(batch['hr_gap'])
173 |
174 | loss /= n_batch
175 | acc /= n_batch
176 | score_gap /= n_batch
177 | rank_gap /= n_batch
178 | hr_gap /= n_batch
179 | s = '%s hr_gap %.2f score_gap %.2f rank_gap %.2f loss %.4f acc %.3f'%(
180 | info,
181 | hr_gap,
182 | score_gap,
183 | rank_gap,
184 | loss,
185 | acc,
186 | )
187 | s = '[vali] ' + s.strip()
188 | if not n_print:
189 | self.print(s)
190 | return loss.mean().item(), acc
191 |
192 | with torch.no_grad():
193 | pred_pos = self.model.core(batch['ids_pos'], batch['len_pos'])
194 | pred_neg = self.model.core(batch['ids_neg'], batch['len_neg'])
195 |
196 | def to_np(ids):
197 | if self.opt.cuda:
198 | ids = ids.cpu()
199 | return ids.detach().numpy()
200 |
201 | ids_pos = to_np(batch['ids_pos'])
202 | ids_neg = to_np(batch['ids_neg'])
203 |
204 | for j in range(n_print):
205 | l_cxt = batch['len_cxt'][j]
206 | cxt = self.model.tokenizer.decode(ids_pos[j, :l_cxt])
207 | pos = self.model.tokenizer.decode(ids_pos[j, l_cxt:]).strip('<|ndoftext|>')
208 | neg = self.model.tokenizer.decode(ids_neg[j, l_cxt:]).strip('<|ndoftext|>')
209 | self.print(cxt)
210 | self.print('hr_gap %s'%batch['hr_gap'][j])
211 | self.print('%s\t%.2f\t%.3f\t%s'%(batch['score_pos'][j], batch['rank_pos'][j], pred_pos[j], pos))
212 | self.print('%s\t%.2f\t%.3f\t%s'%(batch['score_neg'][j], batch['rank_neg'][j], pred_neg[j], neg))
213 | self.print()
214 |
215 | self.print(s)
216 | return loss.mean().item(), acc
217 |
218 |
219 | def save(self, path):
220 | torch.save(self._model.state_dict(), path)
221 | self.print('saved to '+path)
222 |
--------------------------------------------------------------------------------
/src/model.py:
--------------------------------------------------------------------------------
1 | # author: Xiang Gao at Microsoft Research AI NLP Group
2 |
3 |
4 | import torch, os, pdb
5 | import numpy as np
6 | from transformers19 import GPT2Tokenizer, GPT2Model, GPT2Config
7 | from shared import EOS_token
8 |
9 |
10 | class OptionInfer:
11 | def __init__(self, cuda=True):
12 | self.cuda = cuda
13 |
14 |
15 | class ScorerBase(torch.nn.Module):
16 | def __init__(self, opt):
17 | super().__init__()
18 | self.ix_EOS = 50256
19 | self.ix_OMT = 986
20 | self.opt = opt
21 | self.tokenizer = GPT2Tokenizer.from_pretrained('gpt2')
22 |
23 |
24 | def core(self, ids, l_ids, return_logits=False):
25 | # to be implemented in child class
26 | return 0
27 |
28 |
29 | def predict(self, cxt, hyps, max_cxt_turn=None):
30 | # cxt = str
31 | # hyps = list of str
32 |
33 | self.eval()
34 | cxt_turns = cxt.split(EOS_token)
35 | if max_cxt_turn is not None:
36 | cxt_turns = cxt_turns[-min(max_cxt_turn, len(cxt_turns)):]
37 | ids_cxt = []
38 | for turn in cxt_turns:
39 | ids_cxt += self.tokenizer.encode(turn.strip()) + [self.ix_EOS]
40 | seqs = []
41 | lens = []
42 | for hyp in hyps:
43 | seq = ids_cxt + self.tokenizer.encode(hyp.strip())
44 | lens.append(len(seq))
45 | seqs.append(seq)
46 | max_len = max(lens)
47 | ids = []
48 | for seq in seqs:
49 | ids.append(seq + [self.ix_EOS] * (max_len - len(seq)))
50 | with torch.no_grad():
51 | ids = torch.LongTensor(ids)
52 | if self.opt.cuda:
53 | ids = ids.cuda()
54 | scores = self.core(ids, lens)
55 | if not isinstance(scores, dict):
56 | if self.opt.cuda:
57 | scores = scores.cpu()
58 | return scores.detach().numpy()
59 |
60 | for k in scores:
61 | if self.opt.cuda:
62 | scores[k] = scores[k].cpu()
63 | scores[k] = scores[k].detach().numpy()
64 | return scores
65 |
66 |
67 | def forward(self, batch):
68 | logits_pos = self.core(batch['ids_pos'], batch['len_pos'], return_logits=True)
69 | logits_neg = self.core(batch['ids_neg'], batch['len_neg'], return_logits=True)
70 | # softmax to get the `probability` to rank pos/neg correctly
71 | return torch.exp(logits_pos) / (torch.exp(logits_pos) + torch.exp(logits_neg))
72 |
73 |
74 |
75 | class Scorer(ScorerBase):
76 | def __init__(self, opt):
77 | super().__init__(opt)
78 | n_embd = 1024
79 | config = GPT2Config(n_embd=n_embd, n_layer=24, n_head=16)
80 | self.transformer = GPT2Model(config)
81 | self.score = torch.nn.Linear(n_embd, 1, bias=False)
82 |
83 |
84 | def core(self, ids, l_ids, return_logits=False):
85 | n = ids.shape[0]
86 | attention_mask = torch.ones_like(ids)
87 | for i in range(n):
88 | attention_mask[i, l_ids[i]:] *= 0
89 | hidden_states, _ = self.transformer(ids, attention_mask=attention_mask)
90 | logits = self.score(hidden_states).squeeze(-1)
91 | logits = torch.stack([logits[i, l_ids[i] - 1] for i in range(n)])
92 | if return_logits:
93 | return logits
94 | else:
95 | return torch.sigmoid(logits)
96 |
97 |
98 | def load(self, path):
99 | from shared import download_model
100 | download_model(path)
101 | print('loading from '+path)
102 | weights = torch.load(path, map_location=torch.device('cpu'))
103 | if path.endswith('.pkl'):
104 | # DialoGPT checkpoint
105 | weights['score.weight'] = weights['lm_head.decoder.weight'][self.ix_EOS: self.ix_EOS+1, :]
106 | del weights['lm_head.decoder.weight']
107 | self.load_state_dict(weights)
108 | if self.opt.cuda:
109 | self.cuda()
110 |
111 |
112 | class JointScorer(ScorerBase):
113 |
114 | def core(self, ids, l_ids, return_logits=False):
115 | assert(not return_logits)
116 | scores = dict()
117 | for k in self.kk['prior'] + self.kk['cond']:
118 | scorer = getattr(self, 'scorer_%s'%k)
119 | scores[k] = scorer.core(ids, l_ids)
120 |
121 | def avg_score(kk):
122 | if not kk:
123 | return 1
124 | sum_score_wt = 0
125 | sum_wt = 0
126 | for k in kk:
127 | sum_score_wt = sum_score_wt + scores[k] * self.wt[k]
128 | sum_wt += self.wt[k]
129 | return sum_score_wt / sum_wt
130 |
131 | prior = avg_score(self.kk['prior'])
132 | cond = avg_score(self.kk['cond'])
133 | scores['final'] = prior * cond
134 | return scores
135 |
136 |
137 | def load(self, path_config):
138 | import yaml
139 | with open(path_config, 'r') as stream:
140 | config = yaml.safe_load(stream)
141 | print(config)
142 |
143 | paths = dict()
144 | self.wt = dict()
145 | self.kk = dict()
146 | for prefix in ['prior', 'cond']:
147 | self.kk[prefix] = []
148 | for d in config[prefix]:
149 | k = d['name']
150 | self.kk[prefix].append(k)
151 | self.wt[k] = d['wt']
152 | paths[k] = d['path']
153 |
154 | for k in paths:
155 | path = paths[k]
156 | print('setting up model `%s`'%k)
157 | scorer = Scorer(OptionInfer(cuda=self.opt.cuda))
158 | scorer.load(path)
159 | if self.opt.cuda:
160 | scorer.cuda()
161 | setattr(self, 'scorer_%s'%k, scorer)
162 |
163 |
164 |
165 |
--------------------------------------------------------------------------------
/src/score.py:
--------------------------------------------------------------------------------
1 | import torch, pdb, os, json
2 | from shared import _cat_
3 | import numpy as np
4 | from model import OptionInfer, Scorer
5 | from collections import defaultdict
6 |
7 |
8 | def get_model(path, cuda=True):
9 | opt = OptionInfer(cuda)
10 | if path.endswith('.yaml') or path.endswith('.yml'):
11 | from model import JointScorer
12 | model = JointScorer(opt)
13 | model.load(path)
14 | if path.endswith('pth'):
15 | from model import Scorer
16 | model = Scorer(opt)
17 | model.load(path)
18 | if cuda:
19 | model.cuda()
20 | return model
21 |
22 |
23 | def predict(model, cxt, hyps, max_cxt_turn=None):
24 | # split into smaller batch to avoid OOM
25 | n = len(hyps)
26 | i0 = 0
27 | scores = []
28 | while i0 < n:
29 | i1 = min(i0 + 32, n)
30 | _scores = model.predict(cxt, hyps[i0: i1], max_cxt_turn=max_cxt_turn)
31 | scores.append(_scores)
32 | i0 = i1
33 | if isinstance(_scores, dict):
34 | d_scores = dict()
35 | for k in _scores:
36 | d_scores[k] = np.concatenate([_scores[k] for _scores in scores])
37 | return d_scores
38 | else:
39 | return np.concatenate(scores)
40 |
41 |
42 |
43 | def eval_fake(fld, model, fake, max_n=-1, max_cxt_turn=None):
44 | """
45 | for a given context, we rank k real and m fake responses
46 | if x real responses appeared in topk ranked responses, define acc = x/k, where k = # of real.
47 | this can be seen as a generalized version of hits@k
48 | for a perfect ranking, x == k thus acc == 1.
49 | """
50 |
51 | assert(os.path.isdir(fld))
52 | def read_data(path, max_n=-1):
53 | cxts = dict()
54 | rsps = dict()
55 | for i, line in enumerate(open(path, encoding='utf-8')):
56 | ss = line.strip('\n').split('\t')
57 | ss0 = ss[0].split(_cat_)
58 | if len(ss0) == 2:
59 | cxt, cxt_id = ss0
60 | cxt_id = cxt_id.strip()
61 | else:
62 | cxt = ss0[0]
63 | cxt_id = cxt.strip().replace(' ','')
64 | cxts[cxt_id] = cxt.strip()
65 | rsps[cxt_id] = [s.split(_cat_)[0] for s in ss[1:]]
66 | if i == max_n:
67 | break
68 | return cxts, rsps
69 |
70 | print('evaluating %s'%fld)
71 | acc = []
72 | cxts, reals = read_data(fld + '/ref.tsv', max_n=max_n)
73 | _, fakes = read_data(fld + '/%s.tsv'%fake)
74 |
75 | n = 0
76 | for cxt_id in reals:
77 | if cxt_id not in fakes:
78 | print('[WARNING] could not find fake examples for [%s]'%cxt_id)
79 | #pdb.set_trace()
80 | continue
81 | scores = predict(model, cxts[cxt_id], reals[cxt_id] + fakes[cxt_id], max_cxt_turn=max_cxt_turn)
82 | ix_score = sorted([(scores[i], i) for i in range(len(scores))], reverse=True)
83 | k = len(reals[cxt_id])
84 | _acc = np.mean([i < k for _, i in ix_score[:k]])
85 | acc.append(_acc)
86 | n += 1
87 | if n % 10 == 0:
88 | print('evaluated %i, avg acc %.3f'%(n, np.mean(acc)))
89 | if n == max_n:
90 | break
91 |
92 | print('final acc is %.3f based on %i samples'%(np.mean(acc), n))
93 |
94 |
95 |
96 | def eval_feedback(path, model, max_n=-1, max_cxt_turn=None, min_rank_gap=0., min_score_gap=0, max_hr_gap=1):
97 | """
98 | for a given context, we compare two responses,
99 | predict which one got better feedback (greater updown, depth, or width)
100 | return this pairwise accuracy
101 | """
102 | assert(path.endswith('.tsv'))
103 | assert(min_rank_gap is not None)
104 | assert(min_score_gap is not None)
105 |
106 | print('evaluating %s'%path)
107 | acc = []
108 | n = 0
109 | for line in open(path, encoding='utf-8'):
110 | cc = line.strip('\n').split('\t')
111 | if len(cc) != 11:
112 | continue
113 | cxt, pos, neg, _, _, _, hr_gap, pos_score, neg_score, pos_rank, neg_rank = cc
114 | if float(hr_gap) > max_hr_gap:
115 | continue
116 | if float(pos_rank) - float(neg_rank) < min_rank_gap:
117 | continue
118 | if int(pos_score) - int(neg_score) < min_score_gap:
119 | continue
120 |
121 | scores = predict(model, cxt, [pos, neg], max_cxt_turn=max_cxt_turn)
122 | score_pos = scores[0]
123 | score_neg = scores[1]
124 | acc.append(float(score_pos > score_neg))
125 | n += 1
126 | if n % 10 == 0:
127 | print('evaluated %i, avg acc %.3f'%(n, np.mean(acc)))
128 | if n == max_n:
129 | break
130 |
131 | print('final acc is %.3f based on %i samples'%(np.mean(acc), n))
132 |
133 |
134 |
135 | def rank_hyps(path, model, max_n=-1, max_cxt_turn=None):
136 | """
137 | rank the responses for each given cxt with model
138 | path is the input file, where in each line, 0-th column is the context, and the rest are responses
139 | output a jsonl file, and can be read with function `read_ranked_jsonl`
140 | """
141 |
142 | print('ranking %s'%path)
143 | lines = []
144 | n = 0
145 | sum_avg_score = 0
146 | sum_top_score = 0
147 | for i, line in enumerate(open(path, encoding='utf-8')):
148 | cc = line.strip('\n').split('\t')
149 | if len(cc) < 2:
150 | print('[WARNING] line %i only has %i columns, ignored'%(i, len(cc)))
151 | continue
152 | cxt = cc[0]
153 | hyps = cc[1:]
154 | scores = predict(model, cxt, hyps, max_cxt_turn=max_cxt_turn)
155 | d = {'line_id':i, 'cxt': cxt}
156 | scored = []
157 | if isinstance(scores, dict):
158 | sum_avg_score += np.mean(scores['final'])
159 | sum_top_score += np.max(scores['final'])
160 | for j, hyp in enumerate(hyps):
161 | tup = (
162 | float(scores['final'][j]),
163 | dict([(k, float(scores[k][j])) for k in scores]),
164 | hyp,
165 | )
166 | scored.append(tup)
167 | else:
168 | sum_avg_score += np.mean(scores)
169 | sum_top_score += np.max(scores)
170 | for j, hyp in enumerate(hyps):
171 | scored.append((float(scores[j]), hyp))
172 | d['hyps'] = list(sorted(scored, reverse=True))
173 | lines.append(json.dumps(d))
174 | n += 1
175 |
176 | if n % 10 == 0:
177 | print('processed %i line, avg_hyp_score %.3f, top_hyp_score %.3f'%(
178 | n,
179 | sum_avg_score/n,
180 | sum_top_score/n,
181 | ))
182 | if n == max_n:
183 | break
184 | print('totally processed %i line, avg_hyp_score %.3f, top_hyp_score %.3f'%(
185 | n,
186 | sum_avg_score/n,
187 | sum_top_score/n,
188 | ))
189 | path_out = path+'.ranked.jsonl'
190 | with open(path_out, 'w') as f:
191 | f.write('\n'.join(lines))
192 | print('results saved to '+path_out)
193 |
194 |
195 | def read_ranked_jsonl(path):
196 | """ read the jsonl file ouput by function rank_hyps"""
197 | data = [json.loads(line) for line in open(path, encoding="utf-8")]
198 | n_hyp = [len(d['hyps']) for d in data]
199 | best = defaultdict(list)
200 | avg = defaultdict(list)
201 | for d in data:
202 | scores = defaultdict(list)
203 | for tup in d['hyps']:
204 | scores['_score'].append(tup[0])
205 | if isinstance(tup[1], dict):
206 | for k in tup[1]:
207 | scores[k].append(tup[1][k])
208 | for k in scores:
209 | best[k].append(max(scores[k]))
210 | avg[k].append(np.mean(scores[k]))
211 |
212 | print()
213 | width = 20
214 | print('\t|'.join([' '*width, 'best', 'avg']))
215 | print('-'*40)
216 | for k in best:
217 | print('%s\t|%.3f\t|%.3f'%(
218 | ' '*(width - len(k)) + k,
219 | np.mean(best[k]),
220 | np.mean(avg[k]),
221 | ))
222 | print('-'*40)
223 | print('n_cxt: %i'%len(data))
224 | print('avg n_hyp per cxt: %.2f'%np.mean(n_hyp))
225 | return data
226 |
227 |
228 |
229 | def play(model, max_cxt_turn=None):
230 | from shared import EOS_token
231 | model.eval()
232 | print('enter empty to stop')
233 | print('use `%s` to delimite turns for a multi-turn context'%EOS_token)
234 | while True:
235 | print()
236 | cxt = input('Context: ')
237 | if not cxt:
238 | break
239 | hyp = input('Response: ')
240 | if not hyp:
241 | break
242 | score = model.predict(cxt, [hyp], max_cxt_turn=max_cxt_turn)
243 | if isinstance(score, dict):
244 | ss = ['%s = %.3f'%(k, score[k][0]) for k in score]
245 | print(', '.join(ss))
246 | else:
247 | print('score = %.3f'%score[0])
248 |
249 |
250 | if __name__ == "__main__":
251 | import argparse
252 | parser = argparse.ArgumentParser()
253 | parser.add_argument('task', type=str)
254 | parser.add_argument('--data', type=str)
255 | parser.add_argument('--max_cxt_turn', type=int, default=2)
256 | parser.add_argument('--path_pth', '-p', type=str)
257 | parser.add_argument('--cpu', action='store_true')
258 | parser.add_argument('--max_n', type=int, default=5000)
259 | parser.add_argument('--min_score_gap', type=int)
260 | parser.add_argument('--min_rank_gap', type=float)
261 | args = parser.parse_args()
262 |
263 | cuda = False if args.cpu else torch.cuda.is_available()
264 | if args.task != 'stats':
265 | model = get_model(args.path_pth, cuda)
266 |
267 | if args.task in ['eval_human_vs_rand', 'eval_human_vs_machine']:
268 | fake = args.task.split('_')[-1]
269 | eval_fake(args.data, model, fake, max_n=args.max_n, max_cxt_turn=args.max_cxt_turn)
270 |
271 | elif args.task == 'eval_human_feedback':
272 | eval_feedback(args.data, model, max_cxt_turn=args.max_cxt_turn,
273 | min_rank_gap=args.min_rank_gap, max_n=args.max_n, min_score_gap=args.min_score_gap)
274 |
275 | elif args.task == 'test':
276 | rank_hyps(args.data, model, max_n=args.max_n, max_cxt_turn=args.max_cxt_turn)
277 |
278 | elif args.task == 'play':
279 | play(model, max_cxt_turn=args.max_cxt_turn)
280 |
281 | elif args.task == 'stats':
282 | read_ranked_jsonl(args.data)
283 |
284 | else:
285 | raise ValueError
286 |
--------------------------------------------------------------------------------
/src/shared.py:
--------------------------------------------------------------------------------
1 | # author: Xiang Gao at Microsoft Research AI NLP Group
2 |
3 | _cat_ = ' <-COL-> '
4 | #EOS_token = '_EOS_' # old version, before Nov 8 2020
5 | EOS_token = '<|endoftext|>'
6 |
7 |
8 | def download_model(path):
9 | if path is None:
10 | return
11 | import os, subprocess
12 | if os.path.exists(path):
13 | return
14 | links = dict()
15 | for k in ['updown', 'depth', 'width', 'human_vs_rand', 'human_vs_machine']:
16 | links['restore/%s.pth'%k] = 'https://xiagnlp2.blob.core.windows.net/dialogrpt/%s.pth'%k
17 | links['restore/medium_ft.pkl'] = 'https://convaisharables.blob.core.windows.net/lsp/multiref/medium_ft.pkl'
18 | if path not in links:
19 | return
20 | cmd = [ 'wget', links[path], '-P', 'restore']
21 | process = subprocess.Popen(cmd, stdout=subprocess.PIPE)
22 | process.communicate()
23 |
--------------------------------------------------------------------------------
/src/transformers19/__init__.py:
--------------------------------------------------------------------------------
1 | # copied from: https://github.com/huggingface/transformers/commit/4d456542e9d381090f9a00b2bcc5a4cb07f6f3f7
2 |
3 | from .tokenization_gpt2 import GPT2Tokenizer
4 | from .configuration_gpt2 import GPT2Config, GPT2_PRETRAINED_CONFIG_ARCHIVE_MAP
5 | from .modeling_gpt2 import (GPT2PreTrainedModel, GPT2Model,
6 | GPT2LMHeadModel, GPT2DoubleHeadsModel,
7 | #load_tf_weights_in_gpt2,
8 | GPT2_PRETRAINED_MODEL_ARCHIVE_MAP)
--------------------------------------------------------------------------------
/src/transformers19/configuration_gpt2.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2018 The OpenAI Team Authors and HuggingFace Inc. team.
3 | # Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved.
4 | #
5 | # Licensed under the Apache License, Version 2.0 (the "License");
6 | # you may not use this file except in compliance with the License.
7 | # You may obtain a copy of the License at
8 | #
9 | # http://www.apache.org/licenses/LICENSE-2.0
10 | #
11 | # Unless required by applicable law or agreed to in writing, software
12 | # distributed under the License is distributed on an "AS IS" BASIS,
13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 | # See the License for the specific language governing permissions and
15 | # limitations under the License.
16 | """ OpenAI GPT-2 configuration """
17 |
18 | from __future__ import absolute_import, division, print_function, unicode_literals
19 |
20 | import json
21 | import logging
22 | import sys
23 | from io import open
24 |
25 | from .configuration_utils import PretrainedConfig
26 |
27 | logger = logging.getLogger(__name__)
28 |
29 | GPT2_PRETRAINED_CONFIG_ARCHIVE_MAP = {"gpt2": "https://s3.amazonaws.com/models.huggingface.co/bert/gpt2-config.json",
30 | "gpt2-medium": "https://s3.amazonaws.com/models.huggingface.co/bert/gpt2-medium-config.json",
31 | "gpt2-large": "https://s3.amazonaws.com/models.huggingface.co/bert/gpt2-large-config.json",
32 | "distilgpt2": "https://s3.amazonaws.com/models.huggingface.co/bert/distilgpt2-config.json",}
33 |
34 | class GPT2Config(PretrainedConfig):
35 | """Configuration class to store the configuration of a `GPT2Model`.
36 |
37 | Args:
38 | vocab_size_or_config_json_file: Vocabulary size of `inputs_ids` in `GPT2Model` or a configuration json file.
39 | n_positions: Number of positional embeddings.
40 | n_ctx: Size of the causal mask (usually same as n_positions).
41 | n_embd: Dimensionality of the embeddings and hidden states.
42 | n_layer: Number of hidden layers in the Transformer encoder.
43 | n_head: Number of attention heads for each attention layer in
44 | the Transformer encoder.
45 | layer_norm_epsilon: epsilon to use in the layer norm layers
46 | resid_pdrop: The dropout probabilitiy for all fully connected
47 | layers in the embeddings, encoder, and pooler.
48 | attn_pdrop: The dropout ratio for the attention
49 | probabilities.
50 | embd_pdrop: The dropout ratio for the embeddings.
51 | initializer_range: The sttdev of the truncated_normal_initializer for
52 | initializing all weight matrices.
53 | """
54 | pretrained_config_archive_map = GPT2_PRETRAINED_CONFIG_ARCHIVE_MAP
55 |
56 | def __init__(
57 | self,
58 | vocab_size_or_config_json_file=50257,
59 | n_positions=1024,
60 | n_ctx=1024,
61 | n_embd=768,
62 | n_layer=12,
63 | n_head=12,
64 | resid_pdrop=0.1,
65 | embd_pdrop=0.1,
66 | attn_pdrop=0.1,
67 | layer_norm_epsilon=1e-5,
68 | initializer_range=0.02,
69 |
70 | num_labels=1,
71 | summary_type='cls_index',
72 | summary_use_proj=True,
73 | summary_activation=None,
74 | summary_proj_to_labels=True,
75 | summary_first_dropout=0.1,
76 | **kwargs
77 | ):
78 | """Constructs GPT2Config.
79 |
80 | Args:
81 | vocab_size_or_config_json_file: Vocabulary size of `inputs_ids` in `GPT2Model` or a configuration json file.
82 | n_positions: Number of positional embeddings.
83 | n_ctx: Size of the causal mask (usually same as n_positions).
84 | n_embd: Dimensionality of the embeddings and hidden states.
85 | n_layer: Number of hidden layers in the Transformer encoder.
86 | n_head: Number of attention heads for each attention layer in
87 | the Transformer encoder.
88 | layer_norm_epsilon: epsilon to use in the layer norm layers
89 | resid_pdrop: The dropout probabilitiy for all fully connected
90 | layers in the embeddings, encoder, and pooler.
91 | attn_pdrop: The dropout ratio for the attention
92 | probabilities.
93 | embd_pdrop: The dropout ratio for the embeddings.
94 | initializer_range: The sttdev of the truncated_normal_initializer for
95 | initializing all weight matrices.
96 | """
97 | super(GPT2Config, self).__init__(**kwargs)
98 |
99 | if isinstance(vocab_size_or_config_json_file, str) or (sys.version_info[0] == 2
100 | and isinstance(vocab_size_or_config_json_file, unicode)):
101 | with open(vocab_size_or_config_json_file, "r", encoding="utf-8") as reader:
102 | json_config = json.loads(reader.read())
103 | for key, value in json_config.items():
104 | self.__dict__[key] = value
105 | elif isinstance(vocab_size_or_config_json_file, int):
106 | self.vocab_size = vocab_size_or_config_json_file
107 | self.n_ctx = n_ctx
108 | self.n_positions = n_positions
109 | self.n_embd = n_embd
110 | self.n_layer = n_layer
111 | self.n_head = n_head
112 | self.resid_pdrop = resid_pdrop
113 | self.embd_pdrop = embd_pdrop
114 | self.attn_pdrop = attn_pdrop
115 | self.layer_norm_epsilon = layer_norm_epsilon
116 | self.initializer_range = initializer_range
117 |
118 | self.num_labels = num_labels
119 | self.summary_type = summary_type
120 | self.summary_use_proj = summary_use_proj
121 | self.summary_activation = summary_activation
122 | self.summary_first_dropout = summary_first_dropout
123 | self.summary_proj_to_labels = summary_proj_to_labels
124 | else:
125 | raise ValueError(
126 | "First argument must be either a vocabulary size (int)"
127 | "or the path to a pretrained model config file (str)"
128 | )
129 |
130 | @property
131 | def max_position_embeddings(self):
132 | return self.n_positions
133 |
134 | @property
135 | def hidden_size(self):
136 | return self.n_embd
137 |
138 | @property
139 | def num_attention_heads(self):
140 | return self.n_head
141 |
142 | @property
143 | def num_hidden_layers(self):
144 | return self.n_layer
145 |
--------------------------------------------------------------------------------
/src/transformers19/configuration_utils.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2018 The Google AI Language Team Authors and The HuggingFace Inc. team.
3 | # Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved.
4 | #
5 | # Licensed under the Apache License, Version 2.0 (the "License");
6 | # you may not use this file except in compliance with the License.
7 | # You may obtain a copy of the License at
8 | #
9 | # http://www.apache.org/licenses/LICENSE-2.0
10 | #
11 | # Unless required by applicable law or agreed to in writing, software
12 | # distributed under the License is distributed on an "AS IS" BASIS,
13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 | # See the License for the specific language governing permissions and
15 | # limitations under the License.
16 | """ Configuration base class and utilities."""
17 |
18 | from __future__ import (absolute_import, division, print_function,
19 | unicode_literals)
20 |
21 | import copy
22 | import json
23 | import logging
24 | import os
25 | from io import open
26 |
27 | from .file_utils import cached_path, CONFIG_NAME
28 |
29 | logger = logging.getLogger(__name__)
30 |
31 | class PretrainedConfig(object):
32 | r""" Base class for all configuration classes.
33 | Handles a few parameters common to all models' configurations as well as methods for loading/downloading/saving configurations.
34 |
35 | Note:
36 | A configuration file can be loaded and saved to disk. Loading the configuration file and using this file to initialize a model does **not** load the model weights.
37 | It only affects the model's configuration.
38 |
39 | Class attributes (overridden by derived classes):
40 | - ``pretrained_config_archive_map``: a python ``dict`` of with `short-cut-names` (string) as keys and `url` (string) of associated pretrained model configurations as values.
41 |
42 | Parameters:
43 | ``finetuning_task``: string, default `None`. Name of the task used to fine-tune the model. This can be used when converting from an original (TensorFlow or PyTorch) checkpoint.
44 | ``num_labels``: integer, default `2`. Number of classes to use when the model is a classification model (sequences/tokens)
45 | ``output_attentions``: boolean, default `False`. Should the model returns attentions weights.
46 | ``output_hidden_states``: string, default `False`. Should the model returns all hidden-states.
47 | ``torchscript``: string, default `False`. Is the model used with Torchscript.
48 | """
49 | pretrained_config_archive_map = {}
50 |
51 | def __init__(self, **kwargs):
52 | self.finetuning_task = kwargs.pop('finetuning_task', None)
53 | self.num_labels = kwargs.pop('num_labels', 2)
54 | self.output_attentions = kwargs.pop('output_attentions', False)
55 | self.output_hidden_states = kwargs.pop('output_hidden_states', False)
56 | self.output_past = kwargs.pop('output_past', True) # Not used by all models
57 | self.torchscript = kwargs.pop('torchscript', False) # Only used by PyTorch models
58 | self.use_bfloat16 = kwargs.pop('use_bfloat16', False)
59 | self.pruned_heads = kwargs.pop('pruned_heads', {})
60 |
61 | def save_pretrained(self, save_directory):
62 | """ Save a configuration object to the directory `save_directory`, so that it
63 | can be re-loaded using the :func:`~transformers.PretrainedConfig.from_pretrained` class method.
64 | """
65 | assert os.path.isdir(save_directory), "Saving path should be a directory where the model and configuration can be saved"
66 |
67 | # If we save using the predefined names, we can load using `from_pretrained`
68 | output_config_file = os.path.join(save_directory, CONFIG_NAME)
69 |
70 | self.to_json_file(output_config_file)
71 | logger.info("Configuration saved in {}".format(output_config_file))
72 |
73 | @classmethod
74 | def from_pretrained(cls, pretrained_model_name_or_path, **kwargs):
75 | r""" Instantiate a :class:`~transformers.PretrainedConfig` (or a derived class) from a pre-trained model configuration.
76 |
77 | Parameters:
78 | pretrained_model_name_or_path: either:
79 |
80 | - a string with the `shortcut name` of a pre-trained model configuration to load from cache or download, e.g.: ``bert-base-uncased``.
81 | - a path to a `directory` containing a configuration file saved using the :func:`~transformers.PretrainedConfig.save_pretrained` method, e.g.: ``./my_model_directory/``.
82 | - a path or url to a saved configuration JSON `file`, e.g.: ``./my_model_directory/configuration.json``.
83 |
84 | cache_dir: (`optional`) string:
85 | Path to a directory in which a downloaded pre-trained model
86 | configuration should be cached if the standard cache should not be used.
87 |
88 | kwargs: (`optional`) dict: key/value pairs with which to update the configuration object after loading.
89 |
90 | - The values in kwargs of any keys which are configuration attributes will be used to override the loaded values.
91 | - Behavior concerning key/value pairs whose keys are *not* configuration attributes is controlled by the `return_unused_kwargs` keyword parameter.
92 |
93 | force_download: (`optional`) boolean, default False:
94 | Force to (re-)download the model weights and configuration files and override the cached versions if they exists.
95 |
96 | proxies: (`optional`) dict, default None:
97 | A dictionary of proxy servers to use by protocol or endpoint, e.g.: {'http': 'foo.bar:3128', 'http://hostname': 'foo.bar:4012'}.
98 | The proxies are used on each request.
99 |
100 | return_unused_kwargs: (`optional`) bool:
101 |
102 | - If False, then this function returns just the final configuration object.
103 | - If True, then this functions returns a tuple `(config, unused_kwargs)` where `unused_kwargs` is a dictionary consisting of the key/value pairs whose keys are not configuration attributes: ie the part of kwargs which has not been used to update `config` and is otherwise ignored.
104 |
105 | Examples::
106 |
107 | # We can't instantiate directly the base class `PretrainedConfig` so let's show the examples on a
108 | # derived class: BertConfig
109 | config = BertConfig.from_pretrained('bert-base-uncased') # Download configuration from S3 and cache.
110 | config = BertConfig.from_pretrained('./test/saved_model/') # E.g. config (or model) was saved using `save_pretrained('./test/saved_model/')`
111 | config = BertConfig.from_pretrained('./test/saved_model/my_configuration.json')
112 | config = BertConfig.from_pretrained('bert-base-uncased', output_attention=True, foo=False)
113 | assert config.output_attention == True
114 | config, unused_kwargs = BertConfig.from_pretrained('bert-base-uncased', output_attention=True,
115 | foo=False, return_unused_kwargs=True)
116 | assert config.output_attention == True
117 | assert unused_kwargs == {'foo': False}
118 |
119 | """
120 | cache_dir = kwargs.pop('cache_dir', None)
121 | force_download = kwargs.pop('force_download', False)
122 | proxies = kwargs.pop('proxies', None)
123 | return_unused_kwargs = kwargs.pop('return_unused_kwargs', False)
124 |
125 | if pretrained_model_name_or_path in cls.pretrained_config_archive_map:
126 | config_file = cls.pretrained_config_archive_map[pretrained_model_name_or_path]
127 | elif os.path.isdir(pretrained_model_name_or_path):
128 | config_file = os.path.join(pretrained_model_name_or_path, CONFIG_NAME)
129 | else:
130 | config_file = pretrained_model_name_or_path
131 | # redirect to the cache, if necessary
132 | try:
133 | resolved_config_file = cached_path(config_file, cache_dir=cache_dir, force_download=force_download, proxies=proxies)
134 | except EnvironmentError:
135 | if pretrained_model_name_or_path in cls.pretrained_config_archive_map:
136 | msg = "Couldn't reach server at '{}' to download pretrained model configuration file.".format(
137 | config_file)
138 | else:
139 | msg = "Model name '{}' was not found in model name list ({}). " \
140 | "We assumed '{}' was a path or url to a configuration file named {} or " \
141 | "a directory containing such a file but couldn't find any such file at this path or url.".format(
142 | pretrained_model_name_or_path,
143 | ', '.join(cls.pretrained_config_archive_map.keys()),
144 | config_file, CONFIG_NAME)
145 | raise EnvironmentError(msg)
146 |
147 | if resolved_config_file == config_file:
148 | logger.info("loading configuration file {}".format(config_file))
149 | else:
150 | logger.info("loading configuration file {} from cache at {}".format(
151 | config_file, resolved_config_file))
152 |
153 | # Load config
154 | config = cls.from_json_file(resolved_config_file)
155 |
156 | if hasattr(config, 'pruned_heads'):
157 | config.pruned_heads = dict((int(key), value) for key, value in config.pruned_heads.items())
158 |
159 | # Update config with kwargs if needed
160 | to_remove = []
161 | for key, value in kwargs.items():
162 | if hasattr(config, key):
163 | setattr(config, key, value)
164 | to_remove.append(key)
165 | for key in to_remove:
166 | kwargs.pop(key, None)
167 |
168 | logger.info("Model config %s", str(config))
169 | if return_unused_kwargs:
170 | return config, kwargs
171 | else:
172 | return config
173 |
174 | @classmethod
175 | def from_dict(cls, json_object):
176 | """Constructs a `Config` from a Python dictionary of parameters."""
177 | config = cls(vocab_size_or_config_json_file=-1)
178 | for key, value in json_object.items():
179 | setattr(config, key, value)
180 | return config
181 |
182 | @classmethod
183 | def from_json_file(cls, json_file):
184 | """Constructs a `BertConfig` from a json file of parameters."""
185 | with open(json_file, "r", encoding='utf-8') as reader:
186 | text = reader.read()
187 | return cls.from_dict(json.loads(text))
188 |
189 | def __eq__(self, other):
190 | return self.__dict__ == other.__dict__
191 |
192 | def __repr__(self):
193 | return str(self.to_json_string())
194 |
195 | def to_dict(self):
196 | """Serializes this instance to a Python dictionary."""
197 | output = copy.deepcopy(self.__dict__)
198 | return output
199 |
200 | def to_json_string(self):
201 | """Serializes this instance to a JSON string."""
202 | return json.dumps(self.to_dict(), indent=2, sort_keys=True) + "\n"
203 |
204 | def to_json_file(self, json_file_path):
205 | """ Save this instance to a json file."""
206 | with open(json_file_path, "w", encoding='utf-8') as writer:
207 | writer.write(self.to_json_string())
208 |
--------------------------------------------------------------------------------
/src/transformers19/file_utils.py:
--------------------------------------------------------------------------------
1 | """
2 | Utilities for working with the local dataset cache.
3 | This file is adapted from the AllenNLP library at https://github.com/allenai/allennlp
4 | Copyright by the AllenNLP authors.
5 | """
6 | from __future__ import (absolute_import, division, print_function, unicode_literals)
7 |
8 | import sys
9 | import json
10 | import logging
11 | import os
12 | import six
13 | import shutil
14 | import tempfile
15 | import fnmatch
16 | from functools import wraps
17 | from hashlib import sha256
18 | from io import open
19 |
20 | import boto3
21 | from botocore.config import Config
22 | from botocore.exceptions import ClientError
23 | import requests
24 | from tqdm import tqdm
25 |
26 | logger = logging.getLogger(__name__) # pylint: disable=invalid-name
27 | _tf_available = False # pylint: disable=invalid-name
28 |
29 | try:
30 | import torch
31 | _torch_available = True # pylint: disable=invalid-name
32 | logger.info("PyTorch version {} available.".format(torch.__version__))
33 | except ImportError:
34 | _torch_available = False # pylint: disable=invalid-name
35 |
36 |
37 | try:
38 | from torch.hub import _get_torch_home
39 | torch_cache_home = _get_torch_home()
40 | except ImportError:
41 | torch_cache_home = os.path.expanduser(
42 | os.getenv('TORCH_HOME', os.path.join(
43 | os.getenv('XDG_CACHE_HOME', '~/.cache'), 'torch')))
44 | default_cache_path = os.path.join(torch_cache_home, 'transformers')
45 |
46 | try:
47 | from urllib.parse import urlparse
48 | except ImportError:
49 | from urlparse import urlparse
50 |
51 | try:
52 | from pathlib import Path
53 | PYTORCH_PRETRAINED_BERT_CACHE = Path(
54 | os.getenv('PYTORCH_TRANSFORMERS_CACHE', os.getenv('PYTORCH_PRETRAINED_BERT_CACHE', default_cache_path)))
55 | except (AttributeError, ImportError):
56 | PYTORCH_PRETRAINED_BERT_CACHE = os.getenv('PYTORCH_TRANSFORMERS_CACHE',
57 | os.getenv('PYTORCH_PRETRAINED_BERT_CACHE',
58 | default_cache_path))
59 |
60 | PYTORCH_TRANSFORMERS_CACHE = PYTORCH_PRETRAINED_BERT_CACHE # Kept for backward compatibility
61 | TRANSFORMERS_CACHE = PYTORCH_PRETRAINED_BERT_CACHE # Kept for backward compatibility
62 |
63 | WEIGHTS_NAME = "pytorch_model.bin"
64 | TF2_WEIGHTS_NAME = 'tf_model.h5'
65 | TF_WEIGHTS_NAME = 'model.ckpt'
66 | CONFIG_NAME = "config.json"
67 |
68 | def is_torch_available():
69 | return _torch_available
70 |
71 | def is_tf_available():
72 | return _tf_available
73 |
74 | if not six.PY2:
75 | def add_start_docstrings(*docstr):
76 | def docstring_decorator(fn):
77 | fn.__doc__ = ''.join(docstr) + fn.__doc__
78 | return fn
79 | return docstring_decorator
80 |
81 | def add_end_docstrings(*docstr):
82 | def docstring_decorator(fn):
83 | fn.__doc__ = fn.__doc__ + ''.join(docstr)
84 | return fn
85 | return docstring_decorator
86 | else:
87 | # Not possible to update class docstrings on python2
88 | def add_start_docstrings(*docstr):
89 | def docstring_decorator(fn):
90 | return fn
91 | return docstring_decorator
92 |
93 | def add_end_docstrings(*docstr):
94 | def docstring_decorator(fn):
95 | return fn
96 | return docstring_decorator
97 |
98 | def url_to_filename(url, etag=None):
99 | """
100 | Convert `url` into a hashed filename in a repeatable way.
101 | If `etag` is specified, append its hash to the url's, delimited
102 | by a period.
103 | If the url ends with .h5 (Keras HDF5 weights) ands '.h5' to the name
104 | so that TF 2.0 can identify it as a HDF5 file
105 | (see https://github.com/tensorflow/tensorflow/blob/00fad90125b18b80fe054de1055770cfb8fe4ba3/tensorflow/python/keras/engine/network.py#L1380)
106 | """
107 | url_bytes = url.encode('utf-8')
108 | url_hash = sha256(url_bytes)
109 | filename = url_hash.hexdigest()
110 |
111 | if etag:
112 | etag_bytes = etag.encode('utf-8')
113 | etag_hash = sha256(etag_bytes)
114 | filename += '.' + etag_hash.hexdigest()
115 |
116 | if url.endswith('.h5'):
117 | filename += '.h5'
118 |
119 | return filename
120 |
121 |
122 | def filename_to_url(filename, cache_dir=None):
123 | """
124 | Return the url and etag (which may be ``None``) stored for `filename`.
125 | Raise ``EnvironmentError`` if `filename` or its stored metadata do not exist.
126 | """
127 | if cache_dir is None:
128 | cache_dir = TRANSFORMERS_CACHE
129 | if sys.version_info[0] == 3 and isinstance(cache_dir, Path):
130 | cache_dir = str(cache_dir)
131 |
132 | cache_path = os.path.join(cache_dir, filename)
133 | if not os.path.exists(cache_path):
134 | raise EnvironmentError("file {} not found".format(cache_path))
135 |
136 | meta_path = cache_path + '.json'
137 | if not os.path.exists(meta_path):
138 | raise EnvironmentError("file {} not found".format(meta_path))
139 |
140 | with open(meta_path, encoding="utf-8") as meta_file:
141 | metadata = json.load(meta_file)
142 | url = metadata['url']
143 | etag = metadata['etag']
144 |
145 | return url, etag
146 |
147 |
148 | def cached_path(url_or_filename, cache_dir=None, force_download=False, proxies=None):
149 | """
150 | Given something that might be a URL (or might be a local path),
151 | determine which. If it's a URL, download the file and cache it, and
152 | return the path to the cached file. If it's already a local path,
153 | make sure the file exists and then return the path.
154 | Args:
155 | cache_dir: specify a cache directory to save the file to (overwrite the default cache dir).
156 | force_download: if True, re-dowload the file even if it's already cached in the cache dir.
157 | """
158 | if cache_dir is None:
159 | cache_dir = TRANSFORMERS_CACHE
160 | if sys.version_info[0] == 3 and isinstance(url_or_filename, Path):
161 | url_or_filename = str(url_or_filename)
162 | if sys.version_info[0] == 3 and isinstance(cache_dir, Path):
163 | cache_dir = str(cache_dir)
164 |
165 | parsed = urlparse(url_or_filename)
166 |
167 | if parsed.scheme in ('http', 'https', 's3'):
168 | # URL, so get it from the cache (downloading if necessary)
169 | return get_from_cache(url_or_filename, cache_dir=cache_dir, force_download=force_download, proxies=proxies)
170 | elif os.path.exists(url_or_filename):
171 | # File, and it exists.
172 | return url_or_filename
173 | elif parsed.scheme == '':
174 | # File, but it doesn't exist.
175 | raise EnvironmentError("file {} not found".format(url_or_filename))
176 | else:
177 | # Something unknown
178 | raise ValueError("unable to parse {} as a URL or as a local path".format(url_or_filename))
179 |
180 |
181 | def split_s3_path(url):
182 | """Split a full s3 path into the bucket name and path."""
183 | parsed = urlparse(url)
184 | if not parsed.netloc or not parsed.path:
185 | raise ValueError("bad s3 path {}".format(url))
186 | bucket_name = parsed.netloc
187 | s3_path = parsed.path
188 | # Remove '/' at beginning of path.
189 | if s3_path.startswith("/"):
190 | s3_path = s3_path[1:]
191 | return bucket_name, s3_path
192 |
193 |
194 | def s3_request(func):
195 | """
196 | Wrapper function for s3 requests in order to create more helpful error
197 | messages.
198 | """
199 |
200 | @wraps(func)
201 | def wrapper(url, *args, **kwargs):
202 | try:
203 | return func(url, *args, **kwargs)
204 | except ClientError as exc:
205 | if int(exc.response["Error"]["Code"]) == 404:
206 | raise EnvironmentError("file {} not found".format(url))
207 | else:
208 | raise
209 |
210 | return wrapper
211 |
212 |
213 | @s3_request
214 | def s3_etag(url, proxies=None):
215 | """Check ETag on S3 object."""
216 | s3_resource = boto3.resource("s3", config=Config(proxies=proxies))
217 | bucket_name, s3_path = split_s3_path(url)
218 | s3_object = s3_resource.Object(bucket_name, s3_path)
219 | return s3_object.e_tag
220 |
221 |
222 | @s3_request
223 | def s3_get(url, temp_file, proxies=None):
224 | """Pull a file directly from S3."""
225 | s3_resource = boto3.resource("s3", config=Config(proxies=proxies))
226 | bucket_name, s3_path = split_s3_path(url)
227 | s3_resource.Bucket(bucket_name).download_fileobj(s3_path, temp_file)
228 |
229 |
230 | def http_get(url, temp_file, proxies=None):
231 | req = requests.get(url, stream=True, proxies=proxies)
232 | content_length = req.headers.get('Content-Length')
233 | total = int(content_length) if content_length is not None else None
234 | progress = tqdm(unit="B", total=total)
235 | for chunk in req.iter_content(chunk_size=1024):
236 | if chunk: # filter out keep-alive new chunks
237 | progress.update(len(chunk))
238 | temp_file.write(chunk)
239 | progress.close()
240 |
241 |
242 | def get_from_cache(url, cache_dir=None, force_download=False, proxies=None, etag_timeout=10):
243 | """
244 | Given a URL, look for the corresponding dataset in the local cache.
245 | If it's not there, download it. Then return the path to the cached file.
246 | """
247 | if cache_dir is None:
248 | cache_dir = TRANSFORMERS_CACHE
249 | if sys.version_info[0] == 3 and isinstance(cache_dir, Path):
250 | cache_dir = str(cache_dir)
251 | if sys.version_info[0] == 2 and not isinstance(cache_dir, str):
252 | cache_dir = str(cache_dir)
253 |
254 | if not os.path.exists(cache_dir):
255 | os.makedirs(cache_dir, exist_ok=True)
256 |
257 | # Get eTag to add to filename, if it exists.
258 | if url.startswith("s3://"):
259 | etag = s3_etag(url, proxies=proxies)
260 | else:
261 | try:
262 | response = requests.head(url, allow_redirects=True, proxies=proxies, timeout=etag_timeout)
263 | if response.status_code != 200:
264 | etag = None
265 | else:
266 | etag = response.headers.get("ETag")
267 | except (EnvironmentError, requests.exceptions.Timeout):
268 | etag = None
269 |
270 | if sys.version_info[0] == 2 and etag is not None:
271 | etag = etag.decode('utf-8')
272 | filename = url_to_filename(url, etag)
273 |
274 | # get cache path to put the file
275 | cache_path = os.path.join(cache_dir, filename)
276 |
277 | # If we don't have a connection (etag is None) and can't identify the file
278 | # try to get the last downloaded one
279 | if not os.path.exists(cache_path) and etag is None:
280 | matching_files = fnmatch.filter(os.listdir(cache_dir), filename + '.*')
281 | matching_files = list(filter(lambda s: not s.endswith('.json'), matching_files))
282 | if matching_files:
283 | cache_path = os.path.join(cache_dir, matching_files[-1])
284 |
285 | if not os.path.exists(cache_path) or force_download:
286 | # Download to temporary file, then copy to cache dir once finished.
287 | # Otherwise you get corrupt cache entries if the download gets interrupted.
288 | with tempfile.NamedTemporaryFile() as temp_file:
289 | logger.info("%s not found in cache or force_download set to True, downloading to %s", url, temp_file.name)
290 |
291 | # GET file object
292 | if url.startswith("s3://"):
293 | s3_get(url, temp_file, proxies=proxies)
294 | else:
295 | http_get(url, temp_file, proxies=proxies)
296 |
297 | # we are copying the file before closing it, so flush to avoid truncation
298 | temp_file.flush()
299 | # shutil.copyfileobj() starts at the current position, so go to the start
300 | temp_file.seek(0)
301 |
302 | logger.info("copying %s to cache at %s", temp_file.name, cache_path)
303 | with open(cache_path, 'wb') as cache_file:
304 | shutil.copyfileobj(temp_file, cache_file)
305 |
306 | logger.info("creating metadata file for %s", cache_path)
307 | meta = {'url': url, 'etag': etag}
308 | meta_path = cache_path + '.json'
309 | with open(meta_path, 'w') as meta_file:
310 | output_string = json.dumps(meta)
311 | if sys.version_info[0] == 2 and isinstance(output_string, str):
312 | output_string = unicode(output_string, 'utf-8') # The beauty of python 2
313 | meta_file.write(output_string)
314 |
315 | logger.info("removing temp file %s", temp_file.name)
316 |
317 | return cache_path
318 |
--------------------------------------------------------------------------------
/src/transformers19/modeling_gpt2.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2018 The OpenAI Team Authors and HuggingFace Inc. team.
3 | # Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved.
4 | #
5 | # Licensed under the Apache License, Version 2.0 (the "License");
6 | # you may not use this file except in compliance with the License.
7 | # You may obtain a copy of the License at
8 | #
9 | # http://www.apache.org/licenses/LICENSE-2.0
10 | #
11 | # Unless required by applicable law or agreed to in writing, software
12 | # distributed under the License is distributed on an "AS IS" BASIS,
13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 | # See the License for the specific language governing permissions and
15 | # limitations under the License.
16 | """PyTorch OpenAI GPT-2 model."""
17 |
18 | from __future__ import absolute_import, division, print_function, unicode_literals
19 |
20 | import collections
21 | import json
22 | import logging
23 | import math
24 | import os
25 | import sys
26 | from io import open
27 |
28 | import torch
29 | import torch.nn as nn
30 | from torch.nn import CrossEntropyLoss
31 | from torch.nn.parameter import Parameter
32 |
33 | from .modeling_utils import PreTrainedModel, Conv1D, prune_conv1d_layer, SequenceSummary
34 | from .configuration_gpt2 import GPT2Config
35 | from .file_utils import add_start_docstrings
36 |
37 | logger = logging.getLogger(__name__)
38 |
39 | GPT2_PRETRAINED_MODEL_ARCHIVE_MAP = {"gpt2": "https://s3.amazonaws.com/models.huggingface.co/bert/gpt2-pytorch_model.bin",
40 | "gpt2-medium": "https://s3.amazonaws.com/models.huggingface.co/bert/gpt2-medium-pytorch_model.bin",
41 | "gpt2-large": "https://s3.amazonaws.com/models.huggingface.co/bert/gpt2-large-pytorch_model.bin",
42 | "distilgpt2": "https://s3.amazonaws.com/models.huggingface.co/bert/distilgpt2-pytorch_model.bin",}
43 |
44 |
45 | def gelu(x):
46 | return 0.5 * x * (1 + torch.tanh(math.sqrt(2 / math.pi) * (x + 0.044715 * torch.pow(x, 3))))
47 |
48 |
49 | class Attention(nn.Module):
50 | def __init__(self, nx, n_ctx, config, scale=False):
51 | super(Attention, self).__init__()
52 | self.output_attentions = config.output_attentions
53 |
54 | n_state = nx # in Attention: n_state=768 (nx=n_embd)
55 | # [switch nx => n_state from Block to Attention to keep identical to TF implem]
56 | assert n_state % config.n_head == 0
57 | self.register_buffer("bias", torch.tril(torch.ones(n_ctx, n_ctx)).view(1, 1, n_ctx, n_ctx))
58 | self.n_head = config.n_head
59 | self.split_size = n_state
60 | self.scale = scale
61 |
62 | self.c_attn = Conv1D(n_state * 3, nx)
63 | self.c_proj = Conv1D(n_state, nx)
64 | self.attn_dropout = nn.Dropout(config.attn_pdrop)
65 | self.resid_dropout = nn.Dropout(config.resid_pdrop)
66 | self.pruned_heads = set()
67 |
68 | def prune_heads(self, heads):
69 | if len(heads) == 0:
70 | return
71 | mask = torch.ones(self.n_head, self.split_size // self.n_head)
72 | heads = set(heads) - self.pruned_heads # Convert to set and emove already pruned heads
73 | for head in heads:
74 | # Compute how many pruned heads are before the head and move the index accordingly
75 | head = head - sum(1 if h < head else 0 for h in self.pruned_heads)
76 | mask[head] = 0
77 | mask = mask.view(-1).contiguous().eq(1)
78 | index = torch.arange(len(mask))[mask].long()
79 | index_attn = torch.cat([index, index + self.split_size, index + (2*self.split_size)])
80 |
81 | # Prune conv1d layers
82 | self.c_attn = prune_conv1d_layer(self.c_attn, index_attn, dim=1)
83 | self.c_proj = prune_conv1d_layer(self.c_proj, index, dim=0)
84 |
85 | # Update hyper params
86 | self.split_size = (self.split_size // self.n_head) * (self.n_head - len(heads))
87 | self.n_head = self.n_head - len(heads)
88 | self.pruned_heads = self.pruned_heads.union(heads)
89 |
90 | def _attn(self, q, k, v, attention_mask=None, head_mask=None):
91 | w = torch.matmul(q, k)
92 | if self.scale:
93 | w = w / math.sqrt(v.size(-1))
94 | nd, ns = w.size(-2), w.size(-1)
95 | b = self.bias[:, :, ns-nd:ns, :ns]
96 | w = w * b - 1e4 * (1 - b)
97 |
98 | if attention_mask is not None:
99 | # Apply the attention mask
100 | w = w + attention_mask
101 |
102 | w = nn.Softmax(dim=-1)(w)
103 | w = self.attn_dropout(w)
104 |
105 | # Mask heads if we want to
106 | if head_mask is not None:
107 | w = w * head_mask
108 |
109 | outputs = [torch.matmul(w, v)]
110 | if self.output_attentions:
111 | outputs.append(w)
112 | return outputs
113 |
114 | def merge_heads(self, x):
115 | x = x.permute(0, 2, 1, 3).contiguous()
116 | new_x_shape = x.size()[:-2] + (x.size(-2) * x.size(-1),)
117 | return x.view(*new_x_shape) # in Tensorflow implem: fct merge_states
118 |
119 | def split_heads(self, x, k=False):
120 | new_x_shape = x.size()[:-1] + (self.n_head, x.size(-1) // self.n_head)
121 | x = x.view(*new_x_shape) # in Tensorflow implem: fct split_states
122 | if k:
123 | return x.permute(0, 2, 3, 1) # (batch, head, head_features, seq_length)
124 | else:
125 | return x.permute(0, 2, 1, 3) # (batch, head, seq_length, head_features)
126 |
127 | def forward(self, x, layer_past=None, attention_mask=None, head_mask=None):
128 | x = self.c_attn(x)
129 | query, key, value = x.split(self.split_size, dim=2)
130 | query = self.split_heads(query)
131 | key = self.split_heads(key, k=True)
132 | value = self.split_heads(value)
133 | if layer_past is not None:
134 | past_key, past_value = layer_past[0].transpose(-2, -1), layer_past[1] # transpose back cf below
135 | key = torch.cat((past_key, key), dim=-1)
136 | value = torch.cat((past_value, value), dim=-2)
137 | present = torch.stack((key.transpose(-2, -1), value)) # transpose to have same shapes for stacking
138 |
139 | attn_outputs = self._attn(query, key, value, attention_mask, head_mask)
140 | a = attn_outputs[0]
141 |
142 | a = self.merge_heads(a)
143 | a = self.c_proj(a)
144 | a = self.resid_dropout(a)
145 |
146 | outputs = [a, present] + attn_outputs[1:]
147 | return outputs # a, present, (attentions)
148 |
149 |
150 | class MLP(nn.Module):
151 | def __init__(self, n_state, config): # in MLP: n_state=3072 (4 * n_embd)
152 | super(MLP, self).__init__()
153 | nx = config.n_embd
154 | self.c_fc = Conv1D(n_state, nx)
155 | self.c_proj = Conv1D(nx, n_state)
156 | self.act = gelu
157 | self.dropout = nn.Dropout(config.resid_pdrop)
158 |
159 | def forward(self, x):
160 | h = self.act(self.c_fc(x))
161 | h2 = self.c_proj(h)
162 | return self.dropout(h2)
163 |
164 |
165 | class Block(nn.Module):
166 | def __init__(self, n_ctx, config, scale=False):
167 | super(Block, self).__init__()
168 | nx = config.n_embd
169 | self.ln_1 = nn.LayerNorm(nx, eps=config.layer_norm_epsilon)
170 | self.attn = Attention(nx, n_ctx, config, scale)
171 | self.ln_2 = nn.LayerNorm(nx, eps=config.layer_norm_epsilon)
172 | self.mlp = MLP(4 * nx, config)
173 |
174 | def forward(self, x, layer_past=None, attention_mask=None, head_mask=None):
175 | output_attn = self.attn(self.ln_1(x),
176 | layer_past=layer_past,
177 | attention_mask=attention_mask,
178 | head_mask=head_mask)
179 | a = output_attn[0] # output_attn: a, present, (attentions)
180 |
181 | x = x + a
182 | m = self.mlp(self.ln_2(x))
183 | x = x + m
184 |
185 | outputs = [x] + output_attn[1:]
186 | return outputs # x, present, (attentions)
187 |
188 |
189 | class GPT2PreTrainedModel(PreTrainedModel):
190 | """ An abstract class to handle weights initialization and
191 | a simple interface for dowloading and loading pretrained models.
192 | """
193 | config_class = GPT2Config
194 | pretrained_model_archive_map = GPT2_PRETRAINED_MODEL_ARCHIVE_MAP
195 | #load_tf_weights = load_tf_weights_in_gpt2
196 | base_model_prefix = "transformer"
197 |
198 | def __init__(self, *inputs, **kwargs):
199 | super(GPT2PreTrainedModel, self).__init__(*inputs, **kwargs)
200 |
201 | def _init_weights(self, module):
202 | """ Initialize the weights.
203 | """
204 | if isinstance(module, (nn.Linear, nn.Embedding, Conv1D)):
205 | # Slightly different from the TF version which uses truncated_normal for initialization
206 | # cf https://github.com/pytorch/pytorch/pull/5617
207 | module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
208 | if isinstance(module, (nn.Linear, Conv1D)) and module.bias is not None:
209 | module.bias.data.zero_()
210 | elif isinstance(module, nn.LayerNorm):
211 | module.bias.data.zero_()
212 | module.weight.data.fill_(1.0)
213 |
214 |
215 | GPT2_START_DOCSTRING = r""" OpenAI GPT-2 model was proposed in
216 | `Language Models are Unsupervised Multitask Learners`_
217 | by Alec Radford*, Jeffrey Wu*, Rewon Child, David Luan, Dario Amodei** and Ilya Sutskever**.
218 | It's a causal (unidirectional) transformer pre-trained using language modeling on a very large
219 | corpus of ~40 GB of text data.
220 |
221 | This model is a PyTorch `torch.nn.Module`_ sub-class. Use it as a regular PyTorch Module and
222 | refer to the PyTorch documentation for all matter related to general usage and behavior.
223 |
224 | .. _`Language Models are Unsupervised Multitask Learners`:
225 | https://openai.com/blog/better-language-models/
226 |
227 | .. _`torch.nn.Module`:
228 | https://pytorch.org/docs/stable/nn.html#module
229 |
230 | Parameters:
231 | config (:class:`~transformers.GPT2Config`): Model configuration class with all the parameters of the model.
232 | Initializing with a config file does not load the weights associated with the model, only the configuration.
233 | Check out the :meth:`~transformers.PreTrainedModel.from_pretrained` method to load the model weights.
234 | """
235 |
236 | GPT2_INPUTS_DOCSTRING = r""" Inputs:
237 | **input_ids**: ``torch.LongTensor`` of shape ``(batch_size, sequence_length)``:
238 | Indices of input sequence tokens in the vocabulary.
239 | GPT-2 is a model with absolute position embeddings so it's usually advised to pad the inputs on
240 | the right rather than the left.
241 | Indices can be obtained using :class:`transformers.GPT2Tokenizer`.
242 | See :func:`transformers.PreTrainedTokenizer.encode` and
243 | :func:`transformers.PreTrainedTokenizer.convert_tokens_to_ids` for details.
244 | **past**:
245 | list of ``torch.FloatTensor`` (one for each layer):
246 | that contains pre-computed hidden-states (key and values in the attention blocks) as computed by the model
247 | (see `past` output below). Can be used to speed up sequential decoding.
248 | **attention_mask**: (`optional`) ``torch.FloatTensor`` of shape ``(batch_size, sequence_length)``:
249 | Mask to avoid performing attention on padding token indices.
250 | Mask values selected in ``[0, 1]``:
251 | ``1`` for tokens that are NOT MASKED, ``0`` for MASKED tokens.
252 | **token_type_ids**: (`optional`) ``torch.LongTensor`` of shape ``(batch_size, sequence_length)``:
253 | A parallel sequence of tokens (can be used to indicate various portions of the inputs).
254 | The embeddings from these tokens will be summed with the respective token embeddings.
255 | Indices are selected in the vocabulary (unlike BERT which has a specific vocabulary for segment indices).
256 | **position_ids**: (`optional`) ``torch.LongTensor`` of shape ``(batch_size, sequence_length)``:
257 | Indices of positions of each input sequence tokens in the position embeddings.
258 | Selected in the range ``[0, config.max_position_embeddings - 1]``.
259 | **head_mask**: (`optional`) ``torch.FloatTensor`` of shape ``(num_heads,)`` or ``(num_layers, num_heads)``:
260 | Mask to nullify selected heads of the self-attention modules.
261 | Mask values selected in ``[0, 1]``:
262 | ``1`` indicates the head is **not masked**, ``0`` indicates the head is **masked**.
263 | """
264 |
265 | @add_start_docstrings("The bare GPT2 Model transformer outputting raw hidden-states without any specific head on top.",
266 | GPT2_START_DOCSTRING, GPT2_INPUTS_DOCSTRING)
267 | class GPT2Model(GPT2PreTrainedModel):
268 | r"""
269 | Outputs: `Tuple` comprising various elements depending on the configuration (config) and inputs:
270 | **last_hidden_state**: ``torch.FloatTensor`` of shape ``(batch_size, sequence_length, hidden_size)``
271 | Sequence of hidden-states at the last layer of the model.
272 | **past**:
273 | list of ``torch.FloatTensor`` (one for each layer) of shape ``(batch_size, num_heads, sequence_length, sequence_length)``:
274 | that contains pre-computed hidden-states (key and values in the attention blocks).
275 | Can be used (see `past` input) to speed up sequential decoding.
276 | **hidden_states**: (`optional`, returned when ``config.output_hidden_states=True``)
277 | list of ``torch.FloatTensor`` (one for the output of each layer + the output of the embeddings)
278 | of shape ``(batch_size, sequence_length, hidden_size)``:
279 | Hidden-states of the model at the output of each layer plus the initial embedding outputs.
280 | **attentions**: (`optional`, returned when ``config.output_attentions=True``)
281 | list of ``torch.FloatTensor`` (one for each layer) of shape ``(batch_size, num_heads, sequence_length, sequence_length)``:
282 | Attentions weights after the attention softmax, used to compute the weighted average in the self-attention heads.
283 |
284 | Examples::
285 |
286 | tokenizer = GPT2Tokenizer.from_pretrained('gpt2')
287 | model = GPT2Model.from_pretrained('gpt2')
288 | input_ids = torch.tensor(tokenizer.encode("Hello, my dog is cute")).unsqueeze(0) # Batch size 1
289 | outputs = model(input_ids)
290 | last_hidden_states = outputs[0] # The last hidden-state is the first element of the output tuple
291 |
292 | """
293 | def __init__(self, config):
294 | super(GPT2Model, self).__init__(config)
295 | self.output_hidden_states = config.output_hidden_states
296 | self.output_attentions = config.output_attentions
297 | self.output_past = config.output_past
298 |
299 | self.wte = nn.Embedding(config.vocab_size, config.n_embd)
300 | self.wpe = nn.Embedding(config.n_positions, config.n_embd)
301 | self.drop = nn.Dropout(config.embd_pdrop)
302 | self.h = nn.ModuleList([Block(config.n_ctx, config, scale=True) for _ in range(config.n_layer)])
303 | self.ln_f = nn.LayerNorm(config.n_embd, eps=config.layer_norm_epsilon)
304 |
305 | self.init_weights()
306 |
307 | def _resize_token_embeddings(self, new_num_tokens):
308 | self.wte = self._get_resized_embeddings(self.wte, new_num_tokens)
309 | return self.wte
310 |
311 | def _prune_heads(self, heads_to_prune):
312 | """ Prunes heads of the model.
313 | heads_to_prune: dict of {layer_num: list of heads to prune in this layer}
314 | """
315 | for layer, heads in heads_to_prune.items():
316 | self.h[layer].attn.prune_heads(heads)
317 |
318 | def forward(self, input_ids, past=None, attention_mask=None, token_type_ids=None, position_ids=None, head_mask=None):
319 | input_shape = input_ids.size()
320 | input_ids = input_ids.view(-1, input_shape[-1])
321 | if token_type_ids is not None:
322 | token_type_ids = token_type_ids.view(-1, input_shape[-1])
323 | if position_ids is not None:
324 | position_ids = position_ids.view(-1, input_shape[-1])
325 |
326 | if past is None:
327 | past_length = 0
328 | past = [None] * len(self.h)
329 | else:
330 | past_length = past[0][0].size(-2)
331 | if position_ids is None:
332 | position_ids = torch.arange(past_length, input_ids.size(-1) + past_length, dtype=torch.long, device=input_ids.device)
333 | position_ids = position_ids.unsqueeze(0).expand_as(input_ids)
334 |
335 | # Attention mask.
336 | if attention_mask is not None:
337 | attention_mask = attention_mask.view(-1, input_shape[-1])
338 | # We create a 3D attention mask from a 2D tensor mask.
339 | # Sizes are [batch_size, 1, 1, to_seq_length]
340 | # So we can broadcast to [batch_size, num_heads, from_seq_length, to_seq_length]
341 | # this attention mask is more simple than the triangular masking of causal attention
342 | # used in OpenAI GPT, we just need to prepare the broadcast dimension here.
343 | attention_mask = attention_mask.unsqueeze(1).unsqueeze(2)
344 |
345 | # Since attention_mask is 1.0 for positions we want to attend and 0.0 for
346 | # masked positions, this operation will create a tensor which is 0.0 for
347 | # positions we want to attend and -10000.0 for masked positions.
348 | # Since we are adding it to the raw scores before the softmax, this is
349 | # effectively the same as removing these entirely.
350 | attention_mask = attention_mask.to(dtype=next(self.parameters()).dtype) # fp16 compatibility
351 | attention_mask = (1.0 - attention_mask) * -10000.0
352 |
353 | # Prepare head mask if needed
354 | # 1.0 in head_mask indicate we keep the head
355 | # attention_probs has shape bsz x n_heads x N x N
356 | # head_mask has shape n_layer x batch x n_heads x N x N
357 | if head_mask is not None:
358 | if head_mask.dim() == 1:
359 | head_mask = head_mask.unsqueeze(0).unsqueeze(0).unsqueeze(-1).unsqueeze(-1)
360 | head_mask = head_mask.expand(self.config.n_layer, -1, -1, -1, -1)
361 | elif head_mask.dim() == 2:
362 | head_mask = head_mask.unsqueeze(1).unsqueeze(-1).unsqueeze(-1) # We can specify head_mask for each layer
363 | head_mask = head_mask.to(dtype=next(self.parameters()).dtype) # switch to fload if need + fp16 compatibility
364 | else:
365 | head_mask = [None] * self.config.n_layer
366 |
367 | inputs_embeds = self.wte(input_ids)
368 | position_embeds = self.wpe(position_ids)
369 | if token_type_ids is not None:
370 | token_type_embeds = self.wte(token_type_ids)
371 | else:
372 | token_type_embeds = 0
373 | hidden_states = inputs_embeds + position_embeds + token_type_embeds
374 | hidden_states = self.drop(hidden_states)
375 |
376 | output_shape = input_shape + (hidden_states.size(-1),)
377 |
378 | presents = ()
379 | all_attentions = []
380 | all_hidden_states = ()
381 | for i, (block, layer_past) in enumerate(zip(self.h, past)):
382 | if self.output_hidden_states:
383 | all_hidden_states = all_hidden_states + (hidden_states.view(*output_shape),)
384 |
385 | outputs = block(hidden_states,
386 | layer_past=layer_past,
387 | attention_mask=attention_mask,
388 | head_mask=head_mask[i])
389 |
390 | hidden_states, present = outputs[:2]
391 | if self.output_past:
392 | presents = presents + (present,)
393 |
394 | if self.output_attentions:
395 | all_attentions.append(outputs[2])
396 |
397 | hidden_states = self.ln_f(hidden_states)
398 |
399 | hidden_states = hidden_states.view(*output_shape)
400 | # Add last hidden state
401 | if self.output_hidden_states:
402 | all_hidden_states = all_hidden_states + (hidden_states,)
403 |
404 | outputs = (hidden_states,)
405 | if self.output_past:
406 | outputs = outputs + (presents,)
407 | if self.output_hidden_states:
408 | outputs = outputs + (all_hidden_states,)
409 | if self.output_attentions:
410 | # let the number of heads free (-1) so we can extract attention even after head pruning
411 | attention_output_shape = input_shape[:-1] + (-1,) + all_attentions[0].shape[-2:]
412 | all_attentions = tuple(t.view(*attention_output_shape) for t in all_attentions)
413 | outputs = outputs + (all_attentions,)
414 | return outputs # last hidden state, (presents), (all hidden_states), (attentions)
415 |
416 |
417 | @add_start_docstrings("""The GPT2 Model transformer with a language modeling head on top
418 | (linear layer with weights tied to the input embeddings). """, GPT2_START_DOCSTRING, GPT2_INPUTS_DOCSTRING)
419 | class GPT2LMHeadModel(GPT2PreTrainedModel):
420 | r"""
421 | **labels**: (`optional`) ``torch.LongTensor`` of shape ``(batch_size, sequence_length)``:
422 | Labels for language modeling.
423 | Note that the labels **are shifted** inside the model, i.e. you can set ``lm_labels = input_ids``
424 | Indices are selected in ``[-1, 0, ..., config.vocab_size]``
425 | All labels set to ``-1`` are ignored (masked), the loss is only
426 | computed for labels in ``[0, ..., config.vocab_size]``
427 |
428 | Outputs: `Tuple` comprising various elements depending on the configuration (config) and inputs:
429 | **loss**: (`optional`, returned when ``labels`` is provided) ``torch.FloatTensor`` of shape ``(1,)``:
430 | Language modeling loss.
431 | **prediction_scores**: ``torch.FloatTensor`` of shape ``(batch_size, sequence_length, config.vocab_size)``
432 | Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax).
433 | **past**:
434 | list of ``torch.FloatTensor`` (one for each layer) of shape ``(batch_size, num_heads, sequence_length, sequence_length)``:
435 | that contains pre-computed hidden-states (key and values in the attention blocks).
436 | Can be used (see `past` input) to speed up sequential decoding.
437 | **hidden_states**: (`optional`, returned when ``config.output_hidden_states=True``)
438 | list of ``torch.FloatTensor`` (one for the output of each layer + the output of the embeddings)
439 | of shape ``(batch_size, sequence_length, hidden_size)``:
440 | Hidden-states of the model at the output of each layer plus the initial embedding outputs.
441 | **attentions**: (`optional`, returned when ``config.output_attentions=True``)
442 | list of ``torch.FloatTensor`` (one for each layer) of shape ``(batch_size, num_heads, sequence_length, sequence_length)``:
443 | Attentions weights after the attention softmax, used to compute the weighted average in the self-attention heads.
444 |
445 | Examples::
446 |
447 | import torch
448 | from transformers import GPT2Tokenizer, GPT2LMHeadModel
449 |
450 | tokenizer = GPT2Tokenizer.from_pretrained('gpt2')
451 | model = GPT2LMHeadModel.from_pretrained('gpt2')
452 |
453 | input_ids = torch.tensor(tokenizer.encode("Hello, my dog is cute")).unsqueeze(0) # Batch size 1
454 | outputs = model(input_ids, labels=input_ids)
455 | loss, logits = outputs[:2]
456 |
457 | """
458 | def __init__(self, config):
459 | super(GPT2LMHeadModel, self).__init__(config)
460 | self.transformer = GPT2Model(config)
461 | self.lm_head = nn.Linear(config.n_embd, config.vocab_size, bias=False)
462 |
463 | self.init_weights()
464 | self.tie_weights()
465 |
466 | def tie_weights(self):
467 | """ Make sure we are sharing the input and output embeddings.
468 | Export to TorchScript can't handle parameter sharing so we are cloning them instead.
469 | """
470 | self._tie_or_clone_weights(self.lm_head,
471 | self.transformer.wte)
472 |
473 | def forward(self, input_ids, past=None, attention_mask=None, token_type_ids=None, position_ids=None, head_mask=None,
474 | labels=None):
475 | transformer_outputs = self.transformer(input_ids,
476 | past=past,
477 | attention_mask=attention_mask,
478 | token_type_ids=token_type_ids,
479 | position_ids=position_ids,
480 | head_mask=head_mask)
481 | hidden_states = transformer_outputs[0]
482 |
483 | lm_logits = self.lm_head(hidden_states)
484 |
485 | outputs = (lm_logits,) + transformer_outputs[1:]
486 | if labels is not None:
487 | # Shift so that tokens < n predict n
488 | shift_logits = lm_logits[..., :-1, :].contiguous()
489 | shift_labels = labels[..., 1:].contiguous()
490 | # Flatten the tokens
491 | loss_fct = CrossEntropyLoss(ignore_index=-1)
492 | loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)),
493 | shift_labels.view(-1))
494 | outputs = (loss,) + outputs
495 |
496 | return outputs # (loss), lm_logits, presents, (all hidden_states), (attentions)
497 |
498 |
499 | @add_start_docstrings("""The GPT2 Model transformer with a language modeling and a multiple-choice classification
500 | head on top e.g. for RocStories/SWAG tasks. The two heads are two linear layers.
501 | The language modeling head has its weights tied to the input embeddings,
502 | the classification head takes as input the input of a specified classification token index in the input sequence).
503 | """, GPT2_START_DOCSTRING, GPT2_INPUTS_DOCSTRING)
504 | class GPT2DoubleHeadsModel(GPT2PreTrainedModel):
505 | r"""
506 | **mc_token_ids**: (`optional`, default to index of the last token of the input) ``torch.LongTensor`` of shape ``(batch_size, num_choices)``:
507 | Index of the classification token in each input sequence.
508 | Selected in the range ``[0, input_ids.size(-1) - 1[``.
509 | **lm_labels**: (`optional`) ``torch.LongTensor`` of shape ``(batch_size, sequence_length)``:
510 | Labels for language modeling.
511 | Note that the labels **are shifted** inside the model, i.e. you can set ``lm_labels = input_ids``
512 | Indices are selected in ``[-1, 0, ..., config.vocab_size]``
513 | All labels set to ``-1`` are ignored (masked), the loss is only
514 | computed for labels in ``[0, ..., config.vocab_size]``
515 | **mc_labels**: (`optional`) ``torch.LongTensor`` of shape ``(batch_size)``:
516 | Labels for computing the multiple choice classification loss.
517 | Indices should be in ``[0, ..., num_choices]`` where `num_choices` is the size of the second dimension
518 | of the input tensors. (see `input_ids` above)
519 |
520 | Outputs: `Tuple` comprising various elements depending on the configuration (config) and inputs:
521 | **lm_loss**: (`optional`, returned when ``lm_labels`` is provided) ``torch.FloatTensor`` of shape ``(1,)``:
522 | Language modeling loss.
523 | **mc_loss**: (`optional`, returned when ``multiple_choice_labels`` is provided) ``torch.FloatTensor`` of shape ``(1,)``:
524 | Multiple choice classification loss.
525 | **lm_prediction_scores**: ``torch.FloatTensor`` of shape ``(batch_size, num_choices, sequence_length, config.vocab_size)``
526 | Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax).
527 | **mc_prediction_scores**: ``torch.FloatTensor`` of shape ``(batch_size, num_choices)``
528 | Prediction scores of the multiplechoice classification head (scores for each choice before SoftMax).
529 | **past**:
530 | list of ``torch.FloatTensor`` (one for each layer) of shape ``(batch_size, num_heads, sequence_length, sequence_length)``:
531 | that contains pre-computed hidden-states (key and values in the attention blocks).
532 | Can be used (see `past` input) to speed up sequential decoding.
533 | **hidden_states**: (`optional`, returned when ``config.output_hidden_states=True``)
534 | list of ``torch.FloatTensor`` (one for the output of each layer + the output of the embeddings)
535 | of shape ``(batch_size, sequence_length, hidden_size)``:
536 | Hidden-states of the model at the output of each layer plus the initial embedding outputs.
537 | **attentions**: (`optional`, returned when ``config.output_attentions=True``)
538 | list of ``torch.FloatTensor`` (one for each layer) of shape ``(batch_size, num_heads, sequence_length, sequence_length)``:
539 | Attentions weights after the attention softmax, used to compute the weighted average in the self-attention heads.
540 |
541 | Examples::
542 |
543 | import torch
544 | from transformers import GPT2Tokenizer, GPT2DoubleHeadsModel
545 |
546 | tokenizer = GPT2Tokenizer.from_pretrained('gpt2')
547 | model = GPT2DoubleHeadsModel.from_pretrained('gpt2')
548 |
549 | # Add a [CLS] to the vocabulary (we should train it also!)
550 | tokenizer.add_special_tokens({'cls_token': '[CLS]'})
551 | model.resize_token_embeddings(len(tokenizer)) # Update the model embeddings with the new vocabulary size
552 | print(tokenizer.cls_token_id, len(tokenizer)) # The newly token the last token of the vocabulary
553 |
554 | choices = ["Hello, my dog is cute [CLS]", "Hello, my cat is cute [CLS]"]
555 | encoded_choices = [tokenizer.encode(s) for s in choices]
556 | cls_token_location = [tokens.index(tokenizer.cls_token_id) for tokens in encoded_choices]
557 |
558 | input_ids = torch.tensor(encoded_choices).unsqueeze(0) # Batch size: 1, number of choices: 2
559 | mc_token_ids = torch.tensor([cls_token_location]) # Batch size: 1
560 |
561 | outputs = model(input_ids, mc_token_ids=mc_token_ids)
562 | lm_prediction_scores, mc_prediction_scores = outputs[:2]
563 |
564 | """
565 | def __init__(self, config):
566 | super(GPT2DoubleHeadsModel, self).__init__(config)
567 | self.transformer = GPT2Model(config)
568 | self.lm_head = nn.Linear(config.n_embd, config.vocab_size, bias=False)
569 | self.multiple_choice_head = SequenceSummary(config)
570 |
571 | self.init_weights()
572 | self.tie_weights()
573 |
574 | def tie_weights(self):
575 | """ Make sure we are sharing the input and output embeddings.
576 | Export to TorchScript can't handle parameter sharing so we are cloning them instead.
577 | """
578 | self._tie_or_clone_weights(self.lm_head,
579 | self.transformer.wte)
580 |
581 | def forward(self, input_ids, past=None, attention_mask=None, token_type_ids=None, position_ids=None, head_mask=None,
582 | mc_token_ids=None, lm_labels=None, mc_labels=None):
583 | transformer_outputs = self.transformer(input_ids,
584 | past=past,
585 | attention_mask=attention_mask,
586 | token_type_ids=token_type_ids,
587 | position_ids=position_ids,
588 | head_mask=head_mask)
589 |
590 | hidden_states = transformer_outputs[0]
591 |
592 | lm_logits = self.lm_head(hidden_states)
593 | mc_logits = self.multiple_choice_head(hidden_states, mc_token_ids).squeeze(-1)
594 |
595 | outputs = (lm_logits, mc_logits) + transformer_outputs[1:]
596 | if mc_labels is not None:
597 | loss_fct = CrossEntropyLoss()
598 | loss = loss_fct(mc_logits.view(-1, mc_logits.size(-1)),
599 | mc_labels.view(-1))
600 | outputs = (loss,) + outputs
601 | if lm_labels is not None:
602 | shift_logits = lm_logits[..., :-1, :].contiguous()
603 | shift_labels = lm_labels[..., 1:].contiguous()
604 | loss_fct = CrossEntropyLoss(ignore_index=-1)
605 | loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)),
606 | shift_labels.view(-1))
607 | outputs = (loss,) + outputs
608 |
609 | return outputs # (lm loss), (mc loss), lm logits, mc logits, presents, (all hidden_states), (attentions)
610 |
--------------------------------------------------------------------------------
/src/transformers19/modeling_utils.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2018 The Google AI Language Team Authors and The HuggingFace Inc. team.
3 | # Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved.
4 | #
5 | # Licensed under the Apache License, Version 2.0 (the "License");
6 | # you may not use this file except in compliance with the License.
7 | # You may obtain a copy of the License at
8 | #
9 | # http://www.apache.org/licenses/LICENSE-2.0
10 | #
11 | # Unless required by applicable law or agreed to in writing, software
12 | # distributed under the License is distributed on an "AS IS" BASIS,
13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 | # See the License for the specific language governing permissions and
15 | # limitations under the License.
16 | """PyTorch BERT model."""
17 |
18 | from __future__ import (absolute_import, division, print_function,
19 | unicode_literals)
20 |
21 | import copy
22 | import json
23 | import logging
24 | import os
25 | from io import open
26 |
27 | import six
28 | import torch
29 | from torch import nn
30 | from torch.nn import CrossEntropyLoss
31 | from torch.nn import functional as F
32 |
33 | from .configuration_utils import PretrainedConfig
34 | from .file_utils import cached_path, WEIGHTS_NAME, TF_WEIGHTS_NAME, TF2_WEIGHTS_NAME
35 |
36 | logger = logging.getLogger(__name__)
37 |
38 |
39 | try:
40 | from torch.nn import Identity
41 | except ImportError:
42 | # Older PyTorch compatibility
43 | class Identity(nn.Module):
44 | r"""A placeholder identity operator that is argument-insensitive.
45 | """
46 | def __init__(self, *args, **kwargs):
47 | super(Identity, self).__init__()
48 |
49 | def forward(self, input):
50 | return input
51 |
52 | class PreTrainedModel(nn.Module):
53 | r""" Base class for all models.
54 |
55 | :class:`~transformers.PreTrainedModel` takes care of storing the configuration of the models and handles methods for loading/downloading/saving models
56 | as well as a few methods commons to all models to (i) resize the input embeddings and (ii) prune heads in the self-attention heads.
57 |
58 | Class attributes (overridden by derived classes):
59 | - ``config_class``: a class derived from :class:`~transformers.PretrainedConfig` to use as configuration class for this model architecture.
60 | - ``pretrained_model_archive_map``: a python ``dict`` of with `short-cut-names` (string) as keys and `url` (string) of associated pretrained weights as values.
61 | - ``load_tf_weights``: a python ``method`` for loading a TensorFlow checkpoint in a PyTorch model, taking as arguments:
62 |
63 | - ``model``: an instance of the relevant subclass of :class:`~transformers.PreTrainedModel`,
64 | - ``config``: an instance of the relevant subclass of :class:`~transformers.PretrainedConfig`,
65 | - ``path``: a path (string) to the TensorFlow checkpoint.
66 |
67 | - ``base_model_prefix``: a string indicating the attribute associated to the base model in derived classes of the same architecture adding modules on top of the base model.
68 | """
69 | config_class = None
70 | pretrained_model_archive_map = {}
71 | load_tf_weights = lambda model, config, path: None
72 | base_model_prefix = ""
73 |
74 | def __init__(self, config, *inputs, **kwargs):
75 | super(PreTrainedModel, self).__init__()
76 | if not isinstance(config, PretrainedConfig):
77 | raise ValueError(
78 | "Parameter config in `{}(config)` should be an instance of class `PretrainedConfig`. "
79 | "To create a model from a pretrained model use "
80 | "`model = {}.from_pretrained(PRETRAINED_MODEL_NAME)`".format(
81 | self.__class__.__name__, self.__class__.__name__
82 | ))
83 | # Save config in model
84 | self.config = config
85 |
86 | def _get_resized_embeddings(self, old_embeddings, new_num_tokens=None):
87 | """ Build a resized Embedding Module from a provided token Embedding Module.
88 | Increasing the size will add newly initialized vectors at the end
89 | Reducing the size will remove vectors from the end
90 |
91 | Args:
92 | new_num_tokens: (`optional`) int
93 | New number of tokens in the embedding matrix.
94 | Increasing the size will add newly initialized vectors at the end
95 | Reducing the size will remove vectors from the end
96 | If not provided or None: return the provided token Embedding Module.
97 | Return: ``torch.nn.Embeddings``
98 | Pointer to the resized Embedding Module or the old Embedding Module if new_num_tokens is None
99 | """
100 | if new_num_tokens is None:
101 | return old_embeddings
102 |
103 | old_num_tokens, old_embedding_dim = old_embeddings.weight.size()
104 | if old_num_tokens == new_num_tokens:
105 | return old_embeddings
106 |
107 | # Build new embeddings
108 | new_embeddings = nn.Embedding(new_num_tokens, old_embedding_dim)
109 | new_embeddings.to(old_embeddings.weight.device)
110 |
111 | # initialize all new embeddings (in particular added tokens)
112 | self._init_weights(new_embeddings)
113 |
114 | # Copy word embeddings from the previous weights
115 | num_tokens_to_copy = min(old_num_tokens, new_num_tokens)
116 | new_embeddings.weight.data[:num_tokens_to_copy, :] = old_embeddings.weight.data[:num_tokens_to_copy, :]
117 |
118 | return new_embeddings
119 |
120 | def _tie_or_clone_weights(self, first_module, second_module):
121 | """ Tie or clone module weights depending of weither we are using TorchScript or not
122 | """
123 | if self.config.torchscript:
124 | first_module.weight = nn.Parameter(second_module.weight.clone())
125 | else:
126 | first_module.weight = second_module.weight
127 |
128 | if hasattr(first_module, 'bias') and first_module.bias is not None:
129 | first_module.bias.data = torch.nn.functional.pad(
130 | first_module.bias.data,
131 | (0, first_module.weight.shape[0] - first_module.bias.shape[0]),
132 | 'constant',
133 | 0
134 | )
135 |
136 | def resize_token_embeddings(self, new_num_tokens=None):
137 | """ Resize input token embeddings matrix of the model if new_num_tokens != config.vocab_size.
138 | Take care of tying weights embeddings afterwards if the model class has a `tie_weights()` method.
139 |
140 | Arguments:
141 |
142 | new_num_tokens: (`optional`) int:
143 | New number of tokens in the embedding matrix. Increasing the size will add newly initialized vectors at the end. Reducing the size will remove vectors from the end.
144 | If not provided or None: does nothing and just returns a pointer to the input tokens ``torch.nn.Embeddings`` Module of the model.
145 |
146 | Return: ``torch.nn.Embeddings``
147 | Pointer to the input tokens Embeddings Module of the model
148 | """
149 | base_model = getattr(self, self.base_model_prefix, self) # get the base model if needed
150 | model_embeds = base_model._resize_token_embeddings(new_num_tokens)
151 | if new_num_tokens is None:
152 | return model_embeds
153 |
154 | # Update base model and current model config
155 | self.config.vocab_size = new_num_tokens
156 | base_model.vocab_size = new_num_tokens
157 |
158 | # Tie weights again if needed
159 | if hasattr(self, 'tie_weights'):
160 | self.tie_weights()
161 |
162 | return model_embeds
163 |
164 | def init_weights(self):
165 | """ Initialize and prunes weights if needed. """
166 | # Initialize weights
167 | self.apply(self._init_weights)
168 |
169 | # Prune heads if needed
170 | if self.config.pruned_heads:
171 | self.prune_heads(self.config.pruned_heads)
172 |
173 | def prune_heads(self, heads_to_prune):
174 | """ Prunes heads of the base model.
175 |
176 | Arguments:
177 |
178 | heads_to_prune: dict with keys being selected layer indices (`int`) and associated values being the list of heads to prune in said layer (list of `int`).
179 | E.g. {1: [0, 2], 2: [2, 3]} will prune heads 0 and 2 on layer 1 and heads 2 and 3 on layer 2.
180 | """
181 | base_model = getattr(self, self.base_model_prefix, self) # get the base model if needed
182 |
183 | # save new sets of pruned heads as union of previously stored pruned heads and newly pruned heads
184 | for layer, heads in heads_to_prune.items():
185 | union_heads = set(self.config.pruned_heads.get(layer, [])) | set(heads)
186 | self.config.pruned_heads[layer] = list(union_heads) # Unfortunately we have to store it as list for JSON
187 |
188 | base_model._prune_heads(heads_to_prune)
189 |
190 | def save_pretrained(self, save_directory):
191 | """ Save a model and its configuration file to a directory, so that it
192 | can be re-loaded using the `:func:`~transformers.PreTrainedModel.from_pretrained`` class method.
193 | """
194 | assert os.path.isdir(save_directory), "Saving path should be a directory where the model and configuration can be saved"
195 |
196 | # Only save the model it-self if we are using distributed training
197 | model_to_save = self.module if hasattr(self, 'module') else self
198 |
199 | # Save configuration file
200 | model_to_save.config.save_pretrained(save_directory)
201 |
202 | # If we save using the predefined names, we can load using `from_pretrained`
203 | output_model_file = os.path.join(save_directory, WEIGHTS_NAME)
204 | torch.save(model_to_save.state_dict(), output_model_file)
205 | logger.info("Model weights saved in {}".format(output_model_file))
206 |
207 | @classmethod
208 | def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs):
209 | r"""Instantiate a pretrained pytorch model from a pre-trained model configuration.
210 |
211 | The model is set in evaluation mode by default using ``model.eval()`` (Dropout modules are deactivated)
212 | To train the model, you should first set it back in training mode with ``model.train()``
213 |
214 | The warning ``Weights from XXX not initialized from pretrained model`` means that the weights of XXX do not come pre-trained with the rest of the model.
215 | It is up to you to train those weights with a downstream fine-tuning task.
216 |
217 | The warning ``Weights from XXX not used in YYY`` means that the layer XXX is not used by YYY, therefore those weights are discarded.
218 |
219 | Parameters:
220 | pretrained_model_name_or_path: either:
221 |
222 | - a string with the `shortcut name` of a pre-trained model to load from cache or download, e.g.: ``bert-base-uncased``.
223 | - a path to a `directory` containing model weights saved using :func:`~transformers.PreTrainedModel.save_pretrained`, e.g.: ``./my_model_directory/``.
224 | - a path or url to a `tensorflow index checkpoint file` (e.g. `./tf_model/model.ckpt.index`). In this case, ``from_tf`` should be set to True and a configuration object should be provided as ``config`` argument. This loading path is slower than converting the TensorFlow checkpoint in a PyTorch model using the provided conversion scripts and loading the PyTorch model afterwards.
225 | - None if you are both providing the configuration and state dictionary (resp. with keyword arguments ``config`` and ``state_dict``)
226 |
227 | model_args: (`optional`) Sequence of positional arguments:
228 | All remaning positional arguments will be passed to the underlying model's ``__init__`` method
229 |
230 | config: (`optional`) instance of a class derived from :class:`~transformers.PretrainedConfig`:
231 | Configuration for the model to use instead of an automatically loaded configuation. Configuration can be automatically loaded when:
232 |
233 | - the model is a model provided by the library (loaded with the ``shortcut-name`` string of a pretrained model), or
234 | - the model was saved using :func:`~transformers.PreTrainedModel.save_pretrained` and is reloaded by suppling the save directory.
235 | - the model is loaded by suppling a local directory as ``pretrained_model_name_or_path`` and a configuration JSON file named `config.json` is found in the directory.
236 |
237 | state_dict: (`optional`) dict:
238 | an optional state dictionnary for the model to use instead of a state dictionary loaded from saved weights file.
239 | This option can be used if you want to create a model from a pretrained configuration but load your own weights.
240 | In this case though, you should check if using :func:`~transformers.PreTrainedModel.save_pretrained` and :func:`~transformers.PreTrainedModel.from_pretrained` is not a simpler option.
241 |
242 | cache_dir: (`optional`) string:
243 | Path to a directory in which a downloaded pre-trained model
244 | configuration should be cached if the standard cache should not be used.
245 |
246 | force_download: (`optional`) boolean, default False:
247 | Force to (re-)download the model weights and configuration files and override the cached versions if they exists.
248 |
249 | proxies: (`optional`) dict, default None:
250 | A dictionary of proxy servers to use by protocol or endpoint, e.g.: {'http': 'foo.bar:3128', 'http://hostname': 'foo.bar:4012'}.
251 | The proxies are used on each request.
252 |
253 | output_loading_info: (`optional`) boolean:
254 | Set to ``True`` to also return a dictionnary containing missing keys, unexpected keys and error messages.
255 |
256 | kwargs: (`optional`) Remaining dictionary of keyword arguments:
257 | Can be used to update the configuration object (after it being loaded) and initiate the model. (e.g. ``output_attention=True``). Behave differently depending on whether a `config` is provided or automatically loaded:
258 |
259 | - If a configuration is provided with ``config``, ``**kwargs`` will be directly passed to the underlying model's ``__init__`` method (we assume all relevant updates to the configuration have already been done)
260 | - If a configuration is not provided, ``kwargs`` will be first passed to the configuration class initialization function (:func:`~transformers.PretrainedConfig.from_pretrained`). Each key of ``kwargs`` that corresponds to a configuration attribute will be used to override said attribute with the supplied ``kwargs`` value. Remaining keys that do not correspond to any configuration attribute will be passed to the underlying model's ``__init__`` function.
261 |
262 | Examples::
263 |
264 | model = BertModel.from_pretrained('bert-base-uncased') # Download model and configuration from S3 and cache.
265 | model = BertModel.from_pretrained('./test/saved_model/') # E.g. model was saved using `save_pretrained('./test/saved_model/')`
266 | model = BertModel.from_pretrained('bert-base-uncased', output_attention=True) # Update configuration during loading
267 | assert model.config.output_attention == True
268 | # Loading from a TF checkpoint file instead of a PyTorch model (slower)
269 | config = BertConfig.from_json_file('./tf_model/my_tf_model_config.json')
270 | model = BertModel.from_pretrained('./tf_model/my_tf_checkpoint.ckpt.index', from_tf=True, config=config)
271 |
272 | """
273 | config = kwargs.pop('config', None)
274 | state_dict = kwargs.pop('state_dict', None)
275 | cache_dir = kwargs.pop('cache_dir', None)
276 | from_tf = kwargs.pop('from_tf', False)
277 | force_download = kwargs.pop('force_download', False)
278 | proxies = kwargs.pop('proxies', None)
279 | output_loading_info = kwargs.pop('output_loading_info', False)
280 |
281 | # Load config
282 | if config is None:
283 | config, model_kwargs = cls.config_class.from_pretrained(
284 | pretrained_model_name_or_path, *model_args,
285 | cache_dir=cache_dir, return_unused_kwargs=True,
286 | force_download=force_download,
287 | **kwargs
288 | )
289 | else:
290 | model_kwargs = kwargs
291 |
292 | # Load model
293 | if pretrained_model_name_or_path is not None:
294 | if pretrained_model_name_or_path in cls.pretrained_model_archive_map:
295 | archive_file = cls.pretrained_model_archive_map[pretrained_model_name_or_path]
296 | elif os.path.isdir(pretrained_model_name_or_path):
297 | if from_tf and os.path.isfile(os.path.join(pretrained_model_name_or_path, TF_WEIGHTS_NAME + ".index")):
298 | # Load from a TF 1.0 checkpoint
299 | archive_file = os.path.join(pretrained_model_name_or_path, TF_WEIGHTS_NAME + ".index")
300 | elif from_tf and os.path.isfile(os.path.join(pretrained_model_name_or_path, TF2_WEIGHTS_NAME)):
301 | # Load from a TF 2.0 checkpoint
302 | archive_file = os.path.join(pretrained_model_name_or_path, TF2_WEIGHTS_NAME)
303 | elif os.path.isfile(os.path.join(pretrained_model_name_or_path, WEIGHTS_NAME)):
304 | # Load from a PyTorch checkpoint
305 | archive_file = os.path.join(pretrained_model_name_or_path, WEIGHTS_NAME)
306 | else:
307 | raise EnvironmentError("Error no file named {} found in directory {} or `from_tf` set to False".format(
308 | [WEIGHTS_NAME, TF2_WEIGHTS_NAME, TF_WEIGHTS_NAME + ".index"],
309 | pretrained_model_name_or_path))
310 | elif os.path.isfile(pretrained_model_name_or_path):
311 | archive_file = pretrained_model_name_or_path
312 | else:
313 | assert from_tf, "Error finding file {}, no file or TF 1.X checkpoint found".format(pretrained_model_name_or_path)
314 | archive_file = pretrained_model_name_or_path + ".index"
315 |
316 | # redirect to the cache, if necessary
317 | try:
318 | resolved_archive_file = cached_path(archive_file, cache_dir=cache_dir, force_download=force_download, proxies=proxies)
319 | except EnvironmentError:
320 | if pretrained_model_name_or_path in cls.pretrained_model_archive_map:
321 | msg = "Couldn't reach server at '{}' to download pretrained weights.".format(
322 | archive_file)
323 | else:
324 | msg = "Model name '{}' was not found in model name list ({}). " \
325 | "We assumed '{}' was a path or url to model weight files named one of {} but " \
326 | "couldn't find any such file at this path or url.".format(
327 | pretrained_model_name_or_path,
328 | ', '.join(cls.pretrained_model_archive_map.keys()),
329 | archive_file,
330 | [WEIGHTS_NAME, TF2_WEIGHTS_NAME, TF_WEIGHTS_NAME])
331 | raise EnvironmentError(msg)
332 |
333 | if resolved_archive_file == archive_file:
334 | logger.info("loading weights file {}".format(archive_file))
335 | else:
336 | logger.info("loading weights file {} from cache at {}".format(
337 | archive_file, resolved_archive_file))
338 | else:
339 | resolved_archive_file = None
340 |
341 | # Instantiate model.
342 | model = cls(config, *model_args, **model_kwargs)
343 |
344 | if state_dict is None and not from_tf:
345 | state_dict = torch.load(resolved_archive_file, map_location='cpu')
346 |
347 | missing_keys = []
348 | unexpected_keys = []
349 | error_msgs = []
350 |
351 | if from_tf:
352 | if resolved_archive_file.endswith('.index'):
353 | # Load from a TensorFlow 1.X checkpoint - provided by original authors
354 | model = cls.load_tf_weights(model, config, resolved_archive_file[:-6]) # Remove the '.index'
355 | else:
356 | # Load from our TensorFlow 2.0 checkpoints
357 | try:
358 | from transformers import load_tf2_checkpoint_in_pytorch_model
359 | model = load_tf2_checkpoint_in_pytorch_model(model, resolved_archive_file, allow_missing_keys=True)
360 | except ImportError as e:
361 | logger.error("Loading a TensorFlow model in PyTorch, requires both PyTorch and TensorFlow to be installed. Please see "
362 | "https://pytorch.org/ and https://www.tensorflow.org/install/ for installation instructions.")
363 | raise e
364 | else:
365 | # Convert old format to new format if needed from a PyTorch state_dict
366 | old_keys = []
367 | new_keys = []
368 | for key in state_dict.keys():
369 | new_key = None
370 | if 'gamma' in key:
371 | new_key = key.replace('gamma', 'weight')
372 | if 'beta' in key:
373 | new_key = key.replace('beta', 'bias')
374 | if new_key:
375 | old_keys.append(key)
376 | new_keys.append(new_key)
377 | for old_key, new_key in zip(old_keys, new_keys):
378 | state_dict[new_key] = state_dict.pop(old_key)
379 |
380 | # copy state_dict so _load_from_state_dict can modify it
381 | metadata = getattr(state_dict, '_metadata', None)
382 | state_dict = state_dict.copy()
383 | if metadata is not None:
384 | state_dict._metadata = metadata
385 |
386 | def load(module, prefix=''):
387 | local_metadata = {} if metadata is None else metadata.get(prefix[:-1], {})
388 | module._load_from_state_dict(
389 | state_dict, prefix, local_metadata, True, missing_keys, unexpected_keys, error_msgs)
390 | for name, child in module._modules.items():
391 | if child is not None:
392 | load(child, prefix + name + '.')
393 |
394 | # Make sure we are able to load base models as well as derived models (with heads)
395 | start_prefix = ''
396 | model_to_load = model
397 | if not hasattr(model, cls.base_model_prefix) and any(s.startswith(cls.base_model_prefix) for s in state_dict.keys()):
398 | start_prefix = cls.base_model_prefix + '.'
399 | if hasattr(model, cls.base_model_prefix) and not any(s.startswith(cls.base_model_prefix) for s in state_dict.keys()):
400 | model_to_load = getattr(model, cls.base_model_prefix)
401 |
402 | load(model_to_load, prefix=start_prefix)
403 | if len(missing_keys) > 0:
404 | logger.info("Weights of {} not initialized from pretrained model: {}".format(
405 | model.__class__.__name__, missing_keys))
406 | if len(unexpected_keys) > 0:
407 | logger.info("Weights from pretrained model not used in {}: {}".format(
408 | model.__class__.__name__, unexpected_keys))
409 | if len(error_msgs) > 0:
410 | raise RuntimeError('Error(s) in loading state_dict for {}:\n\t{}'.format(
411 | model.__class__.__name__, "\n\t".join(error_msgs)))
412 |
413 | if hasattr(model, 'tie_weights'):
414 | model.tie_weights() # make sure word embedding weights are still tied
415 |
416 | # Set model in evaluation mode to desactivate DropOut modules by default
417 | model.eval()
418 |
419 | if output_loading_info:
420 | loading_info = {"missing_keys": missing_keys, "unexpected_keys": unexpected_keys, "error_msgs": error_msgs}
421 | return model, loading_info
422 |
423 | return model
424 |
425 |
426 | class Conv1D(nn.Module):
427 | def __init__(self, nf, nx):
428 | """ Conv1D layer as defined by Radford et al. for OpenAI GPT (and also used in GPT-2)
429 | Basically works like a Linear layer but the weights are transposed
430 | """
431 | super(Conv1D, self).__init__()
432 | self.nf = nf
433 | w = torch.empty(nx, nf)
434 | nn.init.normal_(w, std=0.02)
435 | self.weight = nn.Parameter(w)
436 | self.bias = nn.Parameter(torch.zeros(nf))
437 |
438 | def forward(self, x):
439 | size_out = x.size()[:-1] + (self.nf,)
440 | x = torch.addmm(self.bias, x.view(-1, x.size(-1)), self.weight)
441 | x = x.view(*size_out)
442 | return x
443 |
444 |
445 | class PoolerStartLogits(nn.Module):
446 | """ Compute SQuAD start_logits from sequence hidden states. """
447 | def __init__(self, config):
448 | super(PoolerStartLogits, self).__init__()
449 | self.dense = nn.Linear(config.hidden_size, 1)
450 |
451 | def forward(self, hidden_states, p_mask=None):
452 | """ Args:
453 | **p_mask**: (`optional`) ``torch.FloatTensor`` of shape `(batch_size, seq_len)`
454 | invalid position mask such as query and special symbols (PAD, SEP, CLS)
455 | 1.0 means token should be masked.
456 | """
457 | x = self.dense(hidden_states).squeeze(-1)
458 |
459 | if p_mask is not None:
460 | if next(self.parameters()).dtype == torch.float16:
461 | x = x * (1 - p_mask) - 65500 * p_mask
462 | else:
463 | x = x * (1 - p_mask) - 1e30 * p_mask
464 |
465 | return x
466 |
467 |
468 | class PoolerEndLogits(nn.Module):
469 | """ Compute SQuAD end_logits from sequence hidden states and start token hidden state.
470 | """
471 | def __init__(self, config):
472 | super(PoolerEndLogits, self).__init__()
473 | self.dense_0 = nn.Linear(config.hidden_size * 2, config.hidden_size)
474 | self.activation = nn.Tanh()
475 | self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
476 | self.dense_1 = nn.Linear(config.hidden_size, 1)
477 |
478 | def forward(self, hidden_states, start_states=None, start_positions=None, p_mask=None):
479 | """ Args:
480 | One of ``start_states``, ``start_positions`` should be not None.
481 | If both are set, ``start_positions`` overrides ``start_states``.
482 |
483 | **start_states**: ``torch.LongTensor`` of shape identical to hidden_states
484 | hidden states of the first tokens for the labeled span.
485 | **start_positions**: ``torch.LongTensor`` of shape ``(batch_size,)``
486 | position of the first token for the labeled span:
487 | **p_mask**: (`optional`) ``torch.FloatTensor`` of shape ``(batch_size, seq_len)``
488 | Mask of invalid position such as query and special symbols (PAD, SEP, CLS)
489 | 1.0 means token should be masked.
490 | """
491 | assert start_states is not None or start_positions is not None, "One of start_states, start_positions should be not None"
492 | if start_positions is not None:
493 | slen, hsz = hidden_states.shape[-2:]
494 | start_positions = start_positions[:, None, None].expand(-1, -1, hsz) # shape (bsz, 1, hsz)
495 | start_states = hidden_states.gather(-2, start_positions) # shape (bsz, 1, hsz)
496 | start_states = start_states.expand(-1, slen, -1) # shape (bsz, slen, hsz)
497 |
498 | x = self.dense_0(torch.cat([hidden_states, start_states], dim=-1))
499 | x = self.activation(x)
500 | x = self.LayerNorm(x)
501 | x = self.dense_1(x).squeeze(-1)
502 |
503 | if p_mask is not None:
504 | if next(self.parameters()).dtype == torch.float16:
505 | x = x * (1 - p_mask) - 65500 * p_mask
506 | else:
507 | x = x * (1 - p_mask) - 1e30 * p_mask
508 |
509 | return x
510 |
511 |
512 | class PoolerAnswerClass(nn.Module):
513 | """ Compute SQuAD 2.0 answer class from classification and start tokens hidden states. """
514 | def __init__(self, config):
515 | super(PoolerAnswerClass, self).__init__()
516 | self.dense_0 = nn.Linear(config.hidden_size * 2, config.hidden_size)
517 | self.activation = nn.Tanh()
518 | self.dense_1 = nn.Linear(config.hidden_size, 1, bias=False)
519 |
520 | def forward(self, hidden_states, start_states=None, start_positions=None, cls_index=None):
521 | """
522 | Args:
523 | One of ``start_states``, ``start_positions`` should be not None.
524 | If both are set, ``start_positions`` overrides ``start_states``.
525 |
526 | **start_states**: ``torch.LongTensor`` of shape identical to ``hidden_states``.
527 | hidden states of the first tokens for the labeled span.
528 | **start_positions**: ``torch.LongTensor`` of shape ``(batch_size,)``
529 | position of the first token for the labeled span.
530 | **cls_index**: torch.LongTensor of shape ``(batch_size,)``
531 | position of the CLS token. If None, take the last token.
532 |
533 | note(Original repo):
534 | no dependency on end_feature so that we can obtain one single `cls_logits`
535 | for each sample
536 | """
537 | hsz = hidden_states.shape[-1]
538 | assert start_states is not None or start_positions is not None, "One of start_states, start_positions should be not None"
539 | if start_positions is not None:
540 | start_positions = start_positions[:, None, None].expand(-1, -1, hsz) # shape (bsz, 1, hsz)
541 | start_states = hidden_states.gather(-2, start_positions).squeeze(-2) # shape (bsz, hsz)
542 |
543 | if cls_index is not None:
544 | cls_index = cls_index[:, None, None].expand(-1, -1, hsz) # shape (bsz, 1, hsz)
545 | cls_token_state = hidden_states.gather(-2, cls_index).squeeze(-2) # shape (bsz, hsz)
546 | else:
547 | cls_token_state = hidden_states[:, -1, :] # shape (bsz, hsz)
548 |
549 | x = self.dense_0(torch.cat([start_states, cls_token_state], dim=-1))
550 | x = self.activation(x)
551 | x = self.dense_1(x).squeeze(-1)
552 |
553 | return x
554 |
555 |
556 | class SQuADHead(nn.Module):
557 | r""" A SQuAD head inspired by XLNet.
558 |
559 | Parameters:
560 | config (:class:`~transformers.XLNetConfig`): Model configuration class with all the parameters of the model.
561 |
562 | Inputs:
563 | **hidden_states**: ``torch.FloatTensor`` of shape ``(batch_size, seq_len, hidden_size)``
564 | hidden states of sequence tokens
565 | **start_positions**: ``torch.LongTensor`` of shape ``(batch_size,)``
566 | position of the first token for the labeled span.
567 | **end_positions**: ``torch.LongTensor`` of shape ``(batch_size,)``
568 | position of the last token for the labeled span.
569 | **cls_index**: torch.LongTensor of shape ``(batch_size,)``
570 | position of the CLS token. If None, take the last token.
571 | **is_impossible**: ``torch.LongTensor`` of shape ``(batch_size,)``
572 | Whether the question has a possible answer in the paragraph or not.
573 | **p_mask**: (`optional`) ``torch.FloatTensor`` of shape ``(batch_size, seq_len)``
574 | Mask of invalid position such as query and special symbols (PAD, SEP, CLS)
575 | 1.0 means token should be masked.
576 |
577 | Outputs: `Tuple` comprising various elements depending on the configuration (config) and inputs:
578 | **loss**: (`optional`, returned if both ``start_positions`` and ``end_positions`` are provided) ``torch.FloatTensor`` of shape ``(1,)``:
579 | Classification loss as the sum of start token, end token (and is_impossible if provided) classification losses.
580 | **start_top_log_probs**: (`optional`, returned if ``start_positions`` or ``end_positions`` is not provided)
581 | ``torch.FloatTensor`` of shape ``(batch_size, config.start_n_top)``
582 | Log probabilities for the top config.start_n_top start token possibilities (beam-search).
583 | **start_top_index**: (`optional`, returned if ``start_positions`` or ``end_positions`` is not provided)
584 | ``torch.LongTensor`` of shape ``(batch_size, config.start_n_top)``
585 | Indices for the top config.start_n_top start token possibilities (beam-search).
586 | **end_top_log_probs**: (`optional`, returned if ``start_positions`` or ``end_positions`` is not provided)
587 | ``torch.FloatTensor`` of shape ``(batch_size, config.start_n_top * config.end_n_top)``
588 | Log probabilities for the top ``config.start_n_top * config.end_n_top`` end token possibilities (beam-search).
589 | **end_top_index**: (`optional`, returned if ``start_positions`` or ``end_positions`` is not provided)
590 | ``torch.LongTensor`` of shape ``(batch_size, config.start_n_top * config.end_n_top)``
591 | Indices for the top ``config.start_n_top * config.end_n_top`` end token possibilities (beam-search).
592 | **cls_logits**: (`optional`, returned if ``start_positions`` or ``end_positions`` is not provided)
593 | ``torch.FloatTensor`` of shape ``(batch_size,)``
594 | Log probabilities for the ``is_impossible`` label of the answers.
595 | """
596 | def __init__(self, config):
597 | super(SQuADHead, self).__init__()
598 | self.start_n_top = config.start_n_top
599 | self.end_n_top = config.end_n_top
600 |
601 | self.start_logits = PoolerStartLogits(config)
602 | self.end_logits = PoolerEndLogits(config)
603 | self.answer_class = PoolerAnswerClass(config)
604 |
605 | def forward(self, hidden_states, start_positions=None, end_positions=None,
606 | cls_index=None, is_impossible=None, p_mask=None):
607 | outputs = ()
608 |
609 | start_logits = self.start_logits(hidden_states, p_mask=p_mask)
610 |
611 | if start_positions is not None and end_positions is not None:
612 | # If we are on multi-GPU, let's remove the dimension added by batch splitting
613 | for x in (start_positions, end_positions, cls_index, is_impossible):
614 | if x is not None and x.dim() > 1:
615 | x.squeeze_(-1)
616 |
617 | # during training, compute the end logits based on the ground truth of the start position
618 | end_logits = self.end_logits(hidden_states, start_positions=start_positions, p_mask=p_mask)
619 |
620 | loss_fct = CrossEntropyLoss()
621 | start_loss = loss_fct(start_logits, start_positions)
622 | end_loss = loss_fct(end_logits, end_positions)
623 | total_loss = (start_loss + end_loss) / 2
624 |
625 | if cls_index is not None and is_impossible is not None:
626 | # Predict answerability from the representation of CLS and START
627 | cls_logits = self.answer_class(hidden_states, start_positions=start_positions, cls_index=cls_index)
628 | loss_fct_cls = nn.BCEWithLogitsLoss()
629 | cls_loss = loss_fct_cls(cls_logits, is_impossible)
630 |
631 | # note(zhiliny): by default multiply the loss by 0.5 so that the scale is comparable to start_loss and end_loss
632 | total_loss += cls_loss * 0.5
633 |
634 | outputs = (total_loss,) + outputs
635 |
636 | else:
637 | # during inference, compute the end logits based on beam search
638 | bsz, slen, hsz = hidden_states.size()
639 | start_log_probs = F.softmax(start_logits, dim=-1) # shape (bsz, slen)
640 |
641 | start_top_log_probs, start_top_index = torch.topk(start_log_probs, self.start_n_top, dim=-1) # shape (bsz, start_n_top)
642 | start_top_index_exp = start_top_index.unsqueeze(-1).expand(-1, -1, hsz) # shape (bsz, start_n_top, hsz)
643 | start_states = torch.gather(hidden_states, -2, start_top_index_exp) # shape (bsz, start_n_top, hsz)
644 | start_states = start_states.unsqueeze(1).expand(-1, slen, -1, -1) # shape (bsz, slen, start_n_top, hsz)
645 |
646 | hidden_states_expanded = hidden_states.unsqueeze(2).expand_as(start_states) # shape (bsz, slen, start_n_top, hsz)
647 | p_mask = p_mask.unsqueeze(-1) if p_mask is not None else None
648 | end_logits = self.end_logits(hidden_states_expanded, start_states=start_states, p_mask=p_mask)
649 | end_log_probs = F.softmax(end_logits, dim=1) # shape (bsz, slen, start_n_top)
650 |
651 | end_top_log_probs, end_top_index = torch.topk(end_log_probs, self.end_n_top, dim=1) # shape (bsz, end_n_top, start_n_top)
652 | end_top_log_probs = end_top_log_probs.view(-1, self.start_n_top * self.end_n_top)
653 | end_top_index = end_top_index.view(-1, self.start_n_top * self.end_n_top)
654 |
655 | start_states = torch.einsum("blh,bl->bh", hidden_states, start_log_probs)
656 | cls_logits = self.answer_class(hidden_states, start_states=start_states, cls_index=cls_index)
657 |
658 | outputs = (start_top_log_probs, start_top_index, end_top_log_probs, end_top_index, cls_logits) + outputs
659 |
660 | # return start_top_log_probs, start_top_index, end_top_log_probs, end_top_index, cls_logits
661 | # or (if labels are provided) (total_loss,)
662 | return outputs
663 |
664 |
665 | class SequenceSummary(nn.Module):
666 | r""" Compute a single vector summary of a sequence hidden states according to various possibilities:
667 | Args of the config class:
668 | summary_type:
669 | - 'last' => [default] take the last token hidden state (like XLNet)
670 | - 'first' => take the first token hidden state (like Bert)
671 | - 'mean' => take the mean of all tokens hidden states
672 | - 'cls_index' => supply a Tensor of classification token position (GPT/GPT-2)
673 | - 'attn' => Not implemented now, use multi-head attention
674 | summary_use_proj: Add a projection after the vector extraction
675 | summary_proj_to_labels: If True, the projection outputs to config.num_labels classes (otherwise to hidden_size). Default: False.
676 | summary_activation: 'tanh' => add a tanh activation to the output, Other => no activation. Default
677 | summary_first_dropout: Add a dropout before the projection and activation
678 | summary_last_dropout: Add a dropout after the projection and activation
679 | """
680 | def __init__(self, config):
681 | super(SequenceSummary, self).__init__()
682 |
683 | self.summary_type = config.summary_type if hasattr(config, 'summary_use_proj') else 'last'
684 | if self.summary_type == 'attn':
685 | # We should use a standard multi-head attention module with absolute positional embedding for that.
686 | # Cf. https://github.com/zihangdai/xlnet/blob/master/modeling.py#L253-L276
687 | # We can probably just use the multi-head attention module of PyTorch >=1.1.0
688 | raise NotImplementedError
689 |
690 | self.summary = Identity()
691 | if hasattr(config, 'summary_use_proj') and config.summary_use_proj:
692 | if hasattr(config, 'summary_proj_to_labels') and config.summary_proj_to_labels and config.num_labels > 0:
693 | num_classes = config.num_labels
694 | else:
695 | num_classes = config.hidden_size
696 | self.summary = nn.Linear(config.hidden_size, num_classes)
697 |
698 | self.activation = Identity()
699 | if hasattr(config, 'summary_activation') and config.summary_activation == 'tanh':
700 | self.activation = nn.Tanh()
701 |
702 | self.first_dropout = Identity()
703 | if hasattr(config, 'summary_first_dropout') and config.summary_first_dropout > 0:
704 | self.first_dropout = nn.Dropout(config.summary_first_dropout)
705 |
706 | self.last_dropout = Identity()
707 | if hasattr(config, 'summary_last_dropout') and config.summary_last_dropout > 0:
708 | self.last_dropout = nn.Dropout(config.summary_last_dropout)
709 |
710 | def forward(self, hidden_states, cls_index=None):
711 | """ hidden_states: float Tensor in shape [bsz, ..., seq_len, hidden_size], the hidden-states of the last layer.
712 | cls_index: [optional] position of the classification token if summary_type == 'cls_index',
713 | shape (bsz,) or more generally (bsz, ...) where ... are optional leading dimensions of hidden_states.
714 | if summary_type == 'cls_index' and cls_index is None:
715 | we take the last token of the sequence as classification token
716 | """
717 | if self.summary_type == 'last':
718 | output = hidden_states[:, -1]
719 | elif self.summary_type == 'first':
720 | output = hidden_states[:, 0]
721 | elif self.summary_type == 'mean':
722 | output = hidden_states.mean(dim=1)
723 | elif self.summary_type == 'cls_index':
724 | if cls_index is None:
725 | cls_index = torch.full_like(hidden_states[..., :1, :], hidden_states.shape[-2]-1, dtype=torch.long)
726 | else:
727 | cls_index = cls_index.unsqueeze(-1).unsqueeze(-1)
728 | cls_index = cls_index.expand((-1,) * (cls_index.dim()-1) + (hidden_states.size(-1),))
729 | # shape of cls_index: (bsz, XX, 1, hidden_size) where XX are optional leading dim of hidden_states
730 | output = hidden_states.gather(-2, cls_index).squeeze(-2) # shape (bsz, XX, hidden_size)
731 | elif self.summary_type == 'attn':
732 | raise NotImplementedError
733 |
734 | output = self.first_dropout(output)
735 | output = self.summary(output)
736 | output = self.activation(output)
737 | output = self.last_dropout(output)
738 |
739 | return output
740 |
741 |
742 | def prune_linear_layer(layer, index, dim=0):
743 | """ Prune a linear layer (a model parameters) to keep only entries in index.
744 | Return the pruned layer as a new layer with requires_grad=True.
745 | Used to remove heads.
746 | """
747 | index = index.to(layer.weight.device)
748 | W = layer.weight.index_select(dim, index).clone().detach()
749 | if layer.bias is not None:
750 | if dim == 1:
751 | b = layer.bias.clone().detach()
752 | else:
753 | b = layer.bias[index].clone().detach()
754 | new_size = list(layer.weight.size())
755 | new_size[dim] = len(index)
756 | new_layer = nn.Linear(new_size[1], new_size[0], bias=layer.bias is not None).to(layer.weight.device)
757 | new_layer.weight.requires_grad = False
758 | new_layer.weight.copy_(W.contiguous())
759 | new_layer.weight.requires_grad = True
760 | if layer.bias is not None:
761 | new_layer.bias.requires_grad = False
762 | new_layer.bias.copy_(b.contiguous())
763 | new_layer.bias.requires_grad = True
764 | return new_layer
765 |
766 |
767 | def prune_conv1d_layer(layer, index, dim=1):
768 | """ Prune a Conv1D layer (a model parameters) to keep only entries in index.
769 | A Conv1D work as a Linear layer (see e.g. BERT) but the weights are transposed.
770 | Return the pruned layer as a new layer with requires_grad=True.
771 | Used to remove heads.
772 | """
773 | index = index.to(layer.weight.device)
774 | W = layer.weight.index_select(dim, index).clone().detach()
775 | if dim == 0:
776 | b = layer.bias.clone().detach()
777 | else:
778 | b = layer.bias[index].clone().detach()
779 | new_size = list(layer.weight.size())
780 | new_size[dim] = len(index)
781 | new_layer = Conv1D(new_size[1], new_size[0]).to(layer.weight.device)
782 | new_layer.weight.requires_grad = False
783 | new_layer.weight.copy_(W.contiguous())
784 | new_layer.weight.requires_grad = True
785 | new_layer.bias.requires_grad = False
786 | new_layer.bias.copy_(b.contiguous())
787 | new_layer.bias.requires_grad = True
788 | return new_layer
789 |
790 |
791 | def prune_layer(layer, index, dim=None):
792 | """ Prune a Conv1D or nn.Linear layer (a model parameters) to keep only entries in index.
793 | Return the pruned layer as a new layer with requires_grad=True.
794 | Used to remove heads.
795 | """
796 | if isinstance(layer, nn.Linear):
797 | return prune_linear_layer(layer, index, dim=0 if dim is None else dim)
798 | elif isinstance(layer, Conv1D):
799 | return prune_conv1d_layer(layer, index, dim=1 if dim is None else dim)
800 | else:
801 | raise ValueError("Can't prune layer of class {}".format(layer.__class__))
802 |
--------------------------------------------------------------------------------
/src/transformers19/tokenization_gpt2.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2018 The Open AI Team Authors and The HuggingFace Inc. team.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 | """Tokenization classes for OpenAI GPT."""
16 | from __future__ import (absolute_import, division, print_function,
17 | unicode_literals)
18 |
19 | import sys
20 | import json
21 | import logging
22 | import os
23 | import regex as re
24 | from io import open
25 |
26 | try:
27 | from functools import lru_cache
28 | except ImportError:
29 | # Just a dummy decorator to get the checks to run on python2
30 | # because honestly I don't want to support a byte-level unicode BPE tokenizer on python 2 right now.
31 | def lru_cache():
32 | return lambda func: func
33 |
34 | from .tokenization_utils import PreTrainedTokenizer
35 |
36 | logger = logging.getLogger(__name__)
37 |
38 | VOCAB_FILES_NAMES = {
39 | 'vocab_file': 'vocab.json',
40 | 'merges_file': 'merges.txt',
41 | }
42 |
43 | PRETRAINED_VOCAB_FILES_MAP = {
44 | 'vocab_file':
45 | {
46 | 'gpt2': "https://s3.amazonaws.com/models.huggingface.co/bert/gpt2-vocab.json",
47 | 'gpt2-medium': "https://s3.amazonaws.com/models.huggingface.co/bert/gpt2-medium-vocab.json",
48 | 'gpt2-large': "https://s3.amazonaws.com/models.huggingface.co/bert/gpt2-large-vocab.json",
49 | 'distilgpt2': "https://s3.amazonaws.com/models.huggingface.co/bert/distilgpt2-vocab.json",
50 | },
51 | 'merges_file':
52 | {
53 | 'gpt2': "https://s3.amazonaws.com/models.huggingface.co/bert/gpt2-merges.txt",
54 | 'gpt2-medium': "https://s3.amazonaws.com/models.huggingface.co/bert/gpt2-medium-merges.txt",
55 | 'gpt2-large': "https://s3.amazonaws.com/models.huggingface.co/bert/gpt2-large-merges.txt",
56 | 'distilgpt2': "https://s3.amazonaws.com/models.huggingface.co/bert/distilgpt2-merges.txt",
57 | },
58 | }
59 |
60 | PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES = {
61 | 'gpt2': 1024,
62 | 'gpt2-medium': 1024,
63 | 'gpt2-large': 1024,
64 | 'distilgpt2': 1024,
65 | }
66 |
67 | @lru_cache()
68 | def bytes_to_unicode():
69 | """
70 | Returns list of utf-8 byte and a mapping to unicode strings.
71 | We specifically avoids mapping to whitespace/control characters the bpe code barfs on.
72 |
73 | The reversible bpe codes work on unicode strings.
74 | This means you need a large # of unicode characters in your vocab if you want to avoid UNKs.
75 | When you're at something like a 10B token dataset you end up needing around 5K for decent coverage.
76 | This is a signficant percentage of your normal, say, 32K bpe vocab.
77 | To avoid that, we want lookup tables between utf-8 bytes and unicode strings.
78 | """
79 | _chr = unichr if sys.version_info[0] == 2 else chr
80 | bs = list(range(ord("!"), ord("~")+1))+list(range(ord("¡"), ord("¬")+1))+list(range(ord("®"), ord("ÿ")+1))
81 | cs = bs[:]
82 | n = 0
83 | for b in range(2**8):
84 | if b not in bs:
85 | bs.append(b)
86 | cs.append(2**8+n)
87 | n += 1
88 | cs = [_chr(n) for n in cs]
89 | return dict(zip(bs, cs))
90 |
91 | def get_pairs(word):
92 | """Return set of symbol pairs in a word.
93 |
94 | Word is represented as tuple of symbols (symbols being variable-length strings).
95 | """
96 | pairs = set()
97 | prev_char = word[0]
98 | for char in word[1:]:
99 | pairs.add((prev_char, char))
100 | prev_char = char
101 | return pairs
102 |
103 | class GPT2Tokenizer(PreTrainedTokenizer):
104 | """
105 | GPT-2 BPE tokenizer. Peculiarities:
106 | - Byte-level Byte-Pair-Encoding
107 | - Requires a space to start the input string => the encoding methods should be called with the
108 | ``add_prefix_space`` flag set to ``True``.
109 | Otherwise, this tokenizer ``encode`` and ``decode`` method will not conserve
110 | the absence of a space at the beginning of a string: `tokenizer.decode(tokenizer.encode("Hello")) = " Hello"`
111 | """
112 | vocab_files_names = VOCAB_FILES_NAMES
113 | pretrained_vocab_files_map = PRETRAINED_VOCAB_FILES_MAP
114 | max_model_input_sizes = PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES
115 |
116 | def __init__(self, vocab_file, merges_file, errors='replace', unk_token="<|endoftext|>",
117 | bos_token="<|endoftext|>", eos_token="<|endoftext|>", **kwargs):
118 | super(GPT2Tokenizer, self).__init__(bos_token=bos_token, eos_token=eos_token, unk_token=unk_token, **kwargs)
119 | self.max_len_single_sentence = self.max_len # no default special tokens - you can update this value if you add special tokens
120 | self.max_len_sentences_pair = self.max_len # no default special tokens - you can update this value if you add special tokens
121 |
122 | self.encoder = json.load(open(vocab_file, encoding="utf-8"))
123 | self.decoder = {v: k for k, v in self.encoder.items()}
124 | self.errors = errors # how to handle errors in decoding
125 | self.byte_encoder = bytes_to_unicode()
126 | self.byte_decoder = {v: k for k, v in self.byte_encoder.items()}
127 | bpe_data = open(merges_file, encoding='utf-8').read().split('\n')[1:-1]
128 | bpe_merges = [tuple(merge.split()) for merge in bpe_data]
129 | self.bpe_ranks = dict(zip(bpe_merges, range(len(bpe_merges))))
130 | self.cache = {}
131 |
132 | # Should haved added re.IGNORECASE so BPE merges can happen for capitalized versions of contractions
133 | self.pat = re.compile(r"""'s|'t|'re|'ve|'m|'ll|'d| ?\p{L}+| ?\p{N}+| ?[^\s\p{L}\p{N}]+|\s+(?!\S)|\s+""")
134 |
135 | @property
136 | def vocab_size(self):
137 | return len(self.encoder)
138 |
139 | def bpe(self, token):
140 | if token in self.cache:
141 | return self.cache[token]
142 | word = tuple(token)
143 | pairs = get_pairs(word)
144 |
145 | if not pairs:
146 | return token
147 |
148 | while True:
149 | bigram = min(pairs, key = lambda pair: self.bpe_ranks.get(pair, float('inf')))
150 | if bigram not in self.bpe_ranks:
151 | break
152 | first, second = bigram
153 | new_word = []
154 | i = 0
155 | while i < len(word):
156 | try:
157 | j = word.index(first, i)
158 | new_word.extend(word[i:j])
159 | i = j
160 | except:
161 | new_word.extend(word[i:])
162 | break
163 |
164 | if word[i] == first and i < len(word)-1 and word[i+1] == second:
165 | new_word.append(first+second)
166 | i += 2
167 | else:
168 | new_word.append(word[i])
169 | i += 1
170 | new_word = tuple(new_word)
171 | word = new_word
172 | if len(word) == 1:
173 | break
174 | else:
175 | pairs = get_pairs(word)
176 | word = ' '.join(word)
177 | self.cache[token] = word
178 | return word
179 |
180 | def _tokenize(self, text, add_prefix_space=False):
181 | """ Tokenize a string.
182 | Args:
183 | - add_prefix_space (boolean, default False):
184 | Begin the sentence with at least one space toto get invariance to word order in GPT-2 (and RoBERTa) tokenizers.
185 | """
186 | if add_prefix_space:
187 | text = ' ' + text
188 |
189 | bpe_tokens = []
190 | for token in re.findall(self.pat, text):
191 | if sys.version_info[0] == 2:
192 | token = ''.join(self.byte_encoder[ord(b)] for b in token) # Maps all our bytes to unicode strings, avoiding controle tokens of the BPE (spaces in our case)
193 | else:
194 | token = ''.join(self.byte_encoder[b] for b in token.encode('utf-8')) # Maps all our bytes to unicode strings, avoiding controle tokens of the BPE (spaces in our case)
195 | bpe_tokens.extend(bpe_token for bpe_token in self.bpe(token).split(' '))
196 | return bpe_tokens
197 |
198 | def _convert_token_to_id(self, token):
199 | """ Converts a token (str/unicode) in an id using the vocab. """
200 | return self.encoder.get(token, self.encoder.get(self.unk_token))
201 |
202 | def _convert_id_to_token(self, index):
203 | """Converts an index (integer) in a token (string/unicode) using the vocab."""
204 | return self.decoder.get(index)
205 |
206 | def convert_tokens_to_string(self, tokens):
207 | """ Converts a sequence of tokens (string) in a single string. """
208 | text = ''.join(tokens)
209 | text = bytearray([self.byte_decoder[c] for c in text]).decode('utf-8', errors=self.errors)
210 | return text
211 |
212 | def save_vocabulary(self, save_directory):
213 | """Save the tokenizer vocabulary and merge files to a directory."""
214 | if not os.path.isdir(save_directory):
215 | logger.error("Vocabulary path ({}) should be a directory".format(save_directory))
216 | return
217 | vocab_file = os.path.join(save_directory, VOCAB_FILES_NAMES['vocab_file'])
218 | merge_file = os.path.join(save_directory, VOCAB_FILES_NAMES['merges_file'])
219 |
220 | with open(vocab_file, 'w', encoding='utf-8') as f:
221 | f.write(json.dumps(self.encoder, ensure_ascii=False))
222 |
223 | index = 0
224 | with open(merge_file, "w", encoding="utf-8") as writer:
225 | writer.write(u'#version: 0.2\n')
226 | for bpe_tokens, token_index in sorted(self.bpe_ranks.items(), key=lambda kv: kv[1]):
227 | if index != token_index:
228 | logger.warning("Saving vocabulary to {}: BPE merge indices are not consecutive."
229 | " Please check that the tokenizer is not corrupted!".format(merge_file))
230 | index = token_index
231 | writer.write(' '.join(bpe_tokens) + u'\n')
232 | index += 1
233 |
234 | return vocab_file, merge_file
--------------------------------------------------------------------------------