├── .gitattributes ├── .gitignore ├── LICENSE ├── README.md ├── data.sh ├── doc ├── DialogRPT-EMNLP.pdf ├── demo.PNG ├── icon.png ├── toy.tsv ├── toy.tsv.ensemble.jsonl └── toy.tsv.updown.jsonl ├── requirements.txt ├── restore └── ensemble.yml └── src ├── attic └── data.py ├── data_new.py ├── feeder.py ├── generation.py ├── main.py ├── master.py ├── model.py ├── score.py ├── shared.py └── transformers19 ├── __init__.py ├── configuration_gpt2.py ├── configuration_utils.py ├── file_utils.py ├── modeling_gpt2.py ├── modeling_utils.py ├── tokenization_gpt2.py └── tokenization_utils.py /.gitattributes: -------------------------------------------------------------------------------- 1 | # Auto detect text files and perform LF normalization 2 | * text=auto 3 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | 2 | *.pyc 3 | *.pth 4 | out/ 5 | data/ 6 | .vscode/ 7 | test/ 8 | doc/arxiv/ 9 | model_card/ 10 | 11 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2020 Sean Xiang Gao 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 |

2 |
3 | 4 |
5 |

6 | 7 | # DialogRPT: Dialog Ranking Pretrained Transformers 8 | 9 | [DialogRPT](https://arxiv.org/abs/2009.06978/) predicts human feedback (upvotes👍 or replies💬) of dialogue responses. 10 | 11 | It is a set of dialog response ranking models proposed by [Microsoft Research NLP Group](https://www.microsoft.com/en-us/research/group/natural-language-processing/) trained on 100 + millions of human feedback data, accepted to appear at [EMNLP'20](https://2020.emnlp.org/). 12 | It can be used to improve existing dialog generation model (e.g., [DialoGPT](https://github.com/microsoft/DialoGPT)) by re-ranking the generated response candidates. 13 | This repo provides a PyTorch implementation and pretrained models. 14 | 15 | Quick links: 16 | * [Paper](https://arxiv.org/abs/2009.06978/) 17 | * [Intro Talk](https://slideslive.com/38938970/dialogue-response-ranking-training-with-largescale-human-feedback-data) and [Slides](https://github.com/golsun/DialogRPT/blob/master/doc/DialogRPT-EMNLP.pdf) 18 | * Demo: [original](https://colab.research.google.com/drive/1jQXzTYsgdZIQjJKrX4g3CP0_PGCeVU3C?usp=sharing) or [HuggingFace](https://colab.research.google.com/drive/1cAtfkbhqsRsT59y3imjR1APw3MHDMkuV?usp=sharing) 19 | * [Dataset](https://dialogfeedback.github.io/data.html) 20 | 21 | We considered the following tasks and provided corresponding pretrained models. 22 | (Click 💾 to download original pytorch checkpoint for this repo, or click 🤗 to use HuggingFace model card) 23 | 24 | 25 | | Task | Description | Pretrained model | 26 | | :------------- | :----------- | :-----------: | 27 | | **Human feedback** | | | 28 | | `updown` | How likely the response gets the most upvotes? | [💾](https://xiagnlp2.blob.core.windows.net/dialogrpt/updown.pth) / [🤗](https://huggingface.co/microsoft/DialogRPT-updown?text=I+love+NLP%21+<%7Cendoftext%7C>+Me+too%21) | 29 | | `width`| How likely the response gets the most direct replies? | [💾](https://xiagnlp2.blob.core.windows.net/dialogrpt/width.pth) / [🤗](https://huggingface.co/microsoft/DialogRPT-width?text=I+love+NLP%21+<%7Cendoftext%7C>+Me+too%21) | 30 | | `depth`| How likely the response gets the longest follow-up thread? | [💾](https://xiagnlp2.blob.core.windows.net/dialogrpt/depth.pth) / [🤗](https://huggingface.co/microsoft/DialogRPT-depth?text=I+love+NLP%21+<%7Cendoftext%7C>+Me+too%21) | 31 | | **Human-like** (human vs fake) | | | 32 | | `human_vs_rand`| How relevant the response is for the given context? | [💾](https://xiagnlp2.blob.core.windows.net/dialogrpt/human_vs_rand.pth) / [🤗](https://huggingface.co/microsoft/DialogRPT-human-vs-rand?text=I+love+NLP%21+<%7Cendoftext%7C>+Me+too%21) | 33 | | `human_vs_machine`| How likely the response is human-written rather than machine-generated? | [💾](https://xiagnlp2.blob.core.windows.net/dialogrpt/human_vs_machine.pth) / [🤗](https://huggingface.co/microsoft/DialogRPT-human-vs-machine?text=I+love+NLP%21+<%7Cendoftext%7C>+Me+too%21) | 34 | 35 | 36 | ## Contents: 37 | 38 | * [Quick Start](#Quick-Start) 39 | * [Install](#Install), try [this demo](https://colab.research.google.com/drive/1jQXzTYsgdZIQjJKrX4g3CP0_PGCeVU3C?usp=sharing), or use Hugging Face model card with this [demo](https://colab.research.google.com/drive/1cAtfkbhqsRsT59y3imjR1APw3MHDMkuV?usp=sharing) 40 | * [Use rankers only](#Use-rankers-only): use DialogRPT as evalution metric 41 | * [Use generator + ranker](#Use-generator-+-ranker): improve generator by reranking hypotheses with DialogRPT 42 | * [Data](#Data) 43 | * [Training](#Training) 44 | * [Evaluation](#Evaluation) 45 | * [Human feedback prediction](#Human-feedback-prediction) 46 | * [Human-like classification](#Human-like-classification) 47 | * [Citation](#Citation) 48 | 49 | 50 | 51 | 52 | ## Quick Start 53 | 54 | 55 | ### Install 56 | 57 | **Option 1**: run locally 58 | ``` 59 | git clone https://github.com/golsun/DialogRPT 60 | cd DialogRPT 61 | conda create -n dialogrpt python=3.6 62 | conda activate dialogrpt 63 | pip install -r requirements.txt 64 | ``` 65 | 66 | **Option 2**: run on Colab Notebook. You can either use [Demo (original)](https://colab.research.google.com/drive/1jQXzTYsgdZIQjJKrX4g3CP0_PGCeVU3C?usp=sharing) or [Demo (HuggingFace)](https://colab.research.google.com/drive/1cAtfkbhqsRsT59y3imjR1APw3MHDMkuV?usp=sharing) 67 | 68 | 69 | 70 | ### Use rankers only 71 | In the following example, the model predicts that, given the same context *"I love NLP!"*, response *"Here’s a free textbook (URL) in case anyone needs it."* is gets more upvotes than response *"Me too!"*. 72 | ```bash 73 | python src/score.py play -p=restore/updown.pth 74 | # 75 | # Context: I love NLP! 76 | # Response: Here’s a free textbook (URL) in case anyone needs it. 77 | # score = 0.613 78 | 79 | # Context: I love NLP! 80 | # Response: Me too! 81 | # score = 0.111 82 | ``` 83 | You can also play the ensemble model, which involves multiple models defined in its [config file](restore/ensemble.yml) (see this file for details). 84 | ```bash 85 | python src/main.py play -p=restore/ensemble.yml 86 | ``` 87 | To score a list of (context, response) pairs, please provide a input file (`--data`), which is tab-separated in format `context \t response0 \t response1 ...`. See example [input file](https://github.com/golsun/DialogRPT/blob/master/doc/toy.tsv) 88 | * Using a single ranker (see [expected output](https://github.com/golsun/DialogRPT/blob/master/doc/toy.tsv.updown.jsonl)) 89 | ```bash 90 | python src/score.py test --data=doc/toy.tsv -p=restore/updown.pth 91 | # downloading pretrained model to restore/updown.pth 92 | # 100% [....................] 1520029114 / 1520029114 93 | # loading from restore/updown.pth 94 | # ranking doc/toy.tsv 95 | # totally processed 2 line, avg_hyp_score 0.264, top_hyp_score 0.409 96 | # results saved to doc/toy.tsv.ranked.jsonl 97 | ``` 98 | * Using an ensemble model (see [expected output](https://github.com/golsun/DialogRPT/blob/master/doc/toy.tsv.ensemble.jsonl)) 99 | ```bash 100 | python src/score.py test --data=doc/toy.tsv -p=restore/ensemble.yml 101 | ``` 102 | Statistics of the scoring results can be shown with the following command, e.g. for `doc/toy.tsv.ensemble.jsonl` 103 | ```bash 104 | python src/score.py stats --data=doc/toy.tsv.ensemble.jsonl 105 | # |best |avg 106 | # ---------------------------------------- 107 | # _score |0.339 |0.206 108 | # human_vs_rand |0.928 |0.861 109 | # human_vs_machine |0.575 |0.525 110 | # updown |0.409 |0.264 111 | # depth |0.304 |0.153 112 | # width |0.225 |0.114 113 | # final |0.339 |0.206 114 | # ---------------------------------------- 115 | # n_cxt: 2 116 | # avg n_hyp per cxt: 2.50 117 | ``` 118 | 119 | 120 | ### Use generator + ranker 121 | Dialog generation models can be improved by integrating with the response ranking models. 122 | For example, given the context *"Can we restart 2020?"*, DialoGPT may return the following responses by sampling decoding (or you can try beam search without `--sampling`). Some of them, e.g., *"Yes, we can."* has a high generation probability (`gen 0.496`), but less interesting (`ranker 0.302`). So the rankers will put in position lower than ones more likely to be upvoted, e.g. *"I think we should go back to the beginning, and start from the beginning."* which is relatively less likely to be generated (`gen 0.383`) but seems more interesting (`ranker 0.431`) 123 | ```bash 124 | python src/generation.py play -pg=restore/medium_ft.pkl -pr=restore/updown.pth --sampling 125 | # 126 | # Context: Can we restart 2020? 127 | # 0.431 gen 0.383 ranker 0.431 I think we should go back to the beginning, and start from the beginning. 128 | # 0.429 gen 0.227 ranker 0.429 I think I'll just sit here and wait for 2020 129 | # 0.377 gen 0.249 ranker 0.377 Yeah, let's just start from the beginning 130 | # 0.323 gen 0.195 ranker 0.323 I think we should just give up and let the year just pass. 131 | # 0.304 gen 0.395 ranker 0.304 Yes. We can. 132 | # 0.302 gen 0.496 ranker 0.302 Yes, we can. 133 | # 0.283 gen 0.351 ranker 0.283 It's been a while since we've seen a good reboot. 134 | # 0.174 gen 0.306 ranker 0.174 I'm up for it 135 | # 0.168 gen 0.463 ranker 0.168 I'm down 136 | # 0.153 gen 0.328 ranker 0.153 I think so, yes. 137 | # ... 138 | ``` 139 | Similarly, you can use the [ensemble model](restore/ensemble.yml). 140 | ``` 141 | python src/generation.py -pg=restore/medium_ft.pkl -pr=restore/ensemble.yml 142 | ``` 143 | To generate from a list of contexts stored in a line-separated file, provide it with `--path_test` and use the command below: 144 | ``` 145 | python src/generation.py test --path_test=path/to/list/of/contexts -pg=restore/medium_ft.pkl -pr=restore/ensemble.yml 146 | ``` 147 | 148 | 149 | ## Data 150 | 151 | As the Pushshift Reddit dataset was deleted from [this server](https://files.pushshift.io/reddit), the data extraction pipeline of this release no longer works. As an alternative, you may want to use the Pushshift [API](https://github.com/pushshift/api). 152 | 153 | ## Training 154 | We use [DialoGPT](https://github.com/microsoft/DialoGPT) to initialize the model. Please download with 155 | ``` 156 | wget https://convaisharables.blob.core.windows.net/lsp/multiref/medium_ft.pkl -P restore 157 | ``` 158 | For the human feedback prediction tasks, we specify `min_score_gap` and `min_rank_gap` to only validate on less-noisy samples (not applied to training). 159 | ``` 160 | python src/main.py train --data=data/out/updown -p=restore/medium_ft.pkl --min_score_gap=20 --min_rank_gap=0.5 161 | python src/main.py train --data=data/out/depth -p=restore/medium_ft.pkl --min_score_gap=4 --min_rank_gap=0.5 162 | python src/main.py train --data=data/out/width -p=restore/medium_ft.pkl --min_score_gap=4 --min_rank_gap=0.5 163 | ``` 164 | For `human_vs_rand` task, use the `--mismatch` flag to feed rand human response as negative examples. We can reuse previous dataset (e.g. `data/out/updown`). 165 | ``` 166 | python src/main.py train --data=data/out/updown -p=restore/medium_ft.pkl --mismatch 167 | ``` 168 | For `human_vs_machine` task, we build dataset by pair human response with a response generated by [DialoGPT](https://github.com/microsoft/DialoGPT) with topk decoding 169 | ``` 170 | python src/main.py train --data=data/out/human_vs_machine -p=restore/medium_ft.pkl 171 | ``` 172 | 173 | We trained all models on a Nvidia V100 4-core GPU (each core with 32G memory) with the following hyperparameters. Checkpoint with the best validation accuracy is used as final model. 174 | | Argument | Value | Description | 175 | | :------------- | :-----------: |:------------- | 176 | | `batch` | 256 | total batch size for all GPUs. | 177 | | `vali_size` | 1024 | number of samples used for validation (i.e. dev set size). | 178 | | `lr` | 3e-05 | learning rate | 179 | | `max_seq_len` | 50 | max allowed sequence length.
if longer, leading tokens will be truncated | 180 | | `max_hr_gap` | 1 | max allowed hour difference between positive and negative samples.
If longer, this pair will be discarded for train/vali| 181 | 182 | 183 | ## Evaluation 184 | 185 | ### Human feedback prediction 186 | 187 | The performance on `updown`, `depth`, and `width` can be measured with the following commands, respectively. 188 | The `--min_score_gap` and `--min_rank_gap` arguments are consistent with the values used to measure validation loss during training. 189 | ``` 190 | python src/score.py eval_human_feedback -p=restore/updown.pth --data=test/human_feedback/updown.tsv --min_score_gap=20 --min_rank_gap=0.5 191 | python src/score.py eval_human_feedback -p=restore/depth.pth --data=test/human_feedback/depth.tsv --min_score_gap=4 --min_rank_gap=0.5 192 | python src/score.py eval_human_feedback -p=restore/width.pth --data=test/human_feedback/width.tsv --min_score_gap=4 --min_rank_gap=0.5 193 | ``` 194 | 195 | The expected pairwise accuracy on 5000 test samples is listed in the table below (from Table 5 of the [paper](https://arxiv.org/abs/2009.06978)). Note even by random guess one can get accuracy of 0.500. 196 | | human feedback | `updown` | `depth` | `width` | 197 | | :------------- | :------: |:------------: |:--------: | 198 | | Dialog ppl. | 0.488 | 0.508 | 0.513 | 199 | | Reverse dialog ppl. | 0.560 | 0.557 | 0.571 | 200 | | **DialogRPT** (ours)| **0.683** | **0.695** | **0.752** | 201 | 202 | ### Human-like classification 203 | 204 | * `human_vs_rand` task: Although the model is trained on `reddit` corpus only, we measured its **zero-shot** performance on several unseen corpora (`twitter`, `dailydialog` and `personachat`) 205 | ```bash 206 | python src/score.py eval_human_vs_rand -p=restore/human_vs_rand.pth --data=test/human_vs_fake/reddit 207 | python src/score.py eval_human_vs_rand -p=restore/human_vs_rand.pth --data=test/human_vs_fake/dailydialog 208 | python src/score.py eval_human_vs_rand -p=restore/human_vs_rand.pth --data=test/human_vs_fake/twitter 209 | python src/score.py eval_human_vs_rand -p=restore/human_vs_rand.pth --data=test/human_vs_fake/personachat 210 | ``` 211 | The expected `hits@k` metric on 5000 test samples is listed in the table below (from Table 7 of the [paper](https://arxiv.org/abs/2009.06978)). 212 | `hits@k` measures, for the same context, given `k` positive responses and `n` negative responses, how many positive responses are in top-`k` of the ranked responses. 213 | | `human_vs_rand` | `reddit` | `dailydialog` | `twitter` | `personachat` | 214 | | :------------- | :------: |:------------: |:--------: |:------------: | 215 | | BM25 | 0.309 | 0.182 | 0.178 | 0.117 | 216 | | Dialog ppl. | 0.560 | 0.176 | 0.107 | 0.108 | 217 | | Reverse dialog ppl. | 0.775 | 0.457 | 0.440 | 0.449 | 218 | | [ConveRT](https://arxiv.org/abs/1911.03688) | 0.760 | 0.380 | 0.439 | 0.197 | 219 | | **DialogRPT** (ours)| **0.886** | **0.621** | **0.548** | **0.479** | 220 | 221 | * `human_vs_machine` task: its performance is only evaluated for `reddit` corpus. 222 | ```bash 223 | python src/score.py --task=eval_human_vs_machine -p=restore/human_vs_machine.pth --data=test/human_vs_fake/reddit 224 | # expecting accuracy ~0.98 225 | ``` 226 | 227 | 228 | ## Citation 229 | If you use our dataset or model, please cite our [paper](https://arxiv.org/abs/2009.06978) 230 | 231 | ``` 232 | @inproceedings{gao2020dialogrpt, 233 | title={Dialogue Response RankingTraining with Large-Scale Human Feedback Data}, 234 | author={Xiang Gao and Yizhe Zhang and Michel Galley and Chris Brockett and Bill Dolan}, 235 | year={2020}, 236 | booktitle={EMNLP} 237 | } 238 | ``` 239 | -------------------------------------------------------------------------------- /data.sh: -------------------------------------------------------------------------------- 1 | # step 0. create the data folder 2 | 3 | mkdir "data/bz2" 4 | 5 | # Step 1. Download raw data from a third party dump: https://files.pushshift.io/reddit 6 | 7 | # download comments for year 2011 8 | wget https://files.pushshift.io/reddit/comments/RC_2011-01.bz2 -P data/bz2 9 | wget https://files.pushshift.io/reddit/comments/RC_2011-02.bz2 -P data/bz2 10 | wget https://files.pushshift.io/reddit/comments/RC_2011-03.bz2 -P data/bz2 11 | wget https://files.pushshift.io/reddit/comments/RC_2011-04.bz2 -P data/bz2 12 | wget https://files.pushshift.io/reddit/comments/RC_2011-05.bz2 -P data/bz2 13 | wget https://files.pushshift.io/reddit/comments/RC_2011-06.bz2 -P data/bz2 14 | wget https://files.pushshift.io/reddit/comments/RC_2011-07.bz2 -P data/bz2 15 | wget https://files.pushshift.io/reddit/comments/RC_2011-08.bz2 -P data/bz2 16 | wget https://files.pushshift.io/reddit/comments/RC_2011-09.bz2 -P data/bz2 17 | wget https://files.pushshift.io/reddit/comments/RC_2011-10.bz2 -P data/bz2 18 | wget https://files.pushshift.io/reddit/comments/RC_2011-11.bz2 -P data/bz2 19 | wget https://files.pushshift.io/reddit/comments/RC_2011-12.bz2 -P data/bz2 20 | 21 | # download comments for year 2012 22 | wget https://files.pushshift.io/reddit/comments/RC_2012-01.bz2 -P data/bz2 23 | wget https://files.pushshift.io/reddit/comments/RC_2012-02.bz2 -P data/bz2 24 | wget https://files.pushshift.io/reddit/comments/RC_2012-03.bz2 -P data/bz2 25 | wget https://files.pushshift.io/reddit/comments/RC_2012-04.bz2 -P data/bz2 26 | wget https://files.pushshift.io/reddit/comments/RC_2012-05.bz2 -P data/bz2 27 | wget https://files.pushshift.io/reddit/comments/RC_2012-06.bz2 -P data/bz2 28 | wget https://files.pushshift.io/reddit/comments/RC_2012-07.bz2 -P data/bz2 29 | wget https://files.pushshift.io/reddit/comments/RC_2012-08.bz2 -P data/bz2 30 | wget https://files.pushshift.io/reddit/comments/RC_2012-09.bz2 -P data/bz2 31 | wget https://files.pushshift.io/reddit/comments/RC_2012-10.bz2 -P data/bz2 32 | wget https://files.pushshift.io/reddit/comments/RC_2012-11.bz2 -P data/bz2 33 | wget https://files.pushshift.io/reddit/comments/RC_2012-12.bz2 -P data/bz2 34 | 35 | # download submissions for year 2011 36 | wget https://files.pushshift.io/reddit/submissions/RS_2011-01.bz2 -P data/bz2 37 | wget https://files.pushshift.io/reddit/submissions/RS_2011-02.bz2 -P data/bz2 38 | wget https://files.pushshift.io/reddit/submissions/RS_2011-03.bz2 -P data/bz2 39 | wget https://files.pushshift.io/reddit/submissions/RS_2011-04.bz2 -P data/bz2 40 | wget https://files.pushshift.io/reddit/submissions/RS_2011-05.bz2 -P data/bz2 41 | wget https://files.pushshift.io/reddit/submissions/RS_2011-06.bz2 -P data/bz2 42 | wget https://files.pushshift.io/reddit/submissions/RS_2011-07.bz2 -P data/bz2 43 | wget https://files.pushshift.io/reddit/submissions/RS_2011-08.bz2 -P data/bz2 44 | wget https://files.pushshift.io/reddit/submissions/RS_2011-09.bz2 -P data/bz2 45 | wget https://files.pushshift.io/reddit/submissions/RS_2011-10.bz2 -P data/bz2 46 | wget https://files.pushshift.io/reddit/submissions/RS_2011-11.bz2 -P data/bz2 47 | wget https://files.pushshift.io/reddit/submissions/RS_2011-12.bz2 -P data/bz2 48 | 49 | # download submissions for year 2011 50 | wget https://files.pushshift.io/reddit/submissions/RS_2012-01.bz2 -P data/bz2 51 | wget https://files.pushshift.io/reddit/submissions/RS_2012-02.bz2 -P data/bz2 52 | wget https://files.pushshift.io/reddit/submissions/RS_2012-03.bz2 -P data/bz2 53 | wget https://files.pushshift.io/reddit/submissions/RS_2012-04.bz2 -P data/bz2 54 | wget https://files.pushshift.io/reddit/submissions/RS_2012-05.bz2 -P data/bz2 55 | wget https://files.pushshift.io/reddit/submissions/RS_2012-06.bz2 -P data/bz2 56 | wget https://files.pushshift.io/reddit/submissions/RS_2012-07.bz2 -P data/bz2 57 | wget https://files.pushshift.io/reddit/submissions/RS_2012-08.bz2 -P data/bz2 58 | wget https://files.pushshift.io/reddit/submissions/RS_2012-09.bz2 -P data/bz2 59 | wget https://files.pushshift.io/reddit/submissions/RS_2012-10.bz2 -P data/bz2 60 | wget https://files.pushshift.io/reddit/submissions/RS_2012-11.bz2 -P data/bz2 61 | wget https://files.pushshift.io/reddit/submissions/RS_2012-12.bz2 -P data/bz2 62 | 63 | # Step 2. Read the `.bz2` files and group items from the same subreddit 64 | 65 | python src/data.py bz2 2011 66 | python src/data.py bz2 2012 67 | 68 | # Step 3. extract basic attributes and dialog trees. 69 | 70 | python src/data.py basic 2011 71 | python src/data.py basic 2012 72 | 73 | # Step 4. Build training and testing data for different feedback signals. 74 | 75 | python src/data.py updown 2011 --year_to=2012 76 | python src/data.py depth 2011 --year_to=2012 77 | python src/data.py width 2011 --year_to=2012 -------------------------------------------------------------------------------- /doc/DialogRPT-EMNLP.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/golsun/DialogRPT/ac9954a784cc88071e5fa309c2afdace2e9b38d7/doc/DialogRPT-EMNLP.pdf -------------------------------------------------------------------------------- /doc/demo.PNG: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/golsun/DialogRPT/ac9954a784cc88071e5fa309c2afdace2e9b38d7/doc/demo.PNG -------------------------------------------------------------------------------- /doc/icon.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/golsun/DialogRPT/ac9954a784cc88071e5fa309c2afdace2e9b38d7/doc/icon.png -------------------------------------------------------------------------------- /doc/toy.tsv: -------------------------------------------------------------------------------- 1 | I love NLP! Here’s a free textbook (URL) in case anyone needs it. Me too! It’s super useful and more and more powerful! 2 | How are you? Not bad. pretty good! :) -------------------------------------------------------------------------------- /doc/toy.tsv.ensemble.jsonl: -------------------------------------------------------------------------------- 1 | {"line_id": 0, "cxt": "I love NLP!", "hyps": [[0.5319748520851135, {"human_vs_rand": 0.875923216342926, "human_vs_machine": 0.6866590976715088, "updown": 0.6128572225570679, "depth": 0.4968094825744629, "width": 0.3681032657623291, "final": 0.5319748520851135}, "Here\u2019s a free textbook (URL) in case anyone needs it."], [0.22798165678977966, {"human_vs_rand": 0.7438808679580688, "human_vs_machine": 0.6686411499977112, "updown": 0.29050877690315247, "depth": 0.10862986743450165, "width": 0.052612606436014175, "final": 0.22798165678977966}, "It\u2019s super useful and more and more powerful!"], [0.0703396126627922, {"human_vs_rand": 0.630566418170929, "human_vs_machine": 0.6046908497810364, "updown": 0.11055243015289307, "depth": 0.031935129314661026, "width": 0.028544878587126732, "final": 0.0703396126627922}, "Me too!"]]} 2 | {"line_id": 1, "cxt": "How are you?", "hyps": [[0.14587044715881348, {"human_vs_rand": 0.9802649021148682, "human_vs_machine": 0.3306310772895813, "updown": 0.20513156056404114, "depth": 0.111829474568367, "width": 0.08141987770795822, "final": 0.14587044715881348}, "pretty good! :)"], [0.12652477622032166, {"human_vs_rand": 0.962352991104126, "human_vs_machine": 0.462787926197052, "updown": 0.17496052384376526, "depth": 0.0771029144525528, "width": 0.07592014223337173, "final": 0.12652477622032166}, "Not bad."]]} -------------------------------------------------------------------------------- /doc/toy.tsv.updown.jsonl: -------------------------------------------------------------------------------- 1 | {"line_id": 0, "cxt": "I love NLP!", "hyps": [[0.6128574013710022, "Here\u2019s a free textbook (URL) in case anyone needs it."], [0.2905086874961853, "It\u2019s super useful and more and more powerful!"], [0.11055240780115128, "Me too!"]]} 2 | {"line_id": 1, "cxt": "How are you?", "hyps": [[0.20513157546520233, "pretty good! :)"], [0.17496058344841003, "Not bad."]]} -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | boto3==1.14.59 2 | botocore==1.17.59 3 | certifi==2020.6.20 4 | chardet==3.0.4 5 | cycler==0.10.0 6 | docker==3.2.1 7 | docutils==0.15.2 8 | future==0.18.2 9 | idna==2.10 10 | iterable-queue==1.2.2 11 | jmespath==0.10.0 12 | kiwisolver==1.2.0 13 | matplotlib==3.3.1 14 | ntlm-auth==1.2.0 15 | numpy==1.19.2 16 | Pillow==7.2.0 17 | pyparsing==2.4.7 18 | python-dateutil==2.8.1 19 | PyYAML==5.3.1 20 | regex==2020.7.14 21 | requests==2.24.0 22 | requests-ntlm==1.1.0 23 | s3transfer==0.3.3 24 | scipy==1.5.2 25 | six==1.15.0 26 | smmap2==2.0.3 27 | torch==1.6.0 28 | tqdm==4.48.2 29 | urllib3==1.25.10 30 | wincertstore==0.2 31 | -------------------------------------------------------------------------------- /restore/ensemble.yml: -------------------------------------------------------------------------------- 1 | # see model.py/JointScorer.core for details 2 | # the `prior` score is the weighted average of `human_vs_rand` and `human_vs_machine` predictions, 3 | # and `cond` is the weighted average of `updown`, `depth`and `width` predictions. 4 | # The final score is the product of `prior` score and `cond` score 5 | 6 | prior: 7 | 8 | - name: human_vs_rand 9 | wt: 0.5 10 | path: restore/human_vs_rand.pth 11 | 12 | - name: human_vs_machine 13 | wt: 0.5 14 | path: restore/human_vs_machine.pth 15 | 16 | cond: 17 | 18 | - name: updown 19 | wt: 1 20 | path: restore/updown.pth 21 | 22 | - name: depth 23 | wt: 0.48 24 | path: restore/depth.pth 25 | 26 | - name: width 27 | wt: -0.5 28 | path: restore/width.pth -------------------------------------------------------------------------------- /src/attic/data.py: -------------------------------------------------------------------------------- 1 | # author: Xiang Gao at Microsoft Research AI NLP Group 2 | 3 | 4 | import bz2, json, os, pickle, pdb, time, random 5 | import urllib.request 6 | import numpy as np 7 | from shared import _cat_ 8 | 9 | 10 | def valid_sub(sub): 11 | if sub.upper() in [ 12 | 'CON', 'PRN', 'AUX', 'NUL', 'COM1', 'COM2', 'COM3', 'COM4', 'COM5', 'COM6', 'COM7', 13 | 'COM8', 'COM9', 'LPT1', 'LPT2', 'LPT3', 'LPT4', 'LPT5', 'LPT6', 'LPT7', 'LPT8', 'LPT9']: 14 | # not allowed by Windows system 15 | return False 16 | if ':' in sub: 17 | return False 18 | return True 19 | 20 | 21 | def get_dates(year_from, year_to=None): 22 | if year_to is None: 23 | year_to = year_from 24 | dates = [] 25 | for year in range(year_from, year_to + 1): 26 | for _mo in range(1, 12 + 1): 27 | mo = str(_mo) 28 | if len(mo) == 1: 29 | mo = '0' + mo 30 | dates.append(str(year) + '-' + mo) 31 | return dates 32 | 33 | 34 | def extract_rc(date): 35 | path_bz2 = '%s/RC_%s.bz2'%(fld_bz2, date) 36 | nodes = dict() 37 | edges = dict() 38 | subs = set() 39 | n = 0 40 | m = 0 41 | kk = ['body', 'link_id', 'name', 'parent_id', 'subreddit'] 42 | 43 | def save(nodes, edges): 44 | for sub in nodes: 45 | fld = fld_jsonl + '/' + sub 46 | try: 47 | os.makedirs(fld, exist_ok=True) 48 | except NotADirectoryError as e: 49 | print(e) 50 | continue 51 | if sub not in subs: 52 | open(fld + '/%s_nodes.jsonl'%date, 'w', encoding="utf-8") 53 | open(fld + '/%s_edges.tsv'%date, 'w', encoding="utf-8") 54 | subs.add(sub) 55 | with open(fld + '/%s_nodes.jsonl'%date, 'a', encoding="utf-8") as f: 56 | f.write('\n'.join(nodes[sub]) + '\n') 57 | with open(fld + '/%s_edges.tsv'%date, 'a', encoding="utf-8") as f: 58 | f.write('\n'.join(edges[sub]) + '\n') 59 | 60 | for line in bz2.open(path_bz2, 'rt', encoding="utf-8"): 61 | n += 1 62 | line = line.strip('\n') 63 | try: 64 | node = json.loads(line) 65 | except Exception: 66 | continue 67 | 68 | ok = True 69 | for k in kk: 70 | if k not in node: 71 | ok = False 72 | break 73 | if not ok: 74 | break 75 | 76 | if not valid_sub(node['subreddit']): 77 | continue 78 | 79 | if node['subreddit'] not in nodes: 80 | nodes[node['subreddit']] = [] 81 | edges[node['subreddit']] = [] 82 | nodes[node['subreddit']].append(line) 83 | edges[node['subreddit']].append('%s\t%s\t%s'%(node['link_id'], node['parent_id'], node['name'])) 84 | 85 | m += 1 86 | if m % 1e5 == 0: 87 | save(nodes, edges) 88 | print('[RC_%s] saved %.2f/%.2f M, %i subreddits'%(date, m/1e6, n/1e6, len(subs))) 89 | nodes = dict() 90 | edges = dict() 91 | 92 | save(nodes, edges) 93 | print('[RC_%s] FINAL %.2f/%.2f M, %i subreddits ================'%(date, m/1e6, n/1e6, len(subs))) 94 | with open(fld_jsonl + '/readme.txt', 'a', encoding='utf-8') as f: 95 | f.write('[%s] saved %i/%i\n'%(date, m, n)) 96 | 97 | 98 | def extract_rs(date): 99 | path_bz2 = '%s/RS_%s.bz2'%(fld_bz2, date) 100 | roots = dict() 101 | subs = set() 102 | n = 0 103 | m = 0 104 | kk = ['selftext', 'id', 'title', 'subreddit'] 105 | 106 | def save(roots): 107 | for sub in roots: 108 | fld = fld_jsonl + '/' + sub 109 | try: 110 | os.makedirs(fld, exist_ok=True) 111 | except NotADirectoryError as e: 112 | print(e) 113 | continue 114 | if sub not in subs: 115 | open(fld + '/%s_roots.jsonl'%date, 'w', encoding="utf-8") 116 | subs.add(sub) 117 | with open(fld + '/%s_roots.jsonl'%date, 'a', encoding="utf-8") as f: 118 | f.write('\n'.join(roots[sub]) + '\n') 119 | 120 | for line in bz2.open(path_bz2, 'rt', encoding="utf-8"): 121 | n += 1 122 | line = line.strip('\n') 123 | try: 124 | root = json.loads(line) 125 | except Exception: 126 | continue 127 | 128 | ok = True 129 | for k in kk: 130 | if k not in root: 131 | ok = False 132 | break 133 | if not ok: 134 | break 135 | if not valid_sub(root['subreddit']): 136 | continue 137 | 138 | # some bz2, e.g. 2012-09, doesn't have the `name` entry 139 | if 'name' not in root: 140 | root['name'] = 't3_' + root['id'] 141 | 142 | if root['subreddit'] not in roots: 143 | roots[root['subreddit']] = [] 144 | roots[root['subreddit']].append(line) 145 | 146 | m += 1 147 | if m % 1e4 == 0: 148 | save(roots) 149 | print('[RS_%s] saved %.2f/%.2f M, %i subreddits'%(date, m/1e6, n/1e6, len(subs))) 150 | roots = dict() 151 | 152 | save(roots) 153 | print('[RS_%s] FINAL %.2f/%.2f M, %i subreddits ================'%( 154 | date, m/1e6, n/1e6, len(subs))) 155 | with open(fld_jsonl + '/readme_roots.txt', 'a', encoding='utf-8') as f: 156 | f.write('[%s] saved %i/%i\n'%(date, m, n)) 157 | 158 | 159 | 160 | def extract_txt(sub, year, tokenizer, overwrite=False, max_subword=3): 161 | fld = '%s/%s'%(fld_subs, sub) 162 | os.makedirs(fld, exist_ok=True) 163 | path_out = '%s/%i_txt.tsv'%(fld, year) 164 | path_done = path_out + '.done' 165 | if not overwrite and os.path.exists(path_done): 166 | return 167 | 168 | dates = get_dates(year) 169 | open(path_out, 'w', encoding='utf-8') 170 | 171 | def clean(txt): 172 | if txt.strip() in ['[deleted]', '[removed]']: 173 | return None 174 | if '>' in txt or '>' in txt: # no comment in line ('>' means '>') 175 | return None 176 | 177 | # deal with URL 178 | txt = txt.replace('](','] (') 179 | ww = [] 180 | for w in txt.split(): 181 | if len(w) == 0: 182 | continue 183 | if '://' in w.lower() or 'http' in w.lower(): 184 | ww.append('(URL)') 185 | else: 186 | ww.append(w) 187 | if not ww: 188 | return None 189 | if len(ww) > 30: # focus on dialog, so ignore long txt 190 | return None 191 | if len(ww) < 1: 192 | return None 193 | txt = ' '.join(ww) 194 | for c in ['\t', '\n', '\r']: # delimiter or newline 195 | txt = txt.replace(c, ' ') 196 | 197 | ids = tokenizer.encode(txt) 198 | if len(ids) / len(ww) > max_subword: # usually < 1.5. too large means too many unknown words 199 | return None 200 | 201 | ids = ' '.join([str(x) for x in ids]) 202 | return txt, ids 203 | 204 | lines = [] 205 | m = 0 206 | n = 0 207 | name_set = set() 208 | for date in dates: 209 | path = '%s/%s/%s_nodes.jsonl'%(fld_jsonl, sub, date) 210 | if not os.path.exists(path): 211 | continue 212 | for line in open(path, encoding='utf-8'): 213 | n += 1 214 | d = json.loads(line.strip('\n')) 215 | if d['name'] in name_set: 216 | continue 217 | name_set.add(d['name']) 218 | txt_ids = clean(d['body']) 219 | if txt_ids is not None: 220 | txt, ids = txt_ids 221 | lines.append('%s\t%s\t%s'%(d['name'], txt, ids)) 222 | m += 1 223 | if m % 1e4 == 0: 224 | with open(path_out, 'a', encoding='utf-8') as f: 225 | f.write('\n'.join(lines) + '\n') 226 | lines = [] 227 | 228 | for date in dates: 229 | path = '%s/%s/%s_roots.jsonl'%(fld_jsonl, sub, date) 230 | if not os.path.exists(path): 231 | continue 232 | for line in open(path, encoding='utf-8'): 233 | n += 1 234 | d = json.loads(line.strip('\n')) 235 | if 'name' not in d: 236 | d['name'] = 't3_' + d['id'] 237 | if d['name'] in name_set: 238 | continue 239 | name_set.add(d['name']) 240 | txt_ids = clean(d['title'] + ' ' + d['selftext']) 241 | if txt_ids is not None: 242 | txt, ids = txt_ids 243 | lines.append('%s\t%s\t%s'%(d['name'], txt, ids)) 244 | m += 1 245 | if m % 1e4 == 0: 246 | with open(path_out, 'a', encoding='utf-8') as f: 247 | f.write('\n'.join(lines) + '\n') 248 | lines = [] 249 | if lines: 250 | with open(path_out, 'a', encoding='utf-8') as f: 251 | f.write('\n'.join(lines)) 252 | 253 | s = '[%s %s] txt kept %i/%i'%(sub, year, m, n) 254 | with open(path_done, 'w') as f: 255 | f.write(s) 256 | print(s) 257 | 258 | 259 | def extract_trees(sub, year): 260 | fld = '%s/%s'%(fld_subs, sub) 261 | os.makedirs(fld, exist_ok=True) 262 | path_out = '%s/%i_trees.pkl'%(fld, year) 263 | if os.path.exists(path_out): 264 | return 265 | 266 | trees = dict() 267 | n = 0 268 | for date in get_dates(year): 269 | path = '%s/%s/%s_edges.tsv'%(fld_jsonl, sub, date) 270 | if not os.path.exists(path): 271 | #print('no such file: '+path) 272 | continue 273 | for line in open(path, encoding='utf-8'): 274 | n += 1 275 | link, parent, child = line.strip('\n').split('\t') 276 | if link not in trees: 277 | trees[link] = dict() 278 | trees[link][(parent, child)] = date 279 | 280 | if not trees: 281 | return 282 | 283 | print('[%s %i] %i trees %.1f nodes/tree'%(sub, year, len(trees), n/len(trees))) 284 | os.makedirs(fld, exist_ok=True) 285 | pickle.dump(trees, open(path_out, 'wb')) 286 | 287 | 288 | def extract_time(sub, year, overwrite=False): 289 | fld = '%s/%s'%(fld_subs, sub) 290 | os.makedirs(fld, exist_ok=True) 291 | path_out = '%s/%i_time.tsv'%(fld, year) 292 | path_done = path_out + '.done' 293 | if not overwrite and os.path.exists(path_done): 294 | return 295 | dates = get_dates(year) 296 | suffix = 'nodes' 297 | os.makedirs(fld, exist_ok=True) 298 | open(path_out, 'w', encoding='utf-8') 299 | 300 | lines = [] 301 | m = 0 302 | n = 0 303 | name_set = set() 304 | for date in dates: 305 | path = '%s/%s/%s_%s.jsonl'%(fld_jsonl, sub, date, suffix) 306 | if not os.path.exists(path): 307 | continue 308 | for line in open(path, encoding='utf-8'): 309 | n += 1 310 | d = json.loads(line.strip('\n')) 311 | if 'name' not in d: 312 | d['name'] = 't3_' + d['id'] 313 | if d['name'] in name_set: 314 | continue 315 | name_set.add(d['name']) 316 | t = d['created_utc'] 317 | lines.append('%s\t%s'%(d['name'], t)) 318 | m += 1 319 | if m % 1e4 == 0: 320 | with open(path_out, 'a', encoding='utf-8') as f: 321 | f.write('\n'.join(lines) + '\n') 322 | lines = [] 323 | with open(path_out, 'a', encoding='utf-8') as f: 324 | f.write('\n'.join(lines)) 325 | 326 | s = '[%s %s] time kept %i/%i'%(sub, year, m, n) 327 | with open(path_done, 'w') as f: 328 | f.write(s) 329 | print(s) 330 | 331 | 332 | 333 | def calc_feedback(sub, year, overwrite=False): 334 | fld = '%s/%s'%(fld_subs, sub) 335 | path_out = '%s/%i_feedback.tsv'%(fld, year) 336 | path_done = path_out + '.done' 337 | if not overwrite and os.path.exists(path_done): 338 | return 339 | 340 | path_pkl = '%s/%i_trees.pkl'%(fld, year) 341 | if not os.path.exists(path_pkl): 342 | return 343 | trees = pickle.load(open(path_pkl, 'rb')) 344 | if not trees: 345 | return 346 | 347 | dates = get_dates(year) 348 | updown = dict() 349 | for date in dates: 350 | path = '%s/%s/%s_nodes.jsonl'%(fld_jsonl, sub, date) 351 | if not os.path.exists(path): 352 | continue 353 | for line in open(path, encoding='utf-8'): 354 | d = json.loads(line.strip('\n')) 355 | updown[d['name']] = d['ups'] - d['downs'] 356 | 357 | if not updown: 358 | print('empty updown:') 359 | return 360 | 361 | with open(path_out, 'w', encoding='utf-8') as f: 362 | f.write('\t'.join(['#path', 'vol', 'width', 'depth', 'updown']) + '\n') 363 | 364 | print('[%s %s] calculating scores for %i trees'%(sub, year, len(trees))) 365 | 366 | n_tree = 0 367 | n_node = 0 368 | for root in trees: 369 | tree = trees[root] 370 | children = dict() 371 | for parent, child in tree: 372 | if parent not in children: 373 | children[parent] = [] 374 | children[parent].append(child) 375 | if root not in children: 376 | continue 377 | 378 | # BFS to get all paths from root to leaf 379 | q = [[root]] 380 | paths = [] 381 | while q: 382 | qsize = len(q) 383 | for _ in range(qsize): 384 | path = q.pop(0) 385 | head = path[-1] 386 | if head not in children: # then head is a leaf 387 | paths.append(path) 388 | continue 389 | for child in children[head]: 390 | q.append(path + [child]) 391 | 392 | prev = dict() 393 | for path in paths: 394 | for i in range(1, len(path)): 395 | prev[path[i]] = ' '.join(path[:i + 1]) 396 | 397 | descendant = dict() 398 | longest_subpath = dict() 399 | while paths: 400 | path = paths.pop(0) 401 | node = path[0] 402 | if node not in descendant: 403 | descendant[node] = set() 404 | longest_subpath[node] = 0 405 | descendant[node] |= set(path[1:]) 406 | longest_subpath[node] = max(longest_subpath[node], len(path) - 1) 407 | if len(path) > 1: 408 | paths.append(path[1:]) 409 | 410 | sorted_nodes = sorted([(len(prev[node].split()), prev[node], node) for node in prev]) 411 | if not sorted_nodes: 412 | continue 413 | 414 | n_tree += 1 415 | lines = [] 416 | for _, _, node in sorted_nodes: 417 | if node == root: 418 | continue 419 | if node not in updown: 420 | continue 421 | n_node += 1 422 | lines.append('%s\t%i\t%i\t%i\t%i'%( 423 | prev[node], # turns: path from its root to this node 424 | len(descendant[node]), # vol: num of descendants of this node 425 | len(children.get(node, [])), # width: num of direct childrent of this node 426 | longest_subpath[node], # depth: num of longest subpath of this node 427 | updown[node], # updown: `upvotes - downvotes` of this node 428 | )) 429 | with open(path_out, 'a', encoding='utf-8') as f: 430 | f.write('\n'.join(lines) + '\n') 431 | 432 | if n_tree: 433 | s = '[%s %s] %i tree %i nodes'%(sub, year, n_tree, n_node) 434 | else: 435 | s = '[%s %s] trees are empty'%(sub, year) 436 | with open(path_done, 'w') as f: 437 | f.write(s) 438 | print(s) 439 | 440 | 441 | def create_pairs(year, sub, feedback, overwrite=False): 442 | fld = '%s/%s'%(fld_subs, sub) 443 | path_out = '%s/%i_%s.tsv'%(fld, year, feedback) 444 | path_done = path_out + '.done' 445 | if not overwrite and os.path.exists(path_done): 446 | return 447 | 448 | ix_feedback = ['vol', 'width', 'depth', 'updown'].index(feedback) + 1 449 | path_in = '%s/%i_feedback.tsv'%(fld, year) 450 | if not os.path.exists(path_in): 451 | return 452 | 453 | time = dict() 454 | path_time = '%s/%i_time.tsv'%(fld, year) 455 | if not os.path.exists(path_time): 456 | return 457 | for line in open(path_time): 458 | ss = line.strip('\n').split('\t') 459 | if len(ss) == 2: 460 | name, t = ss 461 | time[name] = int(t) 462 | 463 | open(path_out, 'w', encoding='utf-8') 464 | print('[%s %s] creating pairs...'%(sub, year)) 465 | 466 | def match_time(replies, cxt): 467 | scores = sorted(set([score for score, _ in replies])) 468 | m = len(scores) 469 | if m < 2: 470 | return 0 # can't create pairs if m < 2 471 | cand = [] 472 | for score, reply in replies: 473 | if reply not in time: 474 | continue 475 | cand.append((time[reply], score, reply)) 476 | cand = sorted(cand) 477 | rank = [scores.index(score) / (m - 1) for _, score, _ in cand] 478 | lines = [] 479 | for i in range(len(cand) - 1): 480 | t_a, score_a, a = cand[i] 481 | t_b, score_b, b = cand[i + 1] 482 | rank_a = rank[i] 483 | rank_b = rank[i + 1] 484 | if score_a == score_b: 485 | continue 486 | hr = (t_b - t_a)/3600 487 | if score_b > score_a: 488 | score_a, score_b = score_b, score_a 489 | a, b = b, a 490 | rank_a, rank_b = rank_b, rank_a 491 | lines.append('\t'.join([ 492 | cxt, 493 | a, 494 | b, 495 | '%.2f'%hr, 496 | '%i'%score_a, 497 | '%i'%score_b, 498 | '%.4f'%rank_a, 499 | '%.4f'%rank_b, 500 | ])) 501 | #pdb.set_trace() 502 | if lines: 503 | with open(path_out, 'a') as f: 504 | f.write('\n'.join(lines) + '\n') 505 | return len(lines) 506 | 507 | n_line = 0 508 | prev = None 509 | replies = [] 510 | for line in open(path_in): 511 | if line.startswith('#'): 512 | continue 513 | ss = line.strip('\n').split('\t') 514 | turns = ss[0].split() # including both cxt and resp 515 | if len(turns) < 2: 516 | continue 517 | reply = turns[-1] 518 | try: 519 | score = int(ss[ix_feedback]) 520 | except ValueError: 521 | continue 522 | parent = turns[-2] 523 | if parent == prev: 524 | replies.append((score, reply)) 525 | else: 526 | if replies: 527 | n_line += match_time(replies, cxt) 528 | cxt = ' '.join(turns[:-1]) 529 | prev = parent 530 | replies = [(score, reply)] 531 | if replies: 532 | n_line += match_time(replies, cxt) 533 | 534 | s = '[%s %s %s] %i pairs'%(sub, year, feedback, n_line) 535 | with open(path_done, 'w') as f: 536 | f.write(s) 537 | print(s) 538 | 539 | 540 | def add_seq(sub, year, feedback, overwrite=False): 541 | fname = '%i_%s'%(year, feedback) 542 | fld = '%s/%s'%(fld_subs, sub) 543 | turn_sep = ' 50256 ' 544 | path_out = fld + '/%s_ids.tsv'%fname 545 | path_done = path_out + '.done' 546 | 547 | if os.path.exists(path_done) and not overwrite: 548 | return 549 | if not os.path.exists(fld + '/%s.tsv'%fname): 550 | return 551 | 552 | seq = dict() 553 | path = '%s/%i_txt.tsv'%(fld, year) 554 | if not os.path.exists(path): 555 | return 556 | for line in open(path, encoding='utf-8'): 557 | ss = line.strip('\n').split('\t') 558 | if len(ss) != 3: 559 | continue 560 | name, txt, ids = ss 561 | seq[name] = ids 562 | print('loaded %i seq'%len(seq)) 563 | open(path_out, 'w', encoding='utf-8') 564 | print('[%s %s %s] adding seq'%(sub, year, feedback)) 565 | path = fld + '/%s.tsv'%fname 566 | lines = [] 567 | n = 0 568 | m = 0 569 | for line in open(path, encoding='utf-8'): 570 | line = line.strip('\n') 571 | if line.startswith('#'): 572 | continue 573 | 574 | n += 1 575 | ss = line.split('\t') 576 | if len(ss) < 7: 577 | continue 578 | name_cxt, name_pos, name_neg = ss[:3] 579 | 580 | cxt = [] 581 | ok = True 582 | for name in name_cxt.split(): 583 | if name in seq: 584 | cxt.append(seq[name]) 585 | else: 586 | ok = False 587 | break 588 | if not ok: 589 | continue 590 | cxt = turn_sep.join(cxt) 591 | 592 | if name_pos in seq: 593 | reply_pos = seq[name_pos] 594 | else: 595 | continue 596 | if name_neg in seq: 597 | reply_neg = seq[name_neg] 598 | else: 599 | continue 600 | 601 | lines.append('\t'.join([ 602 | cxt, reply_pos, reply_neg, 603 | name_cxt, name_pos, name_neg, 604 | ] + ss[3:])) 605 | m += 1 606 | if m % 1e4 == 0: 607 | with open(path_out, 'a', encoding='utf-8') as f: 608 | f.write('\n'.join(lines) + '\n') 609 | lines = [] 610 | 611 | with open(path_out, 'a', encoding='utf-8') as f: 612 | f.write('\n'.join(lines)) 613 | 614 | s = '[%s %s %s] pair seq %i/%i'%(sub, year, feedback, m, n) 615 | with open(path_done, 'w') as f: 616 | f.write(s) 617 | print(s) 618 | 619 | 620 | def combine_sub(year_from, year_to, feedback, overwrite=False, skip_same_pos=True): 621 | fld = '%s/%s'%(fld_out, feedback) 622 | os.makedirs(fld, exist_ok=True) 623 | path_out = fld + '/raw.tsv' 624 | path_done = path_out + '.done' 625 | if os.path.exists(path_done) and not overwrite: 626 | return path_out 627 | 628 | subs = sorted(os.listdir(fld_subs)) 629 | open(path_out, 'w', encoding='utf-8') 630 | lines = [] 631 | n = 0 632 | empty = True 633 | non_empty_subreddits = 0 634 | for sub in subs: 635 | empty = True 636 | for year in range(year_from, year_to + 1): 637 | path = '%s/%s/%i_%s_ids.tsv'%(fld_subs, sub, year, feedback) 638 | if not os.path.exists(path): 639 | continue 640 | for line in open(path, encoding='utf-8'): 641 | if line.startswith('#'): 642 | continue 643 | line = line.strip('\n') 644 | if not line: 645 | continue 646 | lines.append(line) 647 | empty = False 648 | n += 1 649 | if n % 1e5 == 0: 650 | with open(path_out, 'a', encoding='utf-8') as f: 651 | f.write('\n'.join(lines) + '\n') 652 | lines = [] 653 | s = '[%i %s] saved %.2f M lines from %i subreddits, now is %s'%(year, feedback, n/1e6, non_empty_subreddits + 1, sub) 654 | print(s) 655 | if not empty: 656 | non_empty_subreddits += 1 657 | 658 | with open(path_out, 'a', encoding='utf-8') as f: 659 | f.write('\n'.join(lines)) 660 | s = '[%i-%i %s] saved %.2f M lines from %i subreddits'%(year_from, year_to, feedback, n/1e6, non_empty_subreddits ) 661 | with open(path_done, 'w') as f: 662 | f.write(s) 663 | print(s) 664 | return path_out 665 | 666 | 667 | 668 | def split_by_root(path, p_test=0.01): 669 | 670 | print('spliting by root '+path) 671 | lines = { 672 | 'train': [], 673 | 'vali': [], 674 | } 675 | prev = None 676 | n = 0 677 | 678 | for k in lines: 679 | if len(lines[k]) == 0: 680 | continue 681 | open(path + '.' + k, 'w', encoding='utf-8') 682 | 683 | for line in open(path, encoding='utf-8'): 684 | line = line.strip('\n') 685 | if not line: 686 | continue 687 | cxt = line.split('\t')[3] 688 | root = cxt.strip().split()[0] 689 | if root != prev: 690 | if np.random.random() < p_test: 691 | k = 'vali' 692 | else: 693 | k = 'train' 694 | #pdb.set_trace() 695 | lines[k].append(line) 696 | prev = root 697 | n += 1 698 | if n % 1e6 == 0: 699 | print('read %i M'%(n/1e6)) 700 | for k in lines: 701 | if len(lines[k]) == 0: 702 | continue 703 | with open(path + '.' + k, 'a', encoding='utf-8') as f: 704 | f.write('\n'.join(lines[k]) + '\n') 705 | lines[k] = [] 706 | 707 | for k in lines: 708 | if len(lines[k]) == 0: 709 | continue 710 | with open(path + '.' + k, 'a', encoding='utf-8') as f: 711 | f.write('\n'.join(lines[k])) 712 | lines[k] = [] 713 | 714 | 715 | def shuffle(feedback, part, n_temp=10): 716 | fld = '%s/%s'%(fld_out, feedback) 717 | path = '%s/raw.tsv.%s'%(fld, part) 718 | path_out = '%s/%s.tsv'%(fld, part) 719 | fld_temp = '%s/temp/%s'%(fld_out, feedback) 720 | 721 | print('slicing '+path) 722 | os.makedirs(fld_temp, exist_ok=True) 723 | lines = [[] for _ in range(n_temp)] 724 | 725 | # split into n_temp files 726 | for i in range(n_temp): 727 | open(fld_temp + '/temp%i'%i, 'w', encoding='utf-8') 728 | n = 0 729 | count = [0] * n_temp 730 | rand = np.random.randint(0, n_temp, 202005) 731 | for line in open(path, encoding='utf-8'): 732 | line = line.strip('\n') 733 | if len(line) == 0: 734 | continue 735 | bucket = rand[n % len(rand)] 736 | lines[bucket].append(line) 737 | count[bucket] += 1 738 | n += 1 739 | if n % 1e6 == 0: 740 | print('read %i M'%(n/1e6)) 741 | for i in range(n_temp): 742 | if len(lines[i]) == 0: 743 | continue 744 | with open(fld_temp + '/temp%i'%i, 'a', encoding='utf-8') as f: 745 | f.write('\n'.join(lines[i]) + '\n') 746 | lines[i] = [] 747 | 748 | for i in range(n_temp): 749 | with open(fld_temp + '/temp%i'%i, 'a', encoding='utf-8') as f: 750 | f.write('\n'.join(lines[i])) 751 | 752 | # and then merge 753 | open(path_out, 'w', encoding='utf-8') 754 | print(fld_temp) 755 | for i in range(n_temp): 756 | print('reading temp%i'%i) 757 | lines = open(fld_temp + '/temp%i'%i, encoding='utf-8').readlines() 758 | print('shuffling') 759 | jj = list(range(len(lines))) 760 | np.random.shuffle(jj) 761 | print('writing') 762 | with open(path_out, 'a', encoding='utf-8') as f: 763 | f.write('\n'.join([lines[j].strip('\n') for j in jj]) + '\n') 764 | 765 | def get_subs(): 766 | return ['4chan'] 767 | print('collectiing subs...') 768 | subs = sorted(os.listdir(fld_subs)) 769 | print('collected %i subs'%len(subs)) 770 | return subs 771 | 772 | 773 | def build_json(year): 774 | for date in get_dates(year): 775 | extract_rc(date) 776 | extract_rs(date) 777 | 778 | 779 | def build_basic(year): 780 | from transformers import GPT2Tokenizer 781 | tokenizer = GPT2Tokenizer.from_pretrained('gpt2') 782 | subs = get_subs() 783 | for sub in subs: 784 | extract_time(sub, year) 785 | extract_txt(sub, year, tokenizer) 786 | extract_trees(sub, year) 787 | calc_feedback(sub, year, overwrite=False) 788 | 789 | 790 | def build_pairs(year_from, year_to, feedback): 791 | subs = get_subs() 792 | for year in range(year_from, year_to + 1): 793 | for sub in subs: 794 | create_pairs(year, sub, feedback, overwrite=False) 795 | add_seq(sub, year, feedback, overwrite=False) 796 | path = combine_sub(year_from, year_to, feedback) 797 | split_by_root(path) 798 | for part in ['train', 'vali']: 799 | shuffle(feedback, part) 800 | 801 | 802 | FLD = 'data' 803 | fld_bz2 = FLD + '/bz2/' 804 | fld_jsonl = FLD + '/jsonl/' 805 | fld_subs = FLD + '/subs/' 806 | fld_out = FLD + '/out/' 807 | 808 | if __name__ == "__main__": 809 | import argparse 810 | parser = argparse.ArgumentParser() 811 | parser.add_argument('task', type=str) 812 | parser.add_argument('year', type=int) 813 | parser.add_argument('--year_to', type=int) 814 | args = parser.parse_args() 815 | if args.task == 'bz2': 816 | build_json(args.year) 817 | elif args.task == 'basic': 818 | build_basic(args.year) 819 | elif args.task in ['updown', 'depth', 'width']: 820 | build_pairs(args.year, args.year_to, args.task) 821 | else: 822 | raise ValueError -------------------------------------------------------------------------------- /src/data_new.py: -------------------------------------------------------------------------------- 1 | """ 2 | data source: 3 | https://academictorrents.com/details/56aa49f9653ba545f48df2e33679f014d2829c10 4 | https://academictorrents.com/details/20520c420c6c846f555523babc8c059e9daa8fc5 5 | """ 6 | 7 | def zst2jsonl(path_zst): 8 | path_jsonl = path_zst + '.jsonl' 9 | open(path_jsonl, 'w') 10 | n_line = 0 11 | out = [] 12 | with open(path_zst, 'rb') as fh: 13 | dctx = zstd.ZstdDecompressor(max_window_size=2147483648) 14 | with dctx.stream_reader(fh) as reader: 15 | previous_line = "" 16 | while True: 17 | chunk = reader.read(2**24) # 16mb chunks 18 | if not chunk: 19 | break 20 | 21 | string_data = chunk.decode('utf-8') 22 | lines = string_data.split("\n") 23 | for i, line in enumerate(lines[:-1]): 24 | if i == 0: 25 | line = previous_line + line 26 | object = json.loads(line) 27 | out.append(json.dumps(object, ensure_ascii=False)) 28 | n_line += 1 29 | if n_line % 1e5 == 0: 30 | print(n_line) 31 | with open(path_zst + '.jsonl', 'a') as f: 32 | f.write('\n'.join(out) + '\n') 33 | out = [] 34 | previous_line = lines[-1] 35 | 36 | if out: 37 | with open(path_zst + '.jsonl', 'a') as f: 38 | f.write('\n'.join(out) + '\n') 39 | print(n_line) 40 | -------------------------------------------------------------------------------- /src/feeder.py: -------------------------------------------------------------------------------- 1 | # author: Xiang Gao at Microsoft Research AI NLP Group 2 | 3 | 4 | import torch, os, pdb 5 | import numpy as np 6 | 7 | 8 | class Feeder: 9 | # load train/vali/test data 10 | 11 | def __init__(self, opt): 12 | self.opt = opt 13 | self.files = dict() 14 | if self.opt.mismatch: 15 | self.files_mismatch = dict() 16 | for sub in ['train', 'vali', 'test']: 17 | self.reset(sub) 18 | self.ix_EOS = 50256 19 | self.ix_OMT = 986 20 | 21 | 22 | def reset(self, sub): 23 | print('resetting '+sub) 24 | path = '%s/%s.tsv'%(self.opt.fld_data, sub) 25 | if os.path.exists(path): 26 | self.files[sub] = open(path) 27 | if self.opt.mismatch: 28 | self.files_mismatch[sub] = open(path) 29 | # assuming f is already shuffled, this step makes f and f_mismatch mismatch 30 | for _ in range(100): 31 | self.files[sub].readline() 32 | 33 | 34 | def get_batch(self, size, sub='train', min_score_gap=1, min_rank_gap=0): 35 | ids_pos = [] 36 | len_pos = [] 37 | ids_neg = [] 38 | len_neg = [] 39 | len_cxt = [] 40 | score_pos = [] 41 | score_neg = [] 42 | rank_pos = [] 43 | rank_neg = [] 44 | hr_gap = [] 45 | if sub != 'train': 46 | np.random.seed(2020) 47 | 48 | def ints(s): 49 | return [int(x) for x in s.split()] 50 | def pad(seq): 51 | return seq + [self.ix_EOS] * (self.opt.max_seq_len - len(seq)) 52 | 53 | def read(): 54 | total = 0 55 | used = 0 56 | for line in self.files[sub]: 57 | if line.startswith('#'): 58 | continue 59 | # old data is title + ' . ' + selftext, ' .' is 764 and often used as ' .jpg' thus misleading 60 | line = line.replace(' 764\t', '\t').replace(' 764 50256 ', ' 50256 ') 61 | total += 1 62 | ss = line.strip('\n').split('\t') 63 | cxt = ints(ss[0]) 64 | reply_pos = ints(ss[1]) 65 | # _score_pos, _score_neg, _rank_pos, _rank_neg = ss[-4:] 66 | try: 67 | _hr_gap = float(ss[-5]) 68 | except ValueError: 69 | _hr_gap = np.nan 70 | _score_pos = int(ss[-4]) 71 | _rank_pos = float(ss[-2]) 72 | 73 | if self.opt.mismatch: 74 | _score_neg = np.nan 75 | _rank_neg = np.nan 76 | line_mismatch = self.files_mismatch[sub].readline() 77 | ss_mismatch = line_mismatch.strip('\n').split('\t') 78 | reply_neg = ints(ss_mismatch[1]) 79 | 80 | else: 81 | reply_neg = ints(ss[2]) 82 | _score_neg = int(ss[-3]) 83 | _rank_neg = float(ss[-1]) 84 | if _score_pos - _score_neg < min_score_gap: 85 | continue 86 | if _rank_pos - _rank_neg < min_rank_gap: 87 | continue 88 | if self.opt.max_hr_gap > 0 and _hr_gap > self.opt.max_hr_gap: 89 | continue 90 | 91 | pos = cxt + [self.ix_EOS] + reply_pos 92 | score_pos.append(_score_pos) 93 | rank_pos.append(_rank_pos) 94 | 95 | neg = cxt + [self.ix_EOS] + reply_neg 96 | score_neg.append(_score_neg) 97 | rank_neg.append(_rank_neg) 98 | 99 | # make sure cxt still same even after cut 100 | n_del = max(len(pos), len(neg)) - self.opt.max_seq_len 101 | if n_del > 0: 102 | pos = pos[n_del:] 103 | neg = neg[n_del:] 104 | cxt = cxt[n_del:] 105 | 106 | len_cxt.append(len(cxt)) 107 | len_pos.append(len(pos)) 108 | len_neg.append(len(neg)) 109 | ids_pos.append(pad(pos)) 110 | ids_neg.append(pad(neg)) 111 | hr_gap.append(_hr_gap) 112 | 113 | used += 1 114 | if len(ids_pos) == size: 115 | break 116 | 117 | while True: 118 | read() 119 | if len(ids_pos) == size: 120 | break 121 | self.reset(sub) 122 | 123 | ids_pos = torch.LongTensor(ids_pos) 124 | ids_neg = torch.LongTensor(ids_neg) 125 | if self.opt.cuda: 126 | ids_pos = ids_pos.cuda() 127 | ids_neg = ids_neg.cuda() 128 | return { 129 | 'ids_pos':ids_pos, 130 | 'ids_neg':ids_neg, 131 | 'len_pos':len_pos, 132 | 'len_neg':len_neg, 133 | 'len_cxt':len_cxt, 134 | 'score_pos': score_pos, 135 | 'score_neg': score_neg, 136 | 'rank_pos': rank_pos, 137 | 'rank_neg': rank_neg, 138 | 'hr_gap': hr_gap, 139 | } -------------------------------------------------------------------------------- /src/generation.py: -------------------------------------------------------------------------------- 1 | # author: Xiang Gao at Microsoft Research AI NLP Group 2 | 3 | import torch, pdb 4 | import numpy as np 5 | from shared import download_model, EOS_token 6 | 7 | class GPT2Generator: 8 | 9 | def __init__(self, path, cuda): 10 | from transformers19 import GPT2Tokenizer, GPT2LMHeadModel, GPT2Config 11 | self.tokenizer = GPT2Tokenizer.from_pretrained('gpt2') 12 | model_config = GPT2Config(n_embd=1024, n_layer=24, n_head=16) 13 | self.model = GPT2LMHeadModel(model_config) 14 | download_model(path) 15 | print('loading from '+path) 16 | weights = torch.load(path) 17 | if "lm_head.decoder.weight" in weights: 18 | weights["lm_head.weight"] = weights["lm_head.decoder.weight"] 19 | weights.pop("lm_head.decoder.weight",None) 20 | self.model.load_state_dict(weights) 21 | self.ix_EOS = 50256 22 | self.model.eval() 23 | self.cuda = cuda 24 | if self.cuda: 25 | self.model.cuda() 26 | 27 | 28 | def tokenize(self, cxt): 29 | turns = cxt.split(EOS_token) 30 | ids = [] 31 | for turn in turns: 32 | ids += self.tokenizer.encode(turn.strip()) + [self.ix_EOS] 33 | ids = torch.tensor([ids]).view(1, -1) 34 | if self.cuda: 35 | ids = ids.cuda() 36 | return ids 37 | 38 | 39 | def predict_beam(self, cxt, topk=3, topp=0.8, beam=10, max_t=30): 40 | """ pick top tokens at each time step """ 41 | 42 | tokens = self.tokenize(cxt) 43 | len_cxt = tokens.shape[1] 44 | sum_logP = [0] 45 | finished = [] 46 | 47 | for _ in range(max_t): 48 | outputs = self.model(tokens) 49 | predictions = outputs[0] 50 | logP = torch.log_softmax(predictions[:, -1, :], dim=-1) 51 | next_logP, next_token = torch.topk(logP, topk) 52 | sumlogP_ij = [] 53 | sum_prob = 0 54 | for i in range(tokens.shape[0]): 55 | for j in range(topk): 56 | sum_prob += np.exp(logP[i, j].item()) 57 | if sum_prob > topp: 58 | break 59 | if next_token[i, j] == self.ix_EOS: 60 | seq = torch.cat([tokens[i, len_cxt:], next_token[i, j].view(1)], dim=-1) 61 | if self.cuda: 62 | seq = seq.cpu() 63 | seq = seq.detach().numpy().tolist() 64 | prob = np.exp((sum_logP[i] + next_logP[i, j].item()) / len(seq)) 65 | hyp = self.tokenizer.decode(seq[:-1]) # don't include EOS 66 | finished.append((prob, hyp)) 67 | else: 68 | sumlogP_ij.append(( 69 | sum_logP[i] + next_logP[i, j].item(), 70 | i, j)) 71 | 72 | if not sumlogP_ij: 73 | break 74 | sumlogP_ij = sorted(sumlogP_ij, reverse=True)[:min(len(sumlogP_ij), beam)] 75 | new_tokens = [] 76 | new_sum_logP = [] 77 | for _sum_logP, i, j in sumlogP_ij: 78 | new_tokens.append( 79 | torch.cat([tokens[i,:], next_token[i, j].view(1)], dim=-1).view(1, -1) 80 | ) 81 | new_sum_logP.append(_sum_logP) 82 | tokens = torch.cat(new_tokens, dim=0) 83 | sum_logP = new_sum_logP 84 | 85 | return finished 86 | 87 | 88 | def predict_sampling(self, cxt, temperature=1, n_hyp=5, max_t=30): 89 | """ sampling tokens based on predicted probability """ 90 | 91 | tokens = self.tokenize(cxt) 92 | tokens = tokens.repeat(n_hyp, 1) 93 | len_cxt = tokens.shape[1] 94 | sum_logP = [0] * n_hyp 95 | live = [True] * n_hyp 96 | seqs = [[] for _ in range(n_hyp)] 97 | np.random.seed(2020) 98 | for _ in range(max_t): 99 | outputs = self.model(tokens) 100 | predictions = outputs[0] 101 | prob = torch.softmax(predictions[:, -1, :] / temperature, dim=-1) 102 | if self.cuda: 103 | prob = prob.cpu() 104 | prob = prob.detach().numpy() 105 | vocab = prob.shape[-1] 106 | next_tokens = [] 107 | for i in range(n_hyp): 108 | next_token = np.random.choice(vocab, p=prob[i,:]) 109 | next_tokens.append(next_token) 110 | if not live[i]: 111 | continue 112 | sum_logP[i] += np.log(prob[i, next_token]) 113 | seqs[i].append(next_token) 114 | if next_token == self.ix_EOS: 115 | live[i] = False 116 | continue 117 | next_tokens = torch.LongTensor(next_tokens).view(-1, 1) 118 | if self.cuda: 119 | next_tokens = next_tokens.cuda() 120 | tokens = torch.cat([tokens, next_tokens], dim=-1) 121 | 122 | ret = [] 123 | for i in range(n_hyp): 124 | if live[i]: # only return hyp that ends with EOS 125 | continue 126 | prob = np.exp(sum_logP[i] / (len(seqs[i]) + 1)) 127 | hyp = self.tokenizer.decode(seqs[i][:-1]) # strip EOS 128 | ret.append((prob, hyp)) 129 | return ret 130 | 131 | 132 | def play(self, params): 133 | while True: 134 | cxt = input('\nContext:\t') 135 | if not cxt: 136 | break 137 | ret = self.predict(cxt, **params) 138 | for prob, hyp in sorted(ret, reverse=True): 139 | print('%.3f\t%s'%(prob, hyp)) 140 | 141 | 142 | class Integrated: 143 | def __init__(self, generator, ranker): 144 | self.generator = generator 145 | self.ranker = ranker 146 | 147 | def predict(self, cxt, wt_ranker, params): 148 | with torch.no_grad(): 149 | prob_hyp = self.generator.predict(cxt, **params) 150 | probs = np.array([prob for prob, _ in prob_hyp]) 151 | hyps = [hyp for _, hyp in prob_hyp] 152 | if wt_ranker > 0: 153 | scores_ranker = self.ranker.predict(cxt, hyps) 154 | if isinstance(scores_ranker, dict): 155 | scores_ranker = scores_ranker['final'] 156 | scores = wt_ranker * scores_ranker + (1 - wt_ranker) * probs 157 | else: 158 | scores = probs 159 | ret = [] 160 | for i in range(len(hyps)): 161 | ret.append((scores[i], probs[i], scores_ranker[i], hyps[i])) 162 | ret = sorted(ret, reverse=True) 163 | return ret 164 | 165 | 166 | def play(self, wt_ranker, params): 167 | while True: 168 | cxt = input('\nContext:\t') 169 | if not cxt: 170 | break 171 | ret = self.predict(cxt, wt_ranker, params) 172 | for final, prob_gen, score_ranker, hyp in ret: 173 | print('%.3f gen %.3f ranker %.3f\t%s'%(final, prob_gen, score_ranker, hyp)) 174 | 175 | 176 | def test(model, path_in, wt_ranker, params, max_n): 177 | lines = [] 178 | for i, line in enumerate(open(path_in, encoding='utf-8')): 179 | print('processing %i-th context'%i) 180 | cxt = line.strip('\n').split('\t')[0] 181 | ret = model.predict(cxt, wt_ranker, **params) 182 | cc = [cxt] + [tup[-1] for tup in ret] 183 | lines.append('\t'.join(cc)) 184 | if i == max_n: 185 | break 186 | path_out = path_in + '.hyps' 187 | with open(path_out, 'w', encoding='utf-8') as f: 188 | f.write('\n'.join(lines)) 189 | print('saved to '+path_out) 190 | 191 | 192 | 193 | if __name__ == "__main__": 194 | import argparse 195 | parser = argparse.ArgumentParser() 196 | parser.add_argument('task', type=str) 197 | parser.add_argument('--path_generator', '-pg', type=str) 198 | parser.add_argument('--path_ranker', '-pr', type=str) 199 | parser.add_argument('--path_test', type=str) 200 | parser.add_argument('--cpu', action='store_true') 201 | parser.add_argument('--sampling', action='store_true') 202 | parser.add_argument('--topk', type=int, default=3) 203 | parser.add_argument('--beam', type=int, default=3) 204 | parser.add_argument('--wt_ranker', type=float, default=1.) 205 | parser.add_argument('--topp', type=float, default=0.8) 206 | parser.add_argument('--max_n', type=int, default=-1) 207 | parser.add_argument('--temperature', type=float, default=0.5) 208 | parser.add_argument('--n_hyp', type=int, default=10) 209 | args = parser.parse_args() 210 | 211 | cuda = False if args.cpu else torch.cuda.is_available() 212 | generator = GPT2Generator(args.path_generator, cuda) 213 | if args.sampling: 214 | params = {'temperature': args.temperature, 'n_hyp': args.n_hyp} 215 | generator.predict = generator.predict_sampling 216 | else: 217 | params = {'topk': args.topk, 'beam': args.beam, 'topp': args.topp} 218 | generator.predict = generator.predict_beam 219 | 220 | if args.path_ranker is None: 221 | model = generator 222 | else: 223 | from score import get_model 224 | ranker = get_model(args.path_ranker, cuda) 225 | model = Integrated(generator, ranker) 226 | 227 | if args.task == 'play': 228 | model.play(args.wt_ranker, params) 229 | elif args.task == 'test': 230 | test(model, args.path_test, params, args.max_n) -------------------------------------------------------------------------------- /src/main.py: -------------------------------------------------------------------------------- 1 | # author: Xiang Gao at Microsoft Research AI NLP Group 2 | 3 | 4 | import argparse, torch, time, pdb 5 | import os, socket 6 | from master import Master 7 | 8 | 9 | class Option: 10 | 11 | def __init__(self, args): 12 | if args.cpu or not torch.cuda.is_available(): 13 | self.cuda = False 14 | else: 15 | self.cuda = True 16 | self.task = args.task 17 | self.path_load = args.path_load 18 | self.batch = args.batch 19 | self.vali_size = max(self.batch, args.vali_size) 20 | self.vali_print = args.vali_print 21 | self.lr = args.lr 22 | self.max_seq_len = args.max_seq_len 23 | self.min_score_gap = args.min_score_gap 24 | self.min_rank_gap = args.min_rank_gap 25 | self.max_hr_gap = args.max_hr_gap 26 | self.mismatch = args.mismatch 27 | self.fld_data = args.data 28 | if args.task == 'train' or self.path_load is None: 29 | self.fld_out = 'out/%i'%time.time() 30 | else: 31 | self.fld_out = 'out/temp' 32 | os.makedirs(self.fld_out, exist_ok=True) 33 | 34 | self.clip = 1 35 | self.step_max = 1e6 36 | self.step_print = 10 37 | self.step_vali = 100 38 | self.step_save = 500 39 | self.len_acc = self.step_vali 40 | 41 | 42 | def save(self): 43 | d = self.__dict__ 44 | lines = [] 45 | for k in d: 46 | lines.append('%s\t%s'%(k, d[k])) 47 | with open(self.fld_out + '/opt.tsv', 'w') as f: 48 | f.write('\n'.join(lines)) 49 | 50 | 51 | if __name__ == "__main__": 52 | parser = argparse.ArgumentParser() 53 | parser.add_argument('task', type=str) 54 | parser.add_argument('--data', type=str) 55 | parser.add_argument('--batch', type=int, default=256) 56 | parser.add_argument('--vali_size', type=int, default=1024) 57 | parser.add_argument('--vali_print', type=int, default=10) 58 | parser.add_argument('--lr', type=float, default=3e-5) 59 | parser.add_argument('--path_load','-p', type=str) 60 | parser.add_argument('--cpu', action='store_true') 61 | parser.add_argument('--max_seq_len', type=int, default=50) 62 | parser.add_argument('--mismatch', action='store_true') 63 | parser.add_argument('--min_score_gap', type=int) 64 | parser.add_argument('--min_rank_gap', type=float) 65 | parser.add_argument('--max_hr_gap', type=float, default=1) 66 | args = parser.parse_args() 67 | 68 | opt = Option(args) 69 | master = Master(opt) 70 | if args.task == 'train': 71 | master.train() 72 | elif args.task == 'vali': 73 | master.vali() -------------------------------------------------------------------------------- /src/master.py: -------------------------------------------------------------------------------- 1 | # author: Xiang Gao at Microsoft Research AI NLP Group 2 | 3 | 4 | import torch, os, pdb, time, sys, warnings 5 | import numpy as np 6 | from feeder import Feeder 7 | from model import Scorer, JointScorer 8 | import matplotlib.pyplot as plt 9 | 10 | 11 | class Master: 12 | 13 | def __init__(self, opt): 14 | self.opt = opt 15 | if opt.path_load is not None and (opt.path_load.endswith('.yaml') or opt.path_load.endswith('.yml')): 16 | self._model = JointScorer(opt) 17 | else: 18 | self._model = Scorer(opt) 19 | if opt.path_load is not None: 20 | self._model.load(opt.path_load) 21 | self.parallel() 22 | 23 | if opt.task != 'play': 24 | if opt.fld_data is not None: 25 | self.feeder = Feeder(opt) 26 | 27 | if opt.task == 'train': 28 | opt.save() 29 | os.makedirs(opt.fld_out + '/ckpt', exist_ok=True) 30 | self.path_log = self.opt.fld_out + '/log.txt' 31 | else: 32 | self.path_log = self.opt.fld_out + '/log_infer.txt' 33 | 34 | 35 | def print(self, s=''): 36 | try: 37 | print(s) 38 | except UnicodeEncodeError: 39 | print('[UnicodeEncodeError]') 40 | pass 41 | with open(self.path_log, 'a', encoding='utf-8') as f: 42 | f.write(s+'\n') 43 | 44 | 45 | def parallel(self): 46 | if self.opt.cuda: 47 | self._model = self._model.cuda() 48 | n_gpu = torch.cuda.device_count() 49 | if self.opt.cuda and n_gpu > 1: 50 | print('paralleling on %i GPU'%n_gpu) 51 | self.model = torch.nn.DataParallel(self._model) 52 | # after DataParallel, a warning about RNN weights shows up every batch 53 | warnings.filterwarnings("ignore") 54 | # after DataParallel, attr of self.model become attr of self.model.module 55 | self._model = self.model.module 56 | self.model.core = self.model.module.core 57 | self.model.tokenizer = self._model.tokenizer 58 | else: 59 | self.model = self._model 60 | if self.opt.task == 'train': 61 | self.optimizer = torch.optim.Adam(self._model.parameters(), lr=self.opt.lr) 62 | 63 | 64 | def train(self): 65 | vali_loss, best_acc = self.vali() 66 | best_trained = 0 67 | step = 0 68 | n_trained = 0 69 | t0 = time.time() 70 | 71 | list_trained = [0] 72 | list_train_loss = [np.nan] 73 | list_train_acc = [np.nan] 74 | list_vali_loss = [vali_loss] 75 | list_vali_acc = [best_acc] 76 | acc_history = [] 77 | 78 | while step < self.opt.step_max: 79 | self.model.train() 80 | self.optimizer.zero_grad() 81 | batch = self.feeder.get_batch(self.opt.batch) 82 | pred = self.model.forward(batch) 83 | loss = self.loss(pred) 84 | loss = loss.mean() # in case of parallel-training 85 | 86 | loss.backward() 87 | torch.nn.utils.clip_grad_norm_(self.model.parameters(), self.opt.clip) 88 | self.optimizer.step() 89 | 90 | acc = (pred > 0.5).float().mean().item() 91 | acc_history.append(acc) 92 | if len(acc_history) > self.opt.len_acc: 93 | acc_history.pop(0) 94 | avg_train_acc = np.mean(acc_history) 95 | step += 1 96 | n_trained += self.opt.batch 97 | info = 'step %i trained %.3f best %.2f'%(step, n_trained/1e6, best_acc) 98 | 99 | if step % self.opt.step_print == 0: 100 | speed = (n_trained / 1e6) / ((time.time() - t0) / 3600) 101 | 102 | self.print('%s speed %.2f hr_gap %.2f score_gap %.2f rank_gap %.2f loss %.4f acc %.3f'%( 103 | info, 104 | speed, 105 | np.median(batch['hr_gap']), 106 | (np.array(batch['score_pos']) - np.array(batch['score_neg'])).mean(), 107 | (np.array(batch['rank_pos']) - np.array(batch['rank_neg'])).mean(), 108 | loss, 109 | avg_train_acc, 110 | )) 111 | 112 | if step % self.opt.step_vali == 0: 113 | vali_loss, vali_acc = self.vali(info) 114 | if vali_acc > best_acc: 115 | self.save(self.opt.fld_out + '/ckpt/best.pth') 116 | best_acc = vali_acc 117 | best_trained = n_trained 118 | sys.stdout.flush() 119 | 120 | list_trained.append(n_trained/1e6) 121 | list_train_loss.append(loss.item()) 122 | list_train_acc.append(avg_train_acc) 123 | list_vali_loss.append(vali_loss) 124 | list_vali_acc.append(vali_acc) 125 | _, axs = plt.subplots(3, 1, sharex=True) 126 | 127 | axs[0].plot(list_trained, list_train_loss, 'b', label='train') 128 | axs[0].plot(list_trained, list_vali_loss, 'r', label='vali') 129 | axs[0].legend(loc='best') 130 | axs[0].set_ylabel('loss') 131 | 132 | axs[1].plot(list_trained, list_train_acc, 'b', label='train') 133 | axs[1].plot(list_trained, list_vali_acc, 'r', label='vali') 134 | axs[1].plot([best_trained/1e6, n_trained/1e6], [best_acc, best_acc], 'k:') 135 | axs[1].set_ylabel('acc') 136 | 137 | axs[-1].set_xlabel('trained (M)') 138 | axs[0].set_title(self.opt.fld_out + '\n' + self.opt.fld_data + '\nbest_acc = %.4f'%best_acc) 139 | plt.tight_layout() 140 | plt.savefig(self.opt.fld_out + '/log.png') 141 | plt.close() 142 | 143 | if step % self.opt.step_save == 0: 144 | self.save(self.opt.fld_out + '/ckpt/last.pth') 145 | 146 | 147 | def loss(self, pred): 148 | # pred is the probability to pick the positive response, given a context and a negative response 149 | return - torch.log(pred).mean() 150 | 151 | 152 | def vali(self, info=''): 153 | n_print = min(self.opt.batch, self.opt.vali_print) 154 | self.model.eval() 155 | loss = 0 156 | acc = 0 157 | hr_gap = 0 158 | score_gap = 0 159 | rank_gap = 0 160 | n_batch = int(self.opt.vali_size/self.opt.batch) 161 | self.feeder.reset('vali') 162 | 163 | for _ in range(n_batch): 164 | batch = self.feeder.get_batch(self.opt.batch, sub='vali', 165 | min_score_gap=self.opt.min_score_gap, min_rank_gap=self.opt.min_rank_gap) 166 | with torch.no_grad(): 167 | pred = self.model.forward(batch) 168 | loss += self.loss(pred) 169 | acc += (pred > 0.5).float().mean() 170 | score_gap += (np.array(batch['score_pos']) - np.array(batch['score_neg'])).mean() 171 | rank_gap += (np.array(batch['rank_pos']) - np.array(batch['rank_neg'])).mean() 172 | hr_gap += np.median(batch['hr_gap']) 173 | 174 | loss /= n_batch 175 | acc /= n_batch 176 | score_gap /= n_batch 177 | rank_gap /= n_batch 178 | hr_gap /= n_batch 179 | s = '%s hr_gap %.2f score_gap %.2f rank_gap %.2f loss %.4f acc %.3f'%( 180 | info, 181 | hr_gap, 182 | score_gap, 183 | rank_gap, 184 | loss, 185 | acc, 186 | ) 187 | s = '[vali] ' + s.strip() 188 | if not n_print: 189 | self.print(s) 190 | return loss.mean().item(), acc 191 | 192 | with torch.no_grad(): 193 | pred_pos = self.model.core(batch['ids_pos'], batch['len_pos']) 194 | pred_neg = self.model.core(batch['ids_neg'], batch['len_neg']) 195 | 196 | def to_np(ids): 197 | if self.opt.cuda: 198 | ids = ids.cpu() 199 | return ids.detach().numpy() 200 | 201 | ids_pos = to_np(batch['ids_pos']) 202 | ids_neg = to_np(batch['ids_neg']) 203 | 204 | for j in range(n_print): 205 | l_cxt = batch['len_cxt'][j] 206 | cxt = self.model.tokenizer.decode(ids_pos[j, :l_cxt]) 207 | pos = self.model.tokenizer.decode(ids_pos[j, l_cxt:]).strip('<|ndoftext|>') 208 | neg = self.model.tokenizer.decode(ids_neg[j, l_cxt:]).strip('<|ndoftext|>') 209 | self.print(cxt) 210 | self.print('hr_gap %s'%batch['hr_gap'][j]) 211 | self.print('%s\t%.2f\t%.3f\t%s'%(batch['score_pos'][j], batch['rank_pos'][j], pred_pos[j], pos)) 212 | self.print('%s\t%.2f\t%.3f\t%s'%(batch['score_neg'][j], batch['rank_neg'][j], pred_neg[j], neg)) 213 | self.print() 214 | 215 | self.print(s) 216 | return loss.mean().item(), acc 217 | 218 | 219 | def save(self, path): 220 | torch.save(self._model.state_dict(), path) 221 | self.print('saved to '+path) 222 | -------------------------------------------------------------------------------- /src/model.py: -------------------------------------------------------------------------------- 1 | # author: Xiang Gao at Microsoft Research AI NLP Group 2 | 3 | 4 | import torch, os, pdb 5 | import numpy as np 6 | from transformers19 import GPT2Tokenizer, GPT2Model, GPT2Config 7 | from shared import EOS_token 8 | 9 | 10 | class OptionInfer: 11 | def __init__(self, cuda=True): 12 | self.cuda = cuda 13 | 14 | 15 | class ScorerBase(torch.nn.Module): 16 | def __init__(self, opt): 17 | super().__init__() 18 | self.ix_EOS = 50256 19 | self.ix_OMT = 986 20 | self.opt = opt 21 | self.tokenizer = GPT2Tokenizer.from_pretrained('gpt2') 22 | 23 | 24 | def core(self, ids, l_ids, return_logits=False): 25 | # to be implemented in child class 26 | return 0 27 | 28 | 29 | def predict(self, cxt, hyps, max_cxt_turn=None): 30 | # cxt = str 31 | # hyps = list of str 32 | 33 | self.eval() 34 | cxt_turns = cxt.split(EOS_token) 35 | if max_cxt_turn is not None: 36 | cxt_turns = cxt_turns[-min(max_cxt_turn, len(cxt_turns)):] 37 | ids_cxt = [] 38 | for turn in cxt_turns: 39 | ids_cxt += self.tokenizer.encode(turn.strip()) + [self.ix_EOS] 40 | seqs = [] 41 | lens = [] 42 | for hyp in hyps: 43 | seq = ids_cxt + self.tokenizer.encode(hyp.strip()) 44 | lens.append(len(seq)) 45 | seqs.append(seq) 46 | max_len = max(lens) 47 | ids = [] 48 | for seq in seqs: 49 | ids.append(seq + [self.ix_EOS] * (max_len - len(seq))) 50 | with torch.no_grad(): 51 | ids = torch.LongTensor(ids) 52 | if self.opt.cuda: 53 | ids = ids.cuda() 54 | scores = self.core(ids, lens) 55 | if not isinstance(scores, dict): 56 | if self.opt.cuda: 57 | scores = scores.cpu() 58 | return scores.detach().numpy() 59 | 60 | for k in scores: 61 | if self.opt.cuda: 62 | scores[k] = scores[k].cpu() 63 | scores[k] = scores[k].detach().numpy() 64 | return scores 65 | 66 | 67 | def forward(self, batch): 68 | logits_pos = self.core(batch['ids_pos'], batch['len_pos'], return_logits=True) 69 | logits_neg = self.core(batch['ids_neg'], batch['len_neg'], return_logits=True) 70 | # softmax to get the `probability` to rank pos/neg correctly 71 | return torch.exp(logits_pos) / (torch.exp(logits_pos) + torch.exp(logits_neg)) 72 | 73 | 74 | 75 | class Scorer(ScorerBase): 76 | def __init__(self, opt): 77 | super().__init__(opt) 78 | n_embd = 1024 79 | config = GPT2Config(n_embd=n_embd, n_layer=24, n_head=16) 80 | self.transformer = GPT2Model(config) 81 | self.score = torch.nn.Linear(n_embd, 1, bias=False) 82 | 83 | 84 | def core(self, ids, l_ids, return_logits=False): 85 | n = ids.shape[0] 86 | attention_mask = torch.ones_like(ids) 87 | for i in range(n): 88 | attention_mask[i, l_ids[i]:] *= 0 89 | hidden_states, _ = self.transformer(ids, attention_mask=attention_mask) 90 | logits = self.score(hidden_states).squeeze(-1) 91 | logits = torch.stack([logits[i, l_ids[i] - 1] for i in range(n)]) 92 | if return_logits: 93 | return logits 94 | else: 95 | return torch.sigmoid(logits) 96 | 97 | 98 | def load(self, path): 99 | from shared import download_model 100 | download_model(path) 101 | print('loading from '+path) 102 | weights = torch.load(path, map_location=torch.device('cpu')) 103 | if path.endswith('.pkl'): 104 | # DialoGPT checkpoint 105 | weights['score.weight'] = weights['lm_head.decoder.weight'][self.ix_EOS: self.ix_EOS+1, :] 106 | del weights['lm_head.decoder.weight'] 107 | self.load_state_dict(weights) 108 | if self.opt.cuda: 109 | self.cuda() 110 | 111 | 112 | class JointScorer(ScorerBase): 113 | 114 | def core(self, ids, l_ids, return_logits=False): 115 | assert(not return_logits) 116 | scores = dict() 117 | for k in self.kk['prior'] + self.kk['cond']: 118 | scorer = getattr(self, 'scorer_%s'%k) 119 | scores[k] = scorer.core(ids, l_ids) 120 | 121 | def avg_score(kk): 122 | if not kk: 123 | return 1 124 | sum_score_wt = 0 125 | sum_wt = 0 126 | for k in kk: 127 | sum_score_wt = sum_score_wt + scores[k] * self.wt[k] 128 | sum_wt += self.wt[k] 129 | return sum_score_wt / sum_wt 130 | 131 | prior = avg_score(self.kk['prior']) 132 | cond = avg_score(self.kk['cond']) 133 | scores['final'] = prior * cond 134 | return scores 135 | 136 | 137 | def load(self, path_config): 138 | import yaml 139 | with open(path_config, 'r') as stream: 140 | config = yaml.safe_load(stream) 141 | print(config) 142 | 143 | paths = dict() 144 | self.wt = dict() 145 | self.kk = dict() 146 | for prefix in ['prior', 'cond']: 147 | self.kk[prefix] = [] 148 | for d in config[prefix]: 149 | k = d['name'] 150 | self.kk[prefix].append(k) 151 | self.wt[k] = d['wt'] 152 | paths[k] = d['path'] 153 | 154 | for k in paths: 155 | path = paths[k] 156 | print('setting up model `%s`'%k) 157 | scorer = Scorer(OptionInfer(cuda=self.opt.cuda)) 158 | scorer.load(path) 159 | if self.opt.cuda: 160 | scorer.cuda() 161 | setattr(self, 'scorer_%s'%k, scorer) 162 | 163 | 164 | 165 | -------------------------------------------------------------------------------- /src/score.py: -------------------------------------------------------------------------------- 1 | import torch, pdb, os, json 2 | from shared import _cat_ 3 | import numpy as np 4 | from model import OptionInfer, Scorer 5 | from collections import defaultdict 6 | 7 | 8 | def get_model(path, cuda=True): 9 | opt = OptionInfer(cuda) 10 | if path.endswith('.yaml') or path.endswith('.yml'): 11 | from model import JointScorer 12 | model = JointScorer(opt) 13 | model.load(path) 14 | if path.endswith('pth'): 15 | from model import Scorer 16 | model = Scorer(opt) 17 | model.load(path) 18 | if cuda: 19 | model.cuda() 20 | return model 21 | 22 | 23 | def predict(model, cxt, hyps, max_cxt_turn=None): 24 | # split into smaller batch to avoid OOM 25 | n = len(hyps) 26 | i0 = 0 27 | scores = [] 28 | while i0 < n: 29 | i1 = min(i0 + 32, n) 30 | _scores = model.predict(cxt, hyps[i0: i1], max_cxt_turn=max_cxt_turn) 31 | scores.append(_scores) 32 | i0 = i1 33 | if isinstance(_scores, dict): 34 | d_scores = dict() 35 | for k in _scores: 36 | d_scores[k] = np.concatenate([_scores[k] for _scores in scores]) 37 | return d_scores 38 | else: 39 | return np.concatenate(scores) 40 | 41 | 42 | 43 | def eval_fake(fld, model, fake, max_n=-1, max_cxt_turn=None): 44 | """ 45 | for a given context, we rank k real and m fake responses 46 | if x real responses appeared in topk ranked responses, define acc = x/k, where k = # of real. 47 | this can be seen as a generalized version of hits@k 48 | for a perfect ranking, x == k thus acc == 1. 49 | """ 50 | 51 | assert(os.path.isdir(fld)) 52 | def read_data(path, max_n=-1): 53 | cxts = dict() 54 | rsps = dict() 55 | for i, line in enumerate(open(path, encoding='utf-8')): 56 | ss = line.strip('\n').split('\t') 57 | ss0 = ss[0].split(_cat_) 58 | if len(ss0) == 2: 59 | cxt, cxt_id = ss0 60 | cxt_id = cxt_id.strip() 61 | else: 62 | cxt = ss0[0] 63 | cxt_id = cxt.strip().replace(' ','') 64 | cxts[cxt_id] = cxt.strip() 65 | rsps[cxt_id] = [s.split(_cat_)[0] for s in ss[1:]] 66 | if i == max_n: 67 | break 68 | return cxts, rsps 69 | 70 | print('evaluating %s'%fld) 71 | acc = [] 72 | cxts, reals = read_data(fld + '/ref.tsv', max_n=max_n) 73 | _, fakes = read_data(fld + '/%s.tsv'%fake) 74 | 75 | n = 0 76 | for cxt_id in reals: 77 | if cxt_id not in fakes: 78 | print('[WARNING] could not find fake examples for [%s]'%cxt_id) 79 | #pdb.set_trace() 80 | continue 81 | scores = predict(model, cxts[cxt_id], reals[cxt_id] + fakes[cxt_id], max_cxt_turn=max_cxt_turn) 82 | ix_score = sorted([(scores[i], i) for i in range(len(scores))], reverse=True) 83 | k = len(reals[cxt_id]) 84 | _acc = np.mean([i < k for _, i in ix_score[:k]]) 85 | acc.append(_acc) 86 | n += 1 87 | if n % 10 == 0: 88 | print('evaluated %i, avg acc %.3f'%(n, np.mean(acc))) 89 | if n == max_n: 90 | break 91 | 92 | print('final acc is %.3f based on %i samples'%(np.mean(acc), n)) 93 | 94 | 95 | 96 | def eval_feedback(path, model, max_n=-1, max_cxt_turn=None, min_rank_gap=0., min_score_gap=0, max_hr_gap=1): 97 | """ 98 | for a given context, we compare two responses, 99 | predict which one got better feedback (greater updown, depth, or width) 100 | return this pairwise accuracy 101 | """ 102 | assert(path.endswith('.tsv')) 103 | assert(min_rank_gap is not None) 104 | assert(min_score_gap is not None) 105 | 106 | print('evaluating %s'%path) 107 | acc = [] 108 | n = 0 109 | for line in open(path, encoding='utf-8'): 110 | cc = line.strip('\n').split('\t') 111 | if len(cc) != 11: 112 | continue 113 | cxt, pos, neg, _, _, _, hr_gap, pos_score, neg_score, pos_rank, neg_rank = cc 114 | if float(hr_gap) > max_hr_gap: 115 | continue 116 | if float(pos_rank) - float(neg_rank) < min_rank_gap: 117 | continue 118 | if int(pos_score) - int(neg_score) < min_score_gap: 119 | continue 120 | 121 | scores = predict(model, cxt, [pos, neg], max_cxt_turn=max_cxt_turn) 122 | score_pos = scores[0] 123 | score_neg = scores[1] 124 | acc.append(float(score_pos > score_neg)) 125 | n += 1 126 | if n % 10 == 0: 127 | print('evaluated %i, avg acc %.3f'%(n, np.mean(acc))) 128 | if n == max_n: 129 | break 130 | 131 | print('final acc is %.3f based on %i samples'%(np.mean(acc), n)) 132 | 133 | 134 | 135 | def rank_hyps(path, model, max_n=-1, max_cxt_turn=None): 136 | """ 137 | rank the responses for each given cxt with model 138 | path is the input file, where in each line, 0-th column is the context, and the rest are responses 139 | output a jsonl file, and can be read with function `read_ranked_jsonl` 140 | """ 141 | 142 | print('ranking %s'%path) 143 | lines = [] 144 | n = 0 145 | sum_avg_score = 0 146 | sum_top_score = 0 147 | for i, line in enumerate(open(path, encoding='utf-8')): 148 | cc = line.strip('\n').split('\t') 149 | if len(cc) < 2: 150 | print('[WARNING] line %i only has %i columns, ignored'%(i, len(cc))) 151 | continue 152 | cxt = cc[0] 153 | hyps = cc[1:] 154 | scores = predict(model, cxt, hyps, max_cxt_turn=max_cxt_turn) 155 | d = {'line_id':i, 'cxt': cxt} 156 | scored = [] 157 | if isinstance(scores, dict): 158 | sum_avg_score += np.mean(scores['final']) 159 | sum_top_score += np.max(scores['final']) 160 | for j, hyp in enumerate(hyps): 161 | tup = ( 162 | float(scores['final'][j]), 163 | dict([(k, float(scores[k][j])) for k in scores]), 164 | hyp, 165 | ) 166 | scored.append(tup) 167 | else: 168 | sum_avg_score += np.mean(scores) 169 | sum_top_score += np.max(scores) 170 | for j, hyp in enumerate(hyps): 171 | scored.append((float(scores[j]), hyp)) 172 | d['hyps'] = list(sorted(scored, reverse=True)) 173 | lines.append(json.dumps(d)) 174 | n += 1 175 | 176 | if n % 10 == 0: 177 | print('processed %i line, avg_hyp_score %.3f, top_hyp_score %.3f'%( 178 | n, 179 | sum_avg_score/n, 180 | sum_top_score/n, 181 | )) 182 | if n == max_n: 183 | break 184 | print('totally processed %i line, avg_hyp_score %.3f, top_hyp_score %.3f'%( 185 | n, 186 | sum_avg_score/n, 187 | sum_top_score/n, 188 | )) 189 | path_out = path+'.ranked.jsonl' 190 | with open(path_out, 'w') as f: 191 | f.write('\n'.join(lines)) 192 | print('results saved to '+path_out) 193 | 194 | 195 | def read_ranked_jsonl(path): 196 | """ read the jsonl file ouput by function rank_hyps""" 197 | data = [json.loads(line) for line in open(path, encoding="utf-8")] 198 | n_hyp = [len(d['hyps']) for d in data] 199 | best = defaultdict(list) 200 | avg = defaultdict(list) 201 | for d in data: 202 | scores = defaultdict(list) 203 | for tup in d['hyps']: 204 | scores['_score'].append(tup[0]) 205 | if isinstance(tup[1], dict): 206 | for k in tup[1]: 207 | scores[k].append(tup[1][k]) 208 | for k in scores: 209 | best[k].append(max(scores[k])) 210 | avg[k].append(np.mean(scores[k])) 211 | 212 | print() 213 | width = 20 214 | print('\t|'.join([' '*width, 'best', 'avg'])) 215 | print('-'*40) 216 | for k in best: 217 | print('%s\t|%.3f\t|%.3f'%( 218 | ' '*(width - len(k)) + k, 219 | np.mean(best[k]), 220 | np.mean(avg[k]), 221 | )) 222 | print('-'*40) 223 | print('n_cxt: %i'%len(data)) 224 | print('avg n_hyp per cxt: %.2f'%np.mean(n_hyp)) 225 | return data 226 | 227 | 228 | 229 | def play(model, max_cxt_turn=None): 230 | from shared import EOS_token 231 | model.eval() 232 | print('enter empty to stop') 233 | print('use `%s` to delimite turns for a multi-turn context'%EOS_token) 234 | while True: 235 | print() 236 | cxt = input('Context: ') 237 | if not cxt: 238 | break 239 | hyp = input('Response: ') 240 | if not hyp: 241 | break 242 | score = model.predict(cxt, [hyp], max_cxt_turn=max_cxt_turn) 243 | if isinstance(score, dict): 244 | ss = ['%s = %.3f'%(k, score[k][0]) for k in score] 245 | print(', '.join(ss)) 246 | else: 247 | print('score = %.3f'%score[0]) 248 | 249 | 250 | if __name__ == "__main__": 251 | import argparse 252 | parser = argparse.ArgumentParser() 253 | parser.add_argument('task', type=str) 254 | parser.add_argument('--data', type=str) 255 | parser.add_argument('--max_cxt_turn', type=int, default=2) 256 | parser.add_argument('--path_pth', '-p', type=str) 257 | parser.add_argument('--cpu', action='store_true') 258 | parser.add_argument('--max_n', type=int, default=5000) 259 | parser.add_argument('--min_score_gap', type=int) 260 | parser.add_argument('--min_rank_gap', type=float) 261 | args = parser.parse_args() 262 | 263 | cuda = False if args.cpu else torch.cuda.is_available() 264 | if args.task != 'stats': 265 | model = get_model(args.path_pth, cuda) 266 | 267 | if args.task in ['eval_human_vs_rand', 'eval_human_vs_machine']: 268 | fake = args.task.split('_')[-1] 269 | eval_fake(args.data, model, fake, max_n=args.max_n, max_cxt_turn=args.max_cxt_turn) 270 | 271 | elif args.task == 'eval_human_feedback': 272 | eval_feedback(args.data, model, max_cxt_turn=args.max_cxt_turn, 273 | min_rank_gap=args.min_rank_gap, max_n=args.max_n, min_score_gap=args.min_score_gap) 274 | 275 | elif args.task == 'test': 276 | rank_hyps(args.data, model, max_n=args.max_n, max_cxt_turn=args.max_cxt_turn) 277 | 278 | elif args.task == 'play': 279 | play(model, max_cxt_turn=args.max_cxt_turn) 280 | 281 | elif args.task == 'stats': 282 | read_ranked_jsonl(args.data) 283 | 284 | else: 285 | raise ValueError 286 | -------------------------------------------------------------------------------- /src/shared.py: -------------------------------------------------------------------------------- 1 | # author: Xiang Gao at Microsoft Research AI NLP Group 2 | 3 | _cat_ = ' <-COL-> ' 4 | #EOS_token = '_EOS_' # old version, before Nov 8 2020 5 | EOS_token = '<|endoftext|>' 6 | 7 | 8 | def download_model(path): 9 | if path is None: 10 | return 11 | import os, subprocess 12 | if os.path.exists(path): 13 | return 14 | links = dict() 15 | for k in ['updown', 'depth', 'width', 'human_vs_rand', 'human_vs_machine']: 16 | links['restore/%s.pth'%k] = 'https://xiagnlp2.blob.core.windows.net/dialogrpt/%s.pth'%k 17 | links['restore/medium_ft.pkl'] = 'https://convaisharables.blob.core.windows.net/lsp/multiref/medium_ft.pkl' 18 | if path not in links: 19 | return 20 | cmd = [ 'wget', links[path], '-P', 'restore'] 21 | process = subprocess.Popen(cmd, stdout=subprocess.PIPE) 22 | process.communicate() 23 | -------------------------------------------------------------------------------- /src/transformers19/__init__.py: -------------------------------------------------------------------------------- 1 | # copied from: https://github.com/huggingface/transformers/commit/4d456542e9d381090f9a00b2bcc5a4cb07f6f3f7 2 | 3 | from .tokenization_gpt2 import GPT2Tokenizer 4 | from .configuration_gpt2 import GPT2Config, GPT2_PRETRAINED_CONFIG_ARCHIVE_MAP 5 | from .modeling_gpt2 import (GPT2PreTrainedModel, GPT2Model, 6 | GPT2LMHeadModel, GPT2DoubleHeadsModel, 7 | #load_tf_weights_in_gpt2, 8 | GPT2_PRETRAINED_MODEL_ARCHIVE_MAP) -------------------------------------------------------------------------------- /src/transformers19/configuration_gpt2.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2018 The OpenAI Team Authors and HuggingFace Inc. team. 3 | # Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved. 4 | # 5 | # Licensed under the Apache License, Version 2.0 (the "License"); 6 | # you may not use this file except in compliance with the License. 7 | # You may obtain a copy of the License at 8 | # 9 | # http://www.apache.org/licenses/LICENSE-2.0 10 | # 11 | # Unless required by applicable law or agreed to in writing, software 12 | # distributed under the License is distributed on an "AS IS" BASIS, 13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | # See the License for the specific language governing permissions and 15 | # limitations under the License. 16 | """ OpenAI GPT-2 configuration """ 17 | 18 | from __future__ import absolute_import, division, print_function, unicode_literals 19 | 20 | import json 21 | import logging 22 | import sys 23 | from io import open 24 | 25 | from .configuration_utils import PretrainedConfig 26 | 27 | logger = logging.getLogger(__name__) 28 | 29 | GPT2_PRETRAINED_CONFIG_ARCHIVE_MAP = {"gpt2": "https://s3.amazonaws.com/models.huggingface.co/bert/gpt2-config.json", 30 | "gpt2-medium": "https://s3.amazonaws.com/models.huggingface.co/bert/gpt2-medium-config.json", 31 | "gpt2-large": "https://s3.amazonaws.com/models.huggingface.co/bert/gpt2-large-config.json", 32 | "distilgpt2": "https://s3.amazonaws.com/models.huggingface.co/bert/distilgpt2-config.json",} 33 | 34 | class GPT2Config(PretrainedConfig): 35 | """Configuration class to store the configuration of a `GPT2Model`. 36 | 37 | Args: 38 | vocab_size_or_config_json_file: Vocabulary size of `inputs_ids` in `GPT2Model` or a configuration json file. 39 | n_positions: Number of positional embeddings. 40 | n_ctx: Size of the causal mask (usually same as n_positions). 41 | n_embd: Dimensionality of the embeddings and hidden states. 42 | n_layer: Number of hidden layers in the Transformer encoder. 43 | n_head: Number of attention heads for each attention layer in 44 | the Transformer encoder. 45 | layer_norm_epsilon: epsilon to use in the layer norm layers 46 | resid_pdrop: The dropout probabilitiy for all fully connected 47 | layers in the embeddings, encoder, and pooler. 48 | attn_pdrop: The dropout ratio for the attention 49 | probabilities. 50 | embd_pdrop: The dropout ratio for the embeddings. 51 | initializer_range: The sttdev of the truncated_normal_initializer for 52 | initializing all weight matrices. 53 | """ 54 | pretrained_config_archive_map = GPT2_PRETRAINED_CONFIG_ARCHIVE_MAP 55 | 56 | def __init__( 57 | self, 58 | vocab_size_or_config_json_file=50257, 59 | n_positions=1024, 60 | n_ctx=1024, 61 | n_embd=768, 62 | n_layer=12, 63 | n_head=12, 64 | resid_pdrop=0.1, 65 | embd_pdrop=0.1, 66 | attn_pdrop=0.1, 67 | layer_norm_epsilon=1e-5, 68 | initializer_range=0.02, 69 | 70 | num_labels=1, 71 | summary_type='cls_index', 72 | summary_use_proj=True, 73 | summary_activation=None, 74 | summary_proj_to_labels=True, 75 | summary_first_dropout=0.1, 76 | **kwargs 77 | ): 78 | """Constructs GPT2Config. 79 | 80 | Args: 81 | vocab_size_or_config_json_file: Vocabulary size of `inputs_ids` in `GPT2Model` or a configuration json file. 82 | n_positions: Number of positional embeddings. 83 | n_ctx: Size of the causal mask (usually same as n_positions). 84 | n_embd: Dimensionality of the embeddings and hidden states. 85 | n_layer: Number of hidden layers in the Transformer encoder. 86 | n_head: Number of attention heads for each attention layer in 87 | the Transformer encoder. 88 | layer_norm_epsilon: epsilon to use in the layer norm layers 89 | resid_pdrop: The dropout probabilitiy for all fully connected 90 | layers in the embeddings, encoder, and pooler. 91 | attn_pdrop: The dropout ratio for the attention 92 | probabilities. 93 | embd_pdrop: The dropout ratio for the embeddings. 94 | initializer_range: The sttdev of the truncated_normal_initializer for 95 | initializing all weight matrices. 96 | """ 97 | super(GPT2Config, self).__init__(**kwargs) 98 | 99 | if isinstance(vocab_size_or_config_json_file, str) or (sys.version_info[0] == 2 100 | and isinstance(vocab_size_or_config_json_file, unicode)): 101 | with open(vocab_size_or_config_json_file, "r", encoding="utf-8") as reader: 102 | json_config = json.loads(reader.read()) 103 | for key, value in json_config.items(): 104 | self.__dict__[key] = value 105 | elif isinstance(vocab_size_or_config_json_file, int): 106 | self.vocab_size = vocab_size_or_config_json_file 107 | self.n_ctx = n_ctx 108 | self.n_positions = n_positions 109 | self.n_embd = n_embd 110 | self.n_layer = n_layer 111 | self.n_head = n_head 112 | self.resid_pdrop = resid_pdrop 113 | self.embd_pdrop = embd_pdrop 114 | self.attn_pdrop = attn_pdrop 115 | self.layer_norm_epsilon = layer_norm_epsilon 116 | self.initializer_range = initializer_range 117 | 118 | self.num_labels = num_labels 119 | self.summary_type = summary_type 120 | self.summary_use_proj = summary_use_proj 121 | self.summary_activation = summary_activation 122 | self.summary_first_dropout = summary_first_dropout 123 | self.summary_proj_to_labels = summary_proj_to_labels 124 | else: 125 | raise ValueError( 126 | "First argument must be either a vocabulary size (int)" 127 | "or the path to a pretrained model config file (str)" 128 | ) 129 | 130 | @property 131 | def max_position_embeddings(self): 132 | return self.n_positions 133 | 134 | @property 135 | def hidden_size(self): 136 | return self.n_embd 137 | 138 | @property 139 | def num_attention_heads(self): 140 | return self.n_head 141 | 142 | @property 143 | def num_hidden_layers(self): 144 | return self.n_layer 145 | -------------------------------------------------------------------------------- /src/transformers19/configuration_utils.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2018 The Google AI Language Team Authors and The HuggingFace Inc. team. 3 | # Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved. 4 | # 5 | # Licensed under the Apache License, Version 2.0 (the "License"); 6 | # you may not use this file except in compliance with the License. 7 | # You may obtain a copy of the License at 8 | # 9 | # http://www.apache.org/licenses/LICENSE-2.0 10 | # 11 | # Unless required by applicable law or agreed to in writing, software 12 | # distributed under the License is distributed on an "AS IS" BASIS, 13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | # See the License for the specific language governing permissions and 15 | # limitations under the License. 16 | """ Configuration base class and utilities.""" 17 | 18 | from __future__ import (absolute_import, division, print_function, 19 | unicode_literals) 20 | 21 | import copy 22 | import json 23 | import logging 24 | import os 25 | from io import open 26 | 27 | from .file_utils import cached_path, CONFIG_NAME 28 | 29 | logger = logging.getLogger(__name__) 30 | 31 | class PretrainedConfig(object): 32 | r""" Base class for all configuration classes. 33 | Handles a few parameters common to all models' configurations as well as methods for loading/downloading/saving configurations. 34 | 35 | Note: 36 | A configuration file can be loaded and saved to disk. Loading the configuration file and using this file to initialize a model does **not** load the model weights. 37 | It only affects the model's configuration. 38 | 39 | Class attributes (overridden by derived classes): 40 | - ``pretrained_config_archive_map``: a python ``dict`` of with `short-cut-names` (string) as keys and `url` (string) of associated pretrained model configurations as values. 41 | 42 | Parameters: 43 | ``finetuning_task``: string, default `None`. Name of the task used to fine-tune the model. This can be used when converting from an original (TensorFlow or PyTorch) checkpoint. 44 | ``num_labels``: integer, default `2`. Number of classes to use when the model is a classification model (sequences/tokens) 45 | ``output_attentions``: boolean, default `False`. Should the model returns attentions weights. 46 | ``output_hidden_states``: string, default `False`. Should the model returns all hidden-states. 47 | ``torchscript``: string, default `False`. Is the model used with Torchscript. 48 | """ 49 | pretrained_config_archive_map = {} 50 | 51 | def __init__(self, **kwargs): 52 | self.finetuning_task = kwargs.pop('finetuning_task', None) 53 | self.num_labels = kwargs.pop('num_labels', 2) 54 | self.output_attentions = kwargs.pop('output_attentions', False) 55 | self.output_hidden_states = kwargs.pop('output_hidden_states', False) 56 | self.output_past = kwargs.pop('output_past', True) # Not used by all models 57 | self.torchscript = kwargs.pop('torchscript', False) # Only used by PyTorch models 58 | self.use_bfloat16 = kwargs.pop('use_bfloat16', False) 59 | self.pruned_heads = kwargs.pop('pruned_heads', {}) 60 | 61 | def save_pretrained(self, save_directory): 62 | """ Save a configuration object to the directory `save_directory`, so that it 63 | can be re-loaded using the :func:`~transformers.PretrainedConfig.from_pretrained` class method. 64 | """ 65 | assert os.path.isdir(save_directory), "Saving path should be a directory where the model and configuration can be saved" 66 | 67 | # If we save using the predefined names, we can load using `from_pretrained` 68 | output_config_file = os.path.join(save_directory, CONFIG_NAME) 69 | 70 | self.to_json_file(output_config_file) 71 | logger.info("Configuration saved in {}".format(output_config_file)) 72 | 73 | @classmethod 74 | def from_pretrained(cls, pretrained_model_name_or_path, **kwargs): 75 | r""" Instantiate a :class:`~transformers.PretrainedConfig` (or a derived class) from a pre-trained model configuration. 76 | 77 | Parameters: 78 | pretrained_model_name_or_path: either: 79 | 80 | - a string with the `shortcut name` of a pre-trained model configuration to load from cache or download, e.g.: ``bert-base-uncased``. 81 | - a path to a `directory` containing a configuration file saved using the :func:`~transformers.PretrainedConfig.save_pretrained` method, e.g.: ``./my_model_directory/``. 82 | - a path or url to a saved configuration JSON `file`, e.g.: ``./my_model_directory/configuration.json``. 83 | 84 | cache_dir: (`optional`) string: 85 | Path to a directory in which a downloaded pre-trained model 86 | configuration should be cached if the standard cache should not be used. 87 | 88 | kwargs: (`optional`) dict: key/value pairs with which to update the configuration object after loading. 89 | 90 | - The values in kwargs of any keys which are configuration attributes will be used to override the loaded values. 91 | - Behavior concerning key/value pairs whose keys are *not* configuration attributes is controlled by the `return_unused_kwargs` keyword parameter. 92 | 93 | force_download: (`optional`) boolean, default False: 94 | Force to (re-)download the model weights and configuration files and override the cached versions if they exists. 95 | 96 | proxies: (`optional`) dict, default None: 97 | A dictionary of proxy servers to use by protocol or endpoint, e.g.: {'http': 'foo.bar:3128', 'http://hostname': 'foo.bar:4012'}. 98 | The proxies are used on each request. 99 | 100 | return_unused_kwargs: (`optional`) bool: 101 | 102 | - If False, then this function returns just the final configuration object. 103 | - If True, then this functions returns a tuple `(config, unused_kwargs)` where `unused_kwargs` is a dictionary consisting of the key/value pairs whose keys are not configuration attributes: ie the part of kwargs which has not been used to update `config` and is otherwise ignored. 104 | 105 | Examples:: 106 | 107 | # We can't instantiate directly the base class `PretrainedConfig` so let's show the examples on a 108 | # derived class: BertConfig 109 | config = BertConfig.from_pretrained('bert-base-uncased') # Download configuration from S3 and cache. 110 | config = BertConfig.from_pretrained('./test/saved_model/') # E.g. config (or model) was saved using `save_pretrained('./test/saved_model/')` 111 | config = BertConfig.from_pretrained('./test/saved_model/my_configuration.json') 112 | config = BertConfig.from_pretrained('bert-base-uncased', output_attention=True, foo=False) 113 | assert config.output_attention == True 114 | config, unused_kwargs = BertConfig.from_pretrained('bert-base-uncased', output_attention=True, 115 | foo=False, return_unused_kwargs=True) 116 | assert config.output_attention == True 117 | assert unused_kwargs == {'foo': False} 118 | 119 | """ 120 | cache_dir = kwargs.pop('cache_dir', None) 121 | force_download = kwargs.pop('force_download', False) 122 | proxies = kwargs.pop('proxies', None) 123 | return_unused_kwargs = kwargs.pop('return_unused_kwargs', False) 124 | 125 | if pretrained_model_name_or_path in cls.pretrained_config_archive_map: 126 | config_file = cls.pretrained_config_archive_map[pretrained_model_name_or_path] 127 | elif os.path.isdir(pretrained_model_name_or_path): 128 | config_file = os.path.join(pretrained_model_name_or_path, CONFIG_NAME) 129 | else: 130 | config_file = pretrained_model_name_or_path 131 | # redirect to the cache, if necessary 132 | try: 133 | resolved_config_file = cached_path(config_file, cache_dir=cache_dir, force_download=force_download, proxies=proxies) 134 | except EnvironmentError: 135 | if pretrained_model_name_or_path in cls.pretrained_config_archive_map: 136 | msg = "Couldn't reach server at '{}' to download pretrained model configuration file.".format( 137 | config_file) 138 | else: 139 | msg = "Model name '{}' was not found in model name list ({}). " \ 140 | "We assumed '{}' was a path or url to a configuration file named {} or " \ 141 | "a directory containing such a file but couldn't find any such file at this path or url.".format( 142 | pretrained_model_name_or_path, 143 | ', '.join(cls.pretrained_config_archive_map.keys()), 144 | config_file, CONFIG_NAME) 145 | raise EnvironmentError(msg) 146 | 147 | if resolved_config_file == config_file: 148 | logger.info("loading configuration file {}".format(config_file)) 149 | else: 150 | logger.info("loading configuration file {} from cache at {}".format( 151 | config_file, resolved_config_file)) 152 | 153 | # Load config 154 | config = cls.from_json_file(resolved_config_file) 155 | 156 | if hasattr(config, 'pruned_heads'): 157 | config.pruned_heads = dict((int(key), value) for key, value in config.pruned_heads.items()) 158 | 159 | # Update config with kwargs if needed 160 | to_remove = [] 161 | for key, value in kwargs.items(): 162 | if hasattr(config, key): 163 | setattr(config, key, value) 164 | to_remove.append(key) 165 | for key in to_remove: 166 | kwargs.pop(key, None) 167 | 168 | logger.info("Model config %s", str(config)) 169 | if return_unused_kwargs: 170 | return config, kwargs 171 | else: 172 | return config 173 | 174 | @classmethod 175 | def from_dict(cls, json_object): 176 | """Constructs a `Config` from a Python dictionary of parameters.""" 177 | config = cls(vocab_size_or_config_json_file=-1) 178 | for key, value in json_object.items(): 179 | setattr(config, key, value) 180 | return config 181 | 182 | @classmethod 183 | def from_json_file(cls, json_file): 184 | """Constructs a `BertConfig` from a json file of parameters.""" 185 | with open(json_file, "r", encoding='utf-8') as reader: 186 | text = reader.read() 187 | return cls.from_dict(json.loads(text)) 188 | 189 | def __eq__(self, other): 190 | return self.__dict__ == other.__dict__ 191 | 192 | def __repr__(self): 193 | return str(self.to_json_string()) 194 | 195 | def to_dict(self): 196 | """Serializes this instance to a Python dictionary.""" 197 | output = copy.deepcopy(self.__dict__) 198 | return output 199 | 200 | def to_json_string(self): 201 | """Serializes this instance to a JSON string.""" 202 | return json.dumps(self.to_dict(), indent=2, sort_keys=True) + "\n" 203 | 204 | def to_json_file(self, json_file_path): 205 | """ Save this instance to a json file.""" 206 | with open(json_file_path, "w", encoding='utf-8') as writer: 207 | writer.write(self.to_json_string()) 208 | -------------------------------------------------------------------------------- /src/transformers19/file_utils.py: -------------------------------------------------------------------------------- 1 | """ 2 | Utilities for working with the local dataset cache. 3 | This file is adapted from the AllenNLP library at https://github.com/allenai/allennlp 4 | Copyright by the AllenNLP authors. 5 | """ 6 | from __future__ import (absolute_import, division, print_function, unicode_literals) 7 | 8 | import sys 9 | import json 10 | import logging 11 | import os 12 | import six 13 | import shutil 14 | import tempfile 15 | import fnmatch 16 | from functools import wraps 17 | from hashlib import sha256 18 | from io import open 19 | 20 | import boto3 21 | from botocore.config import Config 22 | from botocore.exceptions import ClientError 23 | import requests 24 | from tqdm import tqdm 25 | 26 | logger = logging.getLogger(__name__) # pylint: disable=invalid-name 27 | _tf_available = False # pylint: disable=invalid-name 28 | 29 | try: 30 | import torch 31 | _torch_available = True # pylint: disable=invalid-name 32 | logger.info("PyTorch version {} available.".format(torch.__version__)) 33 | except ImportError: 34 | _torch_available = False # pylint: disable=invalid-name 35 | 36 | 37 | try: 38 | from torch.hub import _get_torch_home 39 | torch_cache_home = _get_torch_home() 40 | except ImportError: 41 | torch_cache_home = os.path.expanduser( 42 | os.getenv('TORCH_HOME', os.path.join( 43 | os.getenv('XDG_CACHE_HOME', '~/.cache'), 'torch'))) 44 | default_cache_path = os.path.join(torch_cache_home, 'transformers') 45 | 46 | try: 47 | from urllib.parse import urlparse 48 | except ImportError: 49 | from urlparse import urlparse 50 | 51 | try: 52 | from pathlib import Path 53 | PYTORCH_PRETRAINED_BERT_CACHE = Path( 54 | os.getenv('PYTORCH_TRANSFORMERS_CACHE', os.getenv('PYTORCH_PRETRAINED_BERT_CACHE', default_cache_path))) 55 | except (AttributeError, ImportError): 56 | PYTORCH_PRETRAINED_BERT_CACHE = os.getenv('PYTORCH_TRANSFORMERS_CACHE', 57 | os.getenv('PYTORCH_PRETRAINED_BERT_CACHE', 58 | default_cache_path)) 59 | 60 | PYTORCH_TRANSFORMERS_CACHE = PYTORCH_PRETRAINED_BERT_CACHE # Kept for backward compatibility 61 | TRANSFORMERS_CACHE = PYTORCH_PRETRAINED_BERT_CACHE # Kept for backward compatibility 62 | 63 | WEIGHTS_NAME = "pytorch_model.bin" 64 | TF2_WEIGHTS_NAME = 'tf_model.h5' 65 | TF_WEIGHTS_NAME = 'model.ckpt' 66 | CONFIG_NAME = "config.json" 67 | 68 | def is_torch_available(): 69 | return _torch_available 70 | 71 | def is_tf_available(): 72 | return _tf_available 73 | 74 | if not six.PY2: 75 | def add_start_docstrings(*docstr): 76 | def docstring_decorator(fn): 77 | fn.__doc__ = ''.join(docstr) + fn.__doc__ 78 | return fn 79 | return docstring_decorator 80 | 81 | def add_end_docstrings(*docstr): 82 | def docstring_decorator(fn): 83 | fn.__doc__ = fn.__doc__ + ''.join(docstr) 84 | return fn 85 | return docstring_decorator 86 | else: 87 | # Not possible to update class docstrings on python2 88 | def add_start_docstrings(*docstr): 89 | def docstring_decorator(fn): 90 | return fn 91 | return docstring_decorator 92 | 93 | def add_end_docstrings(*docstr): 94 | def docstring_decorator(fn): 95 | return fn 96 | return docstring_decorator 97 | 98 | def url_to_filename(url, etag=None): 99 | """ 100 | Convert `url` into a hashed filename in a repeatable way. 101 | If `etag` is specified, append its hash to the url's, delimited 102 | by a period. 103 | If the url ends with .h5 (Keras HDF5 weights) ands '.h5' to the name 104 | so that TF 2.0 can identify it as a HDF5 file 105 | (see https://github.com/tensorflow/tensorflow/blob/00fad90125b18b80fe054de1055770cfb8fe4ba3/tensorflow/python/keras/engine/network.py#L1380) 106 | """ 107 | url_bytes = url.encode('utf-8') 108 | url_hash = sha256(url_bytes) 109 | filename = url_hash.hexdigest() 110 | 111 | if etag: 112 | etag_bytes = etag.encode('utf-8') 113 | etag_hash = sha256(etag_bytes) 114 | filename += '.' + etag_hash.hexdigest() 115 | 116 | if url.endswith('.h5'): 117 | filename += '.h5' 118 | 119 | return filename 120 | 121 | 122 | def filename_to_url(filename, cache_dir=None): 123 | """ 124 | Return the url and etag (which may be ``None``) stored for `filename`. 125 | Raise ``EnvironmentError`` if `filename` or its stored metadata do not exist. 126 | """ 127 | if cache_dir is None: 128 | cache_dir = TRANSFORMERS_CACHE 129 | if sys.version_info[0] == 3 and isinstance(cache_dir, Path): 130 | cache_dir = str(cache_dir) 131 | 132 | cache_path = os.path.join(cache_dir, filename) 133 | if not os.path.exists(cache_path): 134 | raise EnvironmentError("file {} not found".format(cache_path)) 135 | 136 | meta_path = cache_path + '.json' 137 | if not os.path.exists(meta_path): 138 | raise EnvironmentError("file {} not found".format(meta_path)) 139 | 140 | with open(meta_path, encoding="utf-8") as meta_file: 141 | metadata = json.load(meta_file) 142 | url = metadata['url'] 143 | etag = metadata['etag'] 144 | 145 | return url, etag 146 | 147 | 148 | def cached_path(url_or_filename, cache_dir=None, force_download=False, proxies=None): 149 | """ 150 | Given something that might be a URL (or might be a local path), 151 | determine which. If it's a URL, download the file and cache it, and 152 | return the path to the cached file. If it's already a local path, 153 | make sure the file exists and then return the path. 154 | Args: 155 | cache_dir: specify a cache directory to save the file to (overwrite the default cache dir). 156 | force_download: if True, re-dowload the file even if it's already cached in the cache dir. 157 | """ 158 | if cache_dir is None: 159 | cache_dir = TRANSFORMERS_CACHE 160 | if sys.version_info[0] == 3 and isinstance(url_or_filename, Path): 161 | url_or_filename = str(url_or_filename) 162 | if sys.version_info[0] == 3 and isinstance(cache_dir, Path): 163 | cache_dir = str(cache_dir) 164 | 165 | parsed = urlparse(url_or_filename) 166 | 167 | if parsed.scheme in ('http', 'https', 's3'): 168 | # URL, so get it from the cache (downloading if necessary) 169 | return get_from_cache(url_or_filename, cache_dir=cache_dir, force_download=force_download, proxies=proxies) 170 | elif os.path.exists(url_or_filename): 171 | # File, and it exists. 172 | return url_or_filename 173 | elif parsed.scheme == '': 174 | # File, but it doesn't exist. 175 | raise EnvironmentError("file {} not found".format(url_or_filename)) 176 | else: 177 | # Something unknown 178 | raise ValueError("unable to parse {} as a URL or as a local path".format(url_or_filename)) 179 | 180 | 181 | def split_s3_path(url): 182 | """Split a full s3 path into the bucket name and path.""" 183 | parsed = urlparse(url) 184 | if not parsed.netloc or not parsed.path: 185 | raise ValueError("bad s3 path {}".format(url)) 186 | bucket_name = parsed.netloc 187 | s3_path = parsed.path 188 | # Remove '/' at beginning of path. 189 | if s3_path.startswith("/"): 190 | s3_path = s3_path[1:] 191 | return bucket_name, s3_path 192 | 193 | 194 | def s3_request(func): 195 | """ 196 | Wrapper function for s3 requests in order to create more helpful error 197 | messages. 198 | """ 199 | 200 | @wraps(func) 201 | def wrapper(url, *args, **kwargs): 202 | try: 203 | return func(url, *args, **kwargs) 204 | except ClientError as exc: 205 | if int(exc.response["Error"]["Code"]) == 404: 206 | raise EnvironmentError("file {} not found".format(url)) 207 | else: 208 | raise 209 | 210 | return wrapper 211 | 212 | 213 | @s3_request 214 | def s3_etag(url, proxies=None): 215 | """Check ETag on S3 object.""" 216 | s3_resource = boto3.resource("s3", config=Config(proxies=proxies)) 217 | bucket_name, s3_path = split_s3_path(url) 218 | s3_object = s3_resource.Object(bucket_name, s3_path) 219 | return s3_object.e_tag 220 | 221 | 222 | @s3_request 223 | def s3_get(url, temp_file, proxies=None): 224 | """Pull a file directly from S3.""" 225 | s3_resource = boto3.resource("s3", config=Config(proxies=proxies)) 226 | bucket_name, s3_path = split_s3_path(url) 227 | s3_resource.Bucket(bucket_name).download_fileobj(s3_path, temp_file) 228 | 229 | 230 | def http_get(url, temp_file, proxies=None): 231 | req = requests.get(url, stream=True, proxies=proxies) 232 | content_length = req.headers.get('Content-Length') 233 | total = int(content_length) if content_length is not None else None 234 | progress = tqdm(unit="B", total=total) 235 | for chunk in req.iter_content(chunk_size=1024): 236 | if chunk: # filter out keep-alive new chunks 237 | progress.update(len(chunk)) 238 | temp_file.write(chunk) 239 | progress.close() 240 | 241 | 242 | def get_from_cache(url, cache_dir=None, force_download=False, proxies=None, etag_timeout=10): 243 | """ 244 | Given a URL, look for the corresponding dataset in the local cache. 245 | If it's not there, download it. Then return the path to the cached file. 246 | """ 247 | if cache_dir is None: 248 | cache_dir = TRANSFORMERS_CACHE 249 | if sys.version_info[0] == 3 and isinstance(cache_dir, Path): 250 | cache_dir = str(cache_dir) 251 | if sys.version_info[0] == 2 and not isinstance(cache_dir, str): 252 | cache_dir = str(cache_dir) 253 | 254 | if not os.path.exists(cache_dir): 255 | os.makedirs(cache_dir, exist_ok=True) 256 | 257 | # Get eTag to add to filename, if it exists. 258 | if url.startswith("s3://"): 259 | etag = s3_etag(url, proxies=proxies) 260 | else: 261 | try: 262 | response = requests.head(url, allow_redirects=True, proxies=proxies, timeout=etag_timeout) 263 | if response.status_code != 200: 264 | etag = None 265 | else: 266 | etag = response.headers.get("ETag") 267 | except (EnvironmentError, requests.exceptions.Timeout): 268 | etag = None 269 | 270 | if sys.version_info[0] == 2 and etag is not None: 271 | etag = etag.decode('utf-8') 272 | filename = url_to_filename(url, etag) 273 | 274 | # get cache path to put the file 275 | cache_path = os.path.join(cache_dir, filename) 276 | 277 | # If we don't have a connection (etag is None) and can't identify the file 278 | # try to get the last downloaded one 279 | if not os.path.exists(cache_path) and etag is None: 280 | matching_files = fnmatch.filter(os.listdir(cache_dir), filename + '.*') 281 | matching_files = list(filter(lambda s: not s.endswith('.json'), matching_files)) 282 | if matching_files: 283 | cache_path = os.path.join(cache_dir, matching_files[-1]) 284 | 285 | if not os.path.exists(cache_path) or force_download: 286 | # Download to temporary file, then copy to cache dir once finished. 287 | # Otherwise you get corrupt cache entries if the download gets interrupted. 288 | with tempfile.NamedTemporaryFile() as temp_file: 289 | logger.info("%s not found in cache or force_download set to True, downloading to %s", url, temp_file.name) 290 | 291 | # GET file object 292 | if url.startswith("s3://"): 293 | s3_get(url, temp_file, proxies=proxies) 294 | else: 295 | http_get(url, temp_file, proxies=proxies) 296 | 297 | # we are copying the file before closing it, so flush to avoid truncation 298 | temp_file.flush() 299 | # shutil.copyfileobj() starts at the current position, so go to the start 300 | temp_file.seek(0) 301 | 302 | logger.info("copying %s to cache at %s", temp_file.name, cache_path) 303 | with open(cache_path, 'wb') as cache_file: 304 | shutil.copyfileobj(temp_file, cache_file) 305 | 306 | logger.info("creating metadata file for %s", cache_path) 307 | meta = {'url': url, 'etag': etag} 308 | meta_path = cache_path + '.json' 309 | with open(meta_path, 'w') as meta_file: 310 | output_string = json.dumps(meta) 311 | if sys.version_info[0] == 2 and isinstance(output_string, str): 312 | output_string = unicode(output_string, 'utf-8') # The beauty of python 2 313 | meta_file.write(output_string) 314 | 315 | logger.info("removing temp file %s", temp_file.name) 316 | 317 | return cache_path 318 | -------------------------------------------------------------------------------- /src/transformers19/modeling_gpt2.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2018 The OpenAI Team Authors and HuggingFace Inc. team. 3 | # Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved. 4 | # 5 | # Licensed under the Apache License, Version 2.0 (the "License"); 6 | # you may not use this file except in compliance with the License. 7 | # You may obtain a copy of the License at 8 | # 9 | # http://www.apache.org/licenses/LICENSE-2.0 10 | # 11 | # Unless required by applicable law or agreed to in writing, software 12 | # distributed under the License is distributed on an "AS IS" BASIS, 13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | # See the License for the specific language governing permissions and 15 | # limitations under the License. 16 | """PyTorch OpenAI GPT-2 model.""" 17 | 18 | from __future__ import absolute_import, division, print_function, unicode_literals 19 | 20 | import collections 21 | import json 22 | import logging 23 | import math 24 | import os 25 | import sys 26 | from io import open 27 | 28 | import torch 29 | import torch.nn as nn 30 | from torch.nn import CrossEntropyLoss 31 | from torch.nn.parameter import Parameter 32 | 33 | from .modeling_utils import PreTrainedModel, Conv1D, prune_conv1d_layer, SequenceSummary 34 | from .configuration_gpt2 import GPT2Config 35 | from .file_utils import add_start_docstrings 36 | 37 | logger = logging.getLogger(__name__) 38 | 39 | GPT2_PRETRAINED_MODEL_ARCHIVE_MAP = {"gpt2": "https://s3.amazonaws.com/models.huggingface.co/bert/gpt2-pytorch_model.bin", 40 | "gpt2-medium": "https://s3.amazonaws.com/models.huggingface.co/bert/gpt2-medium-pytorch_model.bin", 41 | "gpt2-large": "https://s3.amazonaws.com/models.huggingface.co/bert/gpt2-large-pytorch_model.bin", 42 | "distilgpt2": "https://s3.amazonaws.com/models.huggingface.co/bert/distilgpt2-pytorch_model.bin",} 43 | 44 | 45 | def gelu(x): 46 | return 0.5 * x * (1 + torch.tanh(math.sqrt(2 / math.pi) * (x + 0.044715 * torch.pow(x, 3)))) 47 | 48 | 49 | class Attention(nn.Module): 50 | def __init__(self, nx, n_ctx, config, scale=False): 51 | super(Attention, self).__init__() 52 | self.output_attentions = config.output_attentions 53 | 54 | n_state = nx # in Attention: n_state=768 (nx=n_embd) 55 | # [switch nx => n_state from Block to Attention to keep identical to TF implem] 56 | assert n_state % config.n_head == 0 57 | self.register_buffer("bias", torch.tril(torch.ones(n_ctx, n_ctx)).view(1, 1, n_ctx, n_ctx)) 58 | self.n_head = config.n_head 59 | self.split_size = n_state 60 | self.scale = scale 61 | 62 | self.c_attn = Conv1D(n_state * 3, nx) 63 | self.c_proj = Conv1D(n_state, nx) 64 | self.attn_dropout = nn.Dropout(config.attn_pdrop) 65 | self.resid_dropout = nn.Dropout(config.resid_pdrop) 66 | self.pruned_heads = set() 67 | 68 | def prune_heads(self, heads): 69 | if len(heads) == 0: 70 | return 71 | mask = torch.ones(self.n_head, self.split_size // self.n_head) 72 | heads = set(heads) - self.pruned_heads # Convert to set and emove already pruned heads 73 | for head in heads: 74 | # Compute how many pruned heads are before the head and move the index accordingly 75 | head = head - sum(1 if h < head else 0 for h in self.pruned_heads) 76 | mask[head] = 0 77 | mask = mask.view(-1).contiguous().eq(1) 78 | index = torch.arange(len(mask))[mask].long() 79 | index_attn = torch.cat([index, index + self.split_size, index + (2*self.split_size)]) 80 | 81 | # Prune conv1d layers 82 | self.c_attn = prune_conv1d_layer(self.c_attn, index_attn, dim=1) 83 | self.c_proj = prune_conv1d_layer(self.c_proj, index, dim=0) 84 | 85 | # Update hyper params 86 | self.split_size = (self.split_size // self.n_head) * (self.n_head - len(heads)) 87 | self.n_head = self.n_head - len(heads) 88 | self.pruned_heads = self.pruned_heads.union(heads) 89 | 90 | def _attn(self, q, k, v, attention_mask=None, head_mask=None): 91 | w = torch.matmul(q, k) 92 | if self.scale: 93 | w = w / math.sqrt(v.size(-1)) 94 | nd, ns = w.size(-2), w.size(-1) 95 | b = self.bias[:, :, ns-nd:ns, :ns] 96 | w = w * b - 1e4 * (1 - b) 97 | 98 | if attention_mask is not None: 99 | # Apply the attention mask 100 | w = w + attention_mask 101 | 102 | w = nn.Softmax(dim=-1)(w) 103 | w = self.attn_dropout(w) 104 | 105 | # Mask heads if we want to 106 | if head_mask is not None: 107 | w = w * head_mask 108 | 109 | outputs = [torch.matmul(w, v)] 110 | if self.output_attentions: 111 | outputs.append(w) 112 | return outputs 113 | 114 | def merge_heads(self, x): 115 | x = x.permute(0, 2, 1, 3).contiguous() 116 | new_x_shape = x.size()[:-2] + (x.size(-2) * x.size(-1),) 117 | return x.view(*new_x_shape) # in Tensorflow implem: fct merge_states 118 | 119 | def split_heads(self, x, k=False): 120 | new_x_shape = x.size()[:-1] + (self.n_head, x.size(-1) // self.n_head) 121 | x = x.view(*new_x_shape) # in Tensorflow implem: fct split_states 122 | if k: 123 | return x.permute(0, 2, 3, 1) # (batch, head, head_features, seq_length) 124 | else: 125 | return x.permute(0, 2, 1, 3) # (batch, head, seq_length, head_features) 126 | 127 | def forward(self, x, layer_past=None, attention_mask=None, head_mask=None): 128 | x = self.c_attn(x) 129 | query, key, value = x.split(self.split_size, dim=2) 130 | query = self.split_heads(query) 131 | key = self.split_heads(key, k=True) 132 | value = self.split_heads(value) 133 | if layer_past is not None: 134 | past_key, past_value = layer_past[0].transpose(-2, -1), layer_past[1] # transpose back cf below 135 | key = torch.cat((past_key, key), dim=-1) 136 | value = torch.cat((past_value, value), dim=-2) 137 | present = torch.stack((key.transpose(-2, -1), value)) # transpose to have same shapes for stacking 138 | 139 | attn_outputs = self._attn(query, key, value, attention_mask, head_mask) 140 | a = attn_outputs[0] 141 | 142 | a = self.merge_heads(a) 143 | a = self.c_proj(a) 144 | a = self.resid_dropout(a) 145 | 146 | outputs = [a, present] + attn_outputs[1:] 147 | return outputs # a, present, (attentions) 148 | 149 | 150 | class MLP(nn.Module): 151 | def __init__(self, n_state, config): # in MLP: n_state=3072 (4 * n_embd) 152 | super(MLP, self).__init__() 153 | nx = config.n_embd 154 | self.c_fc = Conv1D(n_state, nx) 155 | self.c_proj = Conv1D(nx, n_state) 156 | self.act = gelu 157 | self.dropout = nn.Dropout(config.resid_pdrop) 158 | 159 | def forward(self, x): 160 | h = self.act(self.c_fc(x)) 161 | h2 = self.c_proj(h) 162 | return self.dropout(h2) 163 | 164 | 165 | class Block(nn.Module): 166 | def __init__(self, n_ctx, config, scale=False): 167 | super(Block, self).__init__() 168 | nx = config.n_embd 169 | self.ln_1 = nn.LayerNorm(nx, eps=config.layer_norm_epsilon) 170 | self.attn = Attention(nx, n_ctx, config, scale) 171 | self.ln_2 = nn.LayerNorm(nx, eps=config.layer_norm_epsilon) 172 | self.mlp = MLP(4 * nx, config) 173 | 174 | def forward(self, x, layer_past=None, attention_mask=None, head_mask=None): 175 | output_attn = self.attn(self.ln_1(x), 176 | layer_past=layer_past, 177 | attention_mask=attention_mask, 178 | head_mask=head_mask) 179 | a = output_attn[0] # output_attn: a, present, (attentions) 180 | 181 | x = x + a 182 | m = self.mlp(self.ln_2(x)) 183 | x = x + m 184 | 185 | outputs = [x] + output_attn[1:] 186 | return outputs # x, present, (attentions) 187 | 188 | 189 | class GPT2PreTrainedModel(PreTrainedModel): 190 | """ An abstract class to handle weights initialization and 191 | a simple interface for dowloading and loading pretrained models. 192 | """ 193 | config_class = GPT2Config 194 | pretrained_model_archive_map = GPT2_PRETRAINED_MODEL_ARCHIVE_MAP 195 | #load_tf_weights = load_tf_weights_in_gpt2 196 | base_model_prefix = "transformer" 197 | 198 | def __init__(self, *inputs, **kwargs): 199 | super(GPT2PreTrainedModel, self).__init__(*inputs, **kwargs) 200 | 201 | def _init_weights(self, module): 202 | """ Initialize the weights. 203 | """ 204 | if isinstance(module, (nn.Linear, nn.Embedding, Conv1D)): 205 | # Slightly different from the TF version which uses truncated_normal for initialization 206 | # cf https://github.com/pytorch/pytorch/pull/5617 207 | module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) 208 | if isinstance(module, (nn.Linear, Conv1D)) and module.bias is not None: 209 | module.bias.data.zero_() 210 | elif isinstance(module, nn.LayerNorm): 211 | module.bias.data.zero_() 212 | module.weight.data.fill_(1.0) 213 | 214 | 215 | GPT2_START_DOCSTRING = r""" OpenAI GPT-2 model was proposed in 216 | `Language Models are Unsupervised Multitask Learners`_ 217 | by Alec Radford*, Jeffrey Wu*, Rewon Child, David Luan, Dario Amodei** and Ilya Sutskever**. 218 | It's a causal (unidirectional) transformer pre-trained using language modeling on a very large 219 | corpus of ~40 GB of text data. 220 | 221 | This model is a PyTorch `torch.nn.Module`_ sub-class. Use it as a regular PyTorch Module and 222 | refer to the PyTorch documentation for all matter related to general usage and behavior. 223 | 224 | .. _`Language Models are Unsupervised Multitask Learners`: 225 | https://openai.com/blog/better-language-models/ 226 | 227 | .. _`torch.nn.Module`: 228 | https://pytorch.org/docs/stable/nn.html#module 229 | 230 | Parameters: 231 | config (:class:`~transformers.GPT2Config`): Model configuration class with all the parameters of the model. 232 | Initializing with a config file does not load the weights associated with the model, only the configuration. 233 | Check out the :meth:`~transformers.PreTrainedModel.from_pretrained` method to load the model weights. 234 | """ 235 | 236 | GPT2_INPUTS_DOCSTRING = r""" Inputs: 237 | **input_ids**: ``torch.LongTensor`` of shape ``(batch_size, sequence_length)``: 238 | Indices of input sequence tokens in the vocabulary. 239 | GPT-2 is a model with absolute position embeddings so it's usually advised to pad the inputs on 240 | the right rather than the left. 241 | Indices can be obtained using :class:`transformers.GPT2Tokenizer`. 242 | See :func:`transformers.PreTrainedTokenizer.encode` and 243 | :func:`transformers.PreTrainedTokenizer.convert_tokens_to_ids` for details. 244 | **past**: 245 | list of ``torch.FloatTensor`` (one for each layer): 246 | that contains pre-computed hidden-states (key and values in the attention blocks) as computed by the model 247 | (see `past` output below). Can be used to speed up sequential decoding. 248 | **attention_mask**: (`optional`) ``torch.FloatTensor`` of shape ``(batch_size, sequence_length)``: 249 | Mask to avoid performing attention on padding token indices. 250 | Mask values selected in ``[0, 1]``: 251 | ``1`` for tokens that are NOT MASKED, ``0`` for MASKED tokens. 252 | **token_type_ids**: (`optional`) ``torch.LongTensor`` of shape ``(batch_size, sequence_length)``: 253 | A parallel sequence of tokens (can be used to indicate various portions of the inputs). 254 | The embeddings from these tokens will be summed with the respective token embeddings. 255 | Indices are selected in the vocabulary (unlike BERT which has a specific vocabulary for segment indices). 256 | **position_ids**: (`optional`) ``torch.LongTensor`` of shape ``(batch_size, sequence_length)``: 257 | Indices of positions of each input sequence tokens in the position embeddings. 258 | Selected in the range ``[0, config.max_position_embeddings - 1]``. 259 | **head_mask**: (`optional`) ``torch.FloatTensor`` of shape ``(num_heads,)`` or ``(num_layers, num_heads)``: 260 | Mask to nullify selected heads of the self-attention modules. 261 | Mask values selected in ``[0, 1]``: 262 | ``1`` indicates the head is **not masked**, ``0`` indicates the head is **masked**. 263 | """ 264 | 265 | @add_start_docstrings("The bare GPT2 Model transformer outputting raw hidden-states without any specific head on top.", 266 | GPT2_START_DOCSTRING, GPT2_INPUTS_DOCSTRING) 267 | class GPT2Model(GPT2PreTrainedModel): 268 | r""" 269 | Outputs: `Tuple` comprising various elements depending on the configuration (config) and inputs: 270 | **last_hidden_state**: ``torch.FloatTensor`` of shape ``(batch_size, sequence_length, hidden_size)`` 271 | Sequence of hidden-states at the last layer of the model. 272 | **past**: 273 | list of ``torch.FloatTensor`` (one for each layer) of shape ``(batch_size, num_heads, sequence_length, sequence_length)``: 274 | that contains pre-computed hidden-states (key and values in the attention blocks). 275 | Can be used (see `past` input) to speed up sequential decoding. 276 | **hidden_states**: (`optional`, returned when ``config.output_hidden_states=True``) 277 | list of ``torch.FloatTensor`` (one for the output of each layer + the output of the embeddings) 278 | of shape ``(batch_size, sequence_length, hidden_size)``: 279 | Hidden-states of the model at the output of each layer plus the initial embedding outputs. 280 | **attentions**: (`optional`, returned when ``config.output_attentions=True``) 281 | list of ``torch.FloatTensor`` (one for each layer) of shape ``(batch_size, num_heads, sequence_length, sequence_length)``: 282 | Attentions weights after the attention softmax, used to compute the weighted average in the self-attention heads. 283 | 284 | Examples:: 285 | 286 | tokenizer = GPT2Tokenizer.from_pretrained('gpt2') 287 | model = GPT2Model.from_pretrained('gpt2') 288 | input_ids = torch.tensor(tokenizer.encode("Hello, my dog is cute")).unsqueeze(0) # Batch size 1 289 | outputs = model(input_ids) 290 | last_hidden_states = outputs[0] # The last hidden-state is the first element of the output tuple 291 | 292 | """ 293 | def __init__(self, config): 294 | super(GPT2Model, self).__init__(config) 295 | self.output_hidden_states = config.output_hidden_states 296 | self.output_attentions = config.output_attentions 297 | self.output_past = config.output_past 298 | 299 | self.wte = nn.Embedding(config.vocab_size, config.n_embd) 300 | self.wpe = nn.Embedding(config.n_positions, config.n_embd) 301 | self.drop = nn.Dropout(config.embd_pdrop) 302 | self.h = nn.ModuleList([Block(config.n_ctx, config, scale=True) for _ in range(config.n_layer)]) 303 | self.ln_f = nn.LayerNorm(config.n_embd, eps=config.layer_norm_epsilon) 304 | 305 | self.init_weights() 306 | 307 | def _resize_token_embeddings(self, new_num_tokens): 308 | self.wte = self._get_resized_embeddings(self.wte, new_num_tokens) 309 | return self.wte 310 | 311 | def _prune_heads(self, heads_to_prune): 312 | """ Prunes heads of the model. 313 | heads_to_prune: dict of {layer_num: list of heads to prune in this layer} 314 | """ 315 | for layer, heads in heads_to_prune.items(): 316 | self.h[layer].attn.prune_heads(heads) 317 | 318 | def forward(self, input_ids, past=None, attention_mask=None, token_type_ids=None, position_ids=None, head_mask=None): 319 | input_shape = input_ids.size() 320 | input_ids = input_ids.view(-1, input_shape[-1]) 321 | if token_type_ids is not None: 322 | token_type_ids = token_type_ids.view(-1, input_shape[-1]) 323 | if position_ids is not None: 324 | position_ids = position_ids.view(-1, input_shape[-1]) 325 | 326 | if past is None: 327 | past_length = 0 328 | past = [None] * len(self.h) 329 | else: 330 | past_length = past[0][0].size(-2) 331 | if position_ids is None: 332 | position_ids = torch.arange(past_length, input_ids.size(-1) + past_length, dtype=torch.long, device=input_ids.device) 333 | position_ids = position_ids.unsqueeze(0).expand_as(input_ids) 334 | 335 | # Attention mask. 336 | if attention_mask is not None: 337 | attention_mask = attention_mask.view(-1, input_shape[-1]) 338 | # We create a 3D attention mask from a 2D tensor mask. 339 | # Sizes are [batch_size, 1, 1, to_seq_length] 340 | # So we can broadcast to [batch_size, num_heads, from_seq_length, to_seq_length] 341 | # this attention mask is more simple than the triangular masking of causal attention 342 | # used in OpenAI GPT, we just need to prepare the broadcast dimension here. 343 | attention_mask = attention_mask.unsqueeze(1).unsqueeze(2) 344 | 345 | # Since attention_mask is 1.0 for positions we want to attend and 0.0 for 346 | # masked positions, this operation will create a tensor which is 0.0 for 347 | # positions we want to attend and -10000.0 for masked positions. 348 | # Since we are adding it to the raw scores before the softmax, this is 349 | # effectively the same as removing these entirely. 350 | attention_mask = attention_mask.to(dtype=next(self.parameters()).dtype) # fp16 compatibility 351 | attention_mask = (1.0 - attention_mask) * -10000.0 352 | 353 | # Prepare head mask if needed 354 | # 1.0 in head_mask indicate we keep the head 355 | # attention_probs has shape bsz x n_heads x N x N 356 | # head_mask has shape n_layer x batch x n_heads x N x N 357 | if head_mask is not None: 358 | if head_mask.dim() == 1: 359 | head_mask = head_mask.unsqueeze(0).unsqueeze(0).unsqueeze(-1).unsqueeze(-1) 360 | head_mask = head_mask.expand(self.config.n_layer, -1, -1, -1, -1) 361 | elif head_mask.dim() == 2: 362 | head_mask = head_mask.unsqueeze(1).unsqueeze(-1).unsqueeze(-1) # We can specify head_mask for each layer 363 | head_mask = head_mask.to(dtype=next(self.parameters()).dtype) # switch to fload if need + fp16 compatibility 364 | else: 365 | head_mask = [None] * self.config.n_layer 366 | 367 | inputs_embeds = self.wte(input_ids) 368 | position_embeds = self.wpe(position_ids) 369 | if token_type_ids is not None: 370 | token_type_embeds = self.wte(token_type_ids) 371 | else: 372 | token_type_embeds = 0 373 | hidden_states = inputs_embeds + position_embeds + token_type_embeds 374 | hidden_states = self.drop(hidden_states) 375 | 376 | output_shape = input_shape + (hidden_states.size(-1),) 377 | 378 | presents = () 379 | all_attentions = [] 380 | all_hidden_states = () 381 | for i, (block, layer_past) in enumerate(zip(self.h, past)): 382 | if self.output_hidden_states: 383 | all_hidden_states = all_hidden_states + (hidden_states.view(*output_shape),) 384 | 385 | outputs = block(hidden_states, 386 | layer_past=layer_past, 387 | attention_mask=attention_mask, 388 | head_mask=head_mask[i]) 389 | 390 | hidden_states, present = outputs[:2] 391 | if self.output_past: 392 | presents = presents + (present,) 393 | 394 | if self.output_attentions: 395 | all_attentions.append(outputs[2]) 396 | 397 | hidden_states = self.ln_f(hidden_states) 398 | 399 | hidden_states = hidden_states.view(*output_shape) 400 | # Add last hidden state 401 | if self.output_hidden_states: 402 | all_hidden_states = all_hidden_states + (hidden_states,) 403 | 404 | outputs = (hidden_states,) 405 | if self.output_past: 406 | outputs = outputs + (presents,) 407 | if self.output_hidden_states: 408 | outputs = outputs + (all_hidden_states,) 409 | if self.output_attentions: 410 | # let the number of heads free (-1) so we can extract attention even after head pruning 411 | attention_output_shape = input_shape[:-1] + (-1,) + all_attentions[0].shape[-2:] 412 | all_attentions = tuple(t.view(*attention_output_shape) for t in all_attentions) 413 | outputs = outputs + (all_attentions,) 414 | return outputs # last hidden state, (presents), (all hidden_states), (attentions) 415 | 416 | 417 | @add_start_docstrings("""The GPT2 Model transformer with a language modeling head on top 418 | (linear layer with weights tied to the input embeddings). """, GPT2_START_DOCSTRING, GPT2_INPUTS_DOCSTRING) 419 | class GPT2LMHeadModel(GPT2PreTrainedModel): 420 | r""" 421 | **labels**: (`optional`) ``torch.LongTensor`` of shape ``(batch_size, sequence_length)``: 422 | Labels for language modeling. 423 | Note that the labels **are shifted** inside the model, i.e. you can set ``lm_labels = input_ids`` 424 | Indices are selected in ``[-1, 0, ..., config.vocab_size]`` 425 | All labels set to ``-1`` are ignored (masked), the loss is only 426 | computed for labels in ``[0, ..., config.vocab_size]`` 427 | 428 | Outputs: `Tuple` comprising various elements depending on the configuration (config) and inputs: 429 | **loss**: (`optional`, returned when ``labels`` is provided) ``torch.FloatTensor`` of shape ``(1,)``: 430 | Language modeling loss. 431 | **prediction_scores**: ``torch.FloatTensor`` of shape ``(batch_size, sequence_length, config.vocab_size)`` 432 | Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax). 433 | **past**: 434 | list of ``torch.FloatTensor`` (one for each layer) of shape ``(batch_size, num_heads, sequence_length, sequence_length)``: 435 | that contains pre-computed hidden-states (key and values in the attention blocks). 436 | Can be used (see `past` input) to speed up sequential decoding. 437 | **hidden_states**: (`optional`, returned when ``config.output_hidden_states=True``) 438 | list of ``torch.FloatTensor`` (one for the output of each layer + the output of the embeddings) 439 | of shape ``(batch_size, sequence_length, hidden_size)``: 440 | Hidden-states of the model at the output of each layer plus the initial embedding outputs. 441 | **attentions**: (`optional`, returned when ``config.output_attentions=True``) 442 | list of ``torch.FloatTensor`` (one for each layer) of shape ``(batch_size, num_heads, sequence_length, sequence_length)``: 443 | Attentions weights after the attention softmax, used to compute the weighted average in the self-attention heads. 444 | 445 | Examples:: 446 | 447 | import torch 448 | from transformers import GPT2Tokenizer, GPT2LMHeadModel 449 | 450 | tokenizer = GPT2Tokenizer.from_pretrained('gpt2') 451 | model = GPT2LMHeadModel.from_pretrained('gpt2') 452 | 453 | input_ids = torch.tensor(tokenizer.encode("Hello, my dog is cute")).unsqueeze(0) # Batch size 1 454 | outputs = model(input_ids, labels=input_ids) 455 | loss, logits = outputs[:2] 456 | 457 | """ 458 | def __init__(self, config): 459 | super(GPT2LMHeadModel, self).__init__(config) 460 | self.transformer = GPT2Model(config) 461 | self.lm_head = nn.Linear(config.n_embd, config.vocab_size, bias=False) 462 | 463 | self.init_weights() 464 | self.tie_weights() 465 | 466 | def tie_weights(self): 467 | """ Make sure we are sharing the input and output embeddings. 468 | Export to TorchScript can't handle parameter sharing so we are cloning them instead. 469 | """ 470 | self._tie_or_clone_weights(self.lm_head, 471 | self.transformer.wte) 472 | 473 | def forward(self, input_ids, past=None, attention_mask=None, token_type_ids=None, position_ids=None, head_mask=None, 474 | labels=None): 475 | transformer_outputs = self.transformer(input_ids, 476 | past=past, 477 | attention_mask=attention_mask, 478 | token_type_ids=token_type_ids, 479 | position_ids=position_ids, 480 | head_mask=head_mask) 481 | hidden_states = transformer_outputs[0] 482 | 483 | lm_logits = self.lm_head(hidden_states) 484 | 485 | outputs = (lm_logits,) + transformer_outputs[1:] 486 | if labels is not None: 487 | # Shift so that tokens < n predict n 488 | shift_logits = lm_logits[..., :-1, :].contiguous() 489 | shift_labels = labels[..., 1:].contiguous() 490 | # Flatten the tokens 491 | loss_fct = CrossEntropyLoss(ignore_index=-1) 492 | loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), 493 | shift_labels.view(-1)) 494 | outputs = (loss,) + outputs 495 | 496 | return outputs # (loss), lm_logits, presents, (all hidden_states), (attentions) 497 | 498 | 499 | @add_start_docstrings("""The GPT2 Model transformer with a language modeling and a multiple-choice classification 500 | head on top e.g. for RocStories/SWAG tasks. The two heads are two linear layers. 501 | The language modeling head has its weights tied to the input embeddings, 502 | the classification head takes as input the input of a specified classification token index in the input sequence). 503 | """, GPT2_START_DOCSTRING, GPT2_INPUTS_DOCSTRING) 504 | class GPT2DoubleHeadsModel(GPT2PreTrainedModel): 505 | r""" 506 | **mc_token_ids**: (`optional`, default to index of the last token of the input) ``torch.LongTensor`` of shape ``(batch_size, num_choices)``: 507 | Index of the classification token in each input sequence. 508 | Selected in the range ``[0, input_ids.size(-1) - 1[``. 509 | **lm_labels**: (`optional`) ``torch.LongTensor`` of shape ``(batch_size, sequence_length)``: 510 | Labels for language modeling. 511 | Note that the labels **are shifted** inside the model, i.e. you can set ``lm_labels = input_ids`` 512 | Indices are selected in ``[-1, 0, ..., config.vocab_size]`` 513 | All labels set to ``-1`` are ignored (masked), the loss is only 514 | computed for labels in ``[0, ..., config.vocab_size]`` 515 | **mc_labels**: (`optional`) ``torch.LongTensor`` of shape ``(batch_size)``: 516 | Labels for computing the multiple choice classification loss. 517 | Indices should be in ``[0, ..., num_choices]`` where `num_choices` is the size of the second dimension 518 | of the input tensors. (see `input_ids` above) 519 | 520 | Outputs: `Tuple` comprising various elements depending on the configuration (config) and inputs: 521 | **lm_loss**: (`optional`, returned when ``lm_labels`` is provided) ``torch.FloatTensor`` of shape ``(1,)``: 522 | Language modeling loss. 523 | **mc_loss**: (`optional`, returned when ``multiple_choice_labels`` is provided) ``torch.FloatTensor`` of shape ``(1,)``: 524 | Multiple choice classification loss. 525 | **lm_prediction_scores**: ``torch.FloatTensor`` of shape ``(batch_size, num_choices, sequence_length, config.vocab_size)`` 526 | Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax). 527 | **mc_prediction_scores**: ``torch.FloatTensor`` of shape ``(batch_size, num_choices)`` 528 | Prediction scores of the multiplechoice classification head (scores for each choice before SoftMax). 529 | **past**: 530 | list of ``torch.FloatTensor`` (one for each layer) of shape ``(batch_size, num_heads, sequence_length, sequence_length)``: 531 | that contains pre-computed hidden-states (key and values in the attention blocks). 532 | Can be used (see `past` input) to speed up sequential decoding. 533 | **hidden_states**: (`optional`, returned when ``config.output_hidden_states=True``) 534 | list of ``torch.FloatTensor`` (one for the output of each layer + the output of the embeddings) 535 | of shape ``(batch_size, sequence_length, hidden_size)``: 536 | Hidden-states of the model at the output of each layer plus the initial embedding outputs. 537 | **attentions**: (`optional`, returned when ``config.output_attentions=True``) 538 | list of ``torch.FloatTensor`` (one for each layer) of shape ``(batch_size, num_heads, sequence_length, sequence_length)``: 539 | Attentions weights after the attention softmax, used to compute the weighted average in the self-attention heads. 540 | 541 | Examples:: 542 | 543 | import torch 544 | from transformers import GPT2Tokenizer, GPT2DoubleHeadsModel 545 | 546 | tokenizer = GPT2Tokenizer.from_pretrained('gpt2') 547 | model = GPT2DoubleHeadsModel.from_pretrained('gpt2') 548 | 549 | # Add a [CLS] to the vocabulary (we should train it also!) 550 | tokenizer.add_special_tokens({'cls_token': '[CLS]'}) 551 | model.resize_token_embeddings(len(tokenizer)) # Update the model embeddings with the new vocabulary size 552 | print(tokenizer.cls_token_id, len(tokenizer)) # The newly token the last token of the vocabulary 553 | 554 | choices = ["Hello, my dog is cute [CLS]", "Hello, my cat is cute [CLS]"] 555 | encoded_choices = [tokenizer.encode(s) for s in choices] 556 | cls_token_location = [tokens.index(tokenizer.cls_token_id) for tokens in encoded_choices] 557 | 558 | input_ids = torch.tensor(encoded_choices).unsqueeze(0) # Batch size: 1, number of choices: 2 559 | mc_token_ids = torch.tensor([cls_token_location]) # Batch size: 1 560 | 561 | outputs = model(input_ids, mc_token_ids=mc_token_ids) 562 | lm_prediction_scores, mc_prediction_scores = outputs[:2] 563 | 564 | """ 565 | def __init__(self, config): 566 | super(GPT2DoubleHeadsModel, self).__init__(config) 567 | self.transformer = GPT2Model(config) 568 | self.lm_head = nn.Linear(config.n_embd, config.vocab_size, bias=False) 569 | self.multiple_choice_head = SequenceSummary(config) 570 | 571 | self.init_weights() 572 | self.tie_weights() 573 | 574 | def tie_weights(self): 575 | """ Make sure we are sharing the input and output embeddings. 576 | Export to TorchScript can't handle parameter sharing so we are cloning them instead. 577 | """ 578 | self._tie_or_clone_weights(self.lm_head, 579 | self.transformer.wte) 580 | 581 | def forward(self, input_ids, past=None, attention_mask=None, token_type_ids=None, position_ids=None, head_mask=None, 582 | mc_token_ids=None, lm_labels=None, mc_labels=None): 583 | transformer_outputs = self.transformer(input_ids, 584 | past=past, 585 | attention_mask=attention_mask, 586 | token_type_ids=token_type_ids, 587 | position_ids=position_ids, 588 | head_mask=head_mask) 589 | 590 | hidden_states = transformer_outputs[0] 591 | 592 | lm_logits = self.lm_head(hidden_states) 593 | mc_logits = self.multiple_choice_head(hidden_states, mc_token_ids).squeeze(-1) 594 | 595 | outputs = (lm_logits, mc_logits) + transformer_outputs[1:] 596 | if mc_labels is not None: 597 | loss_fct = CrossEntropyLoss() 598 | loss = loss_fct(mc_logits.view(-1, mc_logits.size(-1)), 599 | mc_labels.view(-1)) 600 | outputs = (loss,) + outputs 601 | if lm_labels is not None: 602 | shift_logits = lm_logits[..., :-1, :].contiguous() 603 | shift_labels = lm_labels[..., 1:].contiguous() 604 | loss_fct = CrossEntropyLoss(ignore_index=-1) 605 | loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), 606 | shift_labels.view(-1)) 607 | outputs = (loss,) + outputs 608 | 609 | return outputs # (lm loss), (mc loss), lm logits, mc logits, presents, (all hidden_states), (attentions) 610 | -------------------------------------------------------------------------------- /src/transformers19/modeling_utils.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2018 The Google AI Language Team Authors and The HuggingFace Inc. team. 3 | # Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved. 4 | # 5 | # Licensed under the Apache License, Version 2.0 (the "License"); 6 | # you may not use this file except in compliance with the License. 7 | # You may obtain a copy of the License at 8 | # 9 | # http://www.apache.org/licenses/LICENSE-2.0 10 | # 11 | # Unless required by applicable law or agreed to in writing, software 12 | # distributed under the License is distributed on an "AS IS" BASIS, 13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | # See the License for the specific language governing permissions and 15 | # limitations under the License. 16 | """PyTorch BERT model.""" 17 | 18 | from __future__ import (absolute_import, division, print_function, 19 | unicode_literals) 20 | 21 | import copy 22 | import json 23 | import logging 24 | import os 25 | from io import open 26 | 27 | import six 28 | import torch 29 | from torch import nn 30 | from torch.nn import CrossEntropyLoss 31 | from torch.nn import functional as F 32 | 33 | from .configuration_utils import PretrainedConfig 34 | from .file_utils import cached_path, WEIGHTS_NAME, TF_WEIGHTS_NAME, TF2_WEIGHTS_NAME 35 | 36 | logger = logging.getLogger(__name__) 37 | 38 | 39 | try: 40 | from torch.nn import Identity 41 | except ImportError: 42 | # Older PyTorch compatibility 43 | class Identity(nn.Module): 44 | r"""A placeholder identity operator that is argument-insensitive. 45 | """ 46 | def __init__(self, *args, **kwargs): 47 | super(Identity, self).__init__() 48 | 49 | def forward(self, input): 50 | return input 51 | 52 | class PreTrainedModel(nn.Module): 53 | r""" Base class for all models. 54 | 55 | :class:`~transformers.PreTrainedModel` takes care of storing the configuration of the models and handles methods for loading/downloading/saving models 56 | as well as a few methods commons to all models to (i) resize the input embeddings and (ii) prune heads in the self-attention heads. 57 | 58 | Class attributes (overridden by derived classes): 59 | - ``config_class``: a class derived from :class:`~transformers.PretrainedConfig` to use as configuration class for this model architecture. 60 | - ``pretrained_model_archive_map``: a python ``dict`` of with `short-cut-names` (string) as keys and `url` (string) of associated pretrained weights as values. 61 | - ``load_tf_weights``: a python ``method`` for loading a TensorFlow checkpoint in a PyTorch model, taking as arguments: 62 | 63 | - ``model``: an instance of the relevant subclass of :class:`~transformers.PreTrainedModel`, 64 | - ``config``: an instance of the relevant subclass of :class:`~transformers.PretrainedConfig`, 65 | - ``path``: a path (string) to the TensorFlow checkpoint. 66 | 67 | - ``base_model_prefix``: a string indicating the attribute associated to the base model in derived classes of the same architecture adding modules on top of the base model. 68 | """ 69 | config_class = None 70 | pretrained_model_archive_map = {} 71 | load_tf_weights = lambda model, config, path: None 72 | base_model_prefix = "" 73 | 74 | def __init__(self, config, *inputs, **kwargs): 75 | super(PreTrainedModel, self).__init__() 76 | if not isinstance(config, PretrainedConfig): 77 | raise ValueError( 78 | "Parameter config in `{}(config)` should be an instance of class `PretrainedConfig`. " 79 | "To create a model from a pretrained model use " 80 | "`model = {}.from_pretrained(PRETRAINED_MODEL_NAME)`".format( 81 | self.__class__.__name__, self.__class__.__name__ 82 | )) 83 | # Save config in model 84 | self.config = config 85 | 86 | def _get_resized_embeddings(self, old_embeddings, new_num_tokens=None): 87 | """ Build a resized Embedding Module from a provided token Embedding Module. 88 | Increasing the size will add newly initialized vectors at the end 89 | Reducing the size will remove vectors from the end 90 | 91 | Args: 92 | new_num_tokens: (`optional`) int 93 | New number of tokens in the embedding matrix. 94 | Increasing the size will add newly initialized vectors at the end 95 | Reducing the size will remove vectors from the end 96 | If not provided or None: return the provided token Embedding Module. 97 | Return: ``torch.nn.Embeddings`` 98 | Pointer to the resized Embedding Module or the old Embedding Module if new_num_tokens is None 99 | """ 100 | if new_num_tokens is None: 101 | return old_embeddings 102 | 103 | old_num_tokens, old_embedding_dim = old_embeddings.weight.size() 104 | if old_num_tokens == new_num_tokens: 105 | return old_embeddings 106 | 107 | # Build new embeddings 108 | new_embeddings = nn.Embedding(new_num_tokens, old_embedding_dim) 109 | new_embeddings.to(old_embeddings.weight.device) 110 | 111 | # initialize all new embeddings (in particular added tokens) 112 | self._init_weights(new_embeddings) 113 | 114 | # Copy word embeddings from the previous weights 115 | num_tokens_to_copy = min(old_num_tokens, new_num_tokens) 116 | new_embeddings.weight.data[:num_tokens_to_copy, :] = old_embeddings.weight.data[:num_tokens_to_copy, :] 117 | 118 | return new_embeddings 119 | 120 | def _tie_or_clone_weights(self, first_module, second_module): 121 | """ Tie or clone module weights depending of weither we are using TorchScript or not 122 | """ 123 | if self.config.torchscript: 124 | first_module.weight = nn.Parameter(second_module.weight.clone()) 125 | else: 126 | first_module.weight = second_module.weight 127 | 128 | if hasattr(first_module, 'bias') and first_module.bias is not None: 129 | first_module.bias.data = torch.nn.functional.pad( 130 | first_module.bias.data, 131 | (0, first_module.weight.shape[0] - first_module.bias.shape[0]), 132 | 'constant', 133 | 0 134 | ) 135 | 136 | def resize_token_embeddings(self, new_num_tokens=None): 137 | """ Resize input token embeddings matrix of the model if new_num_tokens != config.vocab_size. 138 | Take care of tying weights embeddings afterwards if the model class has a `tie_weights()` method. 139 | 140 | Arguments: 141 | 142 | new_num_tokens: (`optional`) int: 143 | New number of tokens in the embedding matrix. Increasing the size will add newly initialized vectors at the end. Reducing the size will remove vectors from the end. 144 | If not provided or None: does nothing and just returns a pointer to the input tokens ``torch.nn.Embeddings`` Module of the model. 145 | 146 | Return: ``torch.nn.Embeddings`` 147 | Pointer to the input tokens Embeddings Module of the model 148 | """ 149 | base_model = getattr(self, self.base_model_prefix, self) # get the base model if needed 150 | model_embeds = base_model._resize_token_embeddings(new_num_tokens) 151 | if new_num_tokens is None: 152 | return model_embeds 153 | 154 | # Update base model and current model config 155 | self.config.vocab_size = new_num_tokens 156 | base_model.vocab_size = new_num_tokens 157 | 158 | # Tie weights again if needed 159 | if hasattr(self, 'tie_weights'): 160 | self.tie_weights() 161 | 162 | return model_embeds 163 | 164 | def init_weights(self): 165 | """ Initialize and prunes weights if needed. """ 166 | # Initialize weights 167 | self.apply(self._init_weights) 168 | 169 | # Prune heads if needed 170 | if self.config.pruned_heads: 171 | self.prune_heads(self.config.pruned_heads) 172 | 173 | def prune_heads(self, heads_to_prune): 174 | """ Prunes heads of the base model. 175 | 176 | Arguments: 177 | 178 | heads_to_prune: dict with keys being selected layer indices (`int`) and associated values being the list of heads to prune in said layer (list of `int`). 179 | E.g. {1: [0, 2], 2: [2, 3]} will prune heads 0 and 2 on layer 1 and heads 2 and 3 on layer 2. 180 | """ 181 | base_model = getattr(self, self.base_model_prefix, self) # get the base model if needed 182 | 183 | # save new sets of pruned heads as union of previously stored pruned heads and newly pruned heads 184 | for layer, heads in heads_to_prune.items(): 185 | union_heads = set(self.config.pruned_heads.get(layer, [])) | set(heads) 186 | self.config.pruned_heads[layer] = list(union_heads) # Unfortunately we have to store it as list for JSON 187 | 188 | base_model._prune_heads(heads_to_prune) 189 | 190 | def save_pretrained(self, save_directory): 191 | """ Save a model and its configuration file to a directory, so that it 192 | can be re-loaded using the `:func:`~transformers.PreTrainedModel.from_pretrained`` class method. 193 | """ 194 | assert os.path.isdir(save_directory), "Saving path should be a directory where the model and configuration can be saved" 195 | 196 | # Only save the model it-self if we are using distributed training 197 | model_to_save = self.module if hasattr(self, 'module') else self 198 | 199 | # Save configuration file 200 | model_to_save.config.save_pretrained(save_directory) 201 | 202 | # If we save using the predefined names, we can load using `from_pretrained` 203 | output_model_file = os.path.join(save_directory, WEIGHTS_NAME) 204 | torch.save(model_to_save.state_dict(), output_model_file) 205 | logger.info("Model weights saved in {}".format(output_model_file)) 206 | 207 | @classmethod 208 | def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs): 209 | r"""Instantiate a pretrained pytorch model from a pre-trained model configuration. 210 | 211 | The model is set in evaluation mode by default using ``model.eval()`` (Dropout modules are deactivated) 212 | To train the model, you should first set it back in training mode with ``model.train()`` 213 | 214 | The warning ``Weights from XXX not initialized from pretrained model`` means that the weights of XXX do not come pre-trained with the rest of the model. 215 | It is up to you to train those weights with a downstream fine-tuning task. 216 | 217 | The warning ``Weights from XXX not used in YYY`` means that the layer XXX is not used by YYY, therefore those weights are discarded. 218 | 219 | Parameters: 220 | pretrained_model_name_or_path: either: 221 | 222 | - a string with the `shortcut name` of a pre-trained model to load from cache or download, e.g.: ``bert-base-uncased``. 223 | - a path to a `directory` containing model weights saved using :func:`~transformers.PreTrainedModel.save_pretrained`, e.g.: ``./my_model_directory/``. 224 | - a path or url to a `tensorflow index checkpoint file` (e.g. `./tf_model/model.ckpt.index`). In this case, ``from_tf`` should be set to True and a configuration object should be provided as ``config`` argument. This loading path is slower than converting the TensorFlow checkpoint in a PyTorch model using the provided conversion scripts and loading the PyTorch model afterwards. 225 | - None if you are both providing the configuration and state dictionary (resp. with keyword arguments ``config`` and ``state_dict``) 226 | 227 | model_args: (`optional`) Sequence of positional arguments: 228 | All remaning positional arguments will be passed to the underlying model's ``__init__`` method 229 | 230 | config: (`optional`) instance of a class derived from :class:`~transformers.PretrainedConfig`: 231 | Configuration for the model to use instead of an automatically loaded configuation. Configuration can be automatically loaded when: 232 | 233 | - the model is a model provided by the library (loaded with the ``shortcut-name`` string of a pretrained model), or 234 | - the model was saved using :func:`~transformers.PreTrainedModel.save_pretrained` and is reloaded by suppling the save directory. 235 | - the model is loaded by suppling a local directory as ``pretrained_model_name_or_path`` and a configuration JSON file named `config.json` is found in the directory. 236 | 237 | state_dict: (`optional`) dict: 238 | an optional state dictionnary for the model to use instead of a state dictionary loaded from saved weights file. 239 | This option can be used if you want to create a model from a pretrained configuration but load your own weights. 240 | In this case though, you should check if using :func:`~transformers.PreTrainedModel.save_pretrained` and :func:`~transformers.PreTrainedModel.from_pretrained` is not a simpler option. 241 | 242 | cache_dir: (`optional`) string: 243 | Path to a directory in which a downloaded pre-trained model 244 | configuration should be cached if the standard cache should not be used. 245 | 246 | force_download: (`optional`) boolean, default False: 247 | Force to (re-)download the model weights and configuration files and override the cached versions if they exists. 248 | 249 | proxies: (`optional`) dict, default None: 250 | A dictionary of proxy servers to use by protocol or endpoint, e.g.: {'http': 'foo.bar:3128', 'http://hostname': 'foo.bar:4012'}. 251 | The proxies are used on each request. 252 | 253 | output_loading_info: (`optional`) boolean: 254 | Set to ``True`` to also return a dictionnary containing missing keys, unexpected keys and error messages. 255 | 256 | kwargs: (`optional`) Remaining dictionary of keyword arguments: 257 | Can be used to update the configuration object (after it being loaded) and initiate the model. (e.g. ``output_attention=True``). Behave differently depending on whether a `config` is provided or automatically loaded: 258 | 259 | - If a configuration is provided with ``config``, ``**kwargs`` will be directly passed to the underlying model's ``__init__`` method (we assume all relevant updates to the configuration have already been done) 260 | - If a configuration is not provided, ``kwargs`` will be first passed to the configuration class initialization function (:func:`~transformers.PretrainedConfig.from_pretrained`). Each key of ``kwargs`` that corresponds to a configuration attribute will be used to override said attribute with the supplied ``kwargs`` value. Remaining keys that do not correspond to any configuration attribute will be passed to the underlying model's ``__init__`` function. 261 | 262 | Examples:: 263 | 264 | model = BertModel.from_pretrained('bert-base-uncased') # Download model and configuration from S3 and cache. 265 | model = BertModel.from_pretrained('./test/saved_model/') # E.g. model was saved using `save_pretrained('./test/saved_model/')` 266 | model = BertModel.from_pretrained('bert-base-uncased', output_attention=True) # Update configuration during loading 267 | assert model.config.output_attention == True 268 | # Loading from a TF checkpoint file instead of a PyTorch model (slower) 269 | config = BertConfig.from_json_file('./tf_model/my_tf_model_config.json') 270 | model = BertModel.from_pretrained('./tf_model/my_tf_checkpoint.ckpt.index', from_tf=True, config=config) 271 | 272 | """ 273 | config = kwargs.pop('config', None) 274 | state_dict = kwargs.pop('state_dict', None) 275 | cache_dir = kwargs.pop('cache_dir', None) 276 | from_tf = kwargs.pop('from_tf', False) 277 | force_download = kwargs.pop('force_download', False) 278 | proxies = kwargs.pop('proxies', None) 279 | output_loading_info = kwargs.pop('output_loading_info', False) 280 | 281 | # Load config 282 | if config is None: 283 | config, model_kwargs = cls.config_class.from_pretrained( 284 | pretrained_model_name_or_path, *model_args, 285 | cache_dir=cache_dir, return_unused_kwargs=True, 286 | force_download=force_download, 287 | **kwargs 288 | ) 289 | else: 290 | model_kwargs = kwargs 291 | 292 | # Load model 293 | if pretrained_model_name_or_path is not None: 294 | if pretrained_model_name_or_path in cls.pretrained_model_archive_map: 295 | archive_file = cls.pretrained_model_archive_map[pretrained_model_name_or_path] 296 | elif os.path.isdir(pretrained_model_name_or_path): 297 | if from_tf and os.path.isfile(os.path.join(pretrained_model_name_or_path, TF_WEIGHTS_NAME + ".index")): 298 | # Load from a TF 1.0 checkpoint 299 | archive_file = os.path.join(pretrained_model_name_or_path, TF_WEIGHTS_NAME + ".index") 300 | elif from_tf and os.path.isfile(os.path.join(pretrained_model_name_or_path, TF2_WEIGHTS_NAME)): 301 | # Load from a TF 2.0 checkpoint 302 | archive_file = os.path.join(pretrained_model_name_or_path, TF2_WEIGHTS_NAME) 303 | elif os.path.isfile(os.path.join(pretrained_model_name_or_path, WEIGHTS_NAME)): 304 | # Load from a PyTorch checkpoint 305 | archive_file = os.path.join(pretrained_model_name_or_path, WEIGHTS_NAME) 306 | else: 307 | raise EnvironmentError("Error no file named {} found in directory {} or `from_tf` set to False".format( 308 | [WEIGHTS_NAME, TF2_WEIGHTS_NAME, TF_WEIGHTS_NAME + ".index"], 309 | pretrained_model_name_or_path)) 310 | elif os.path.isfile(pretrained_model_name_or_path): 311 | archive_file = pretrained_model_name_or_path 312 | else: 313 | assert from_tf, "Error finding file {}, no file or TF 1.X checkpoint found".format(pretrained_model_name_or_path) 314 | archive_file = pretrained_model_name_or_path + ".index" 315 | 316 | # redirect to the cache, if necessary 317 | try: 318 | resolved_archive_file = cached_path(archive_file, cache_dir=cache_dir, force_download=force_download, proxies=proxies) 319 | except EnvironmentError: 320 | if pretrained_model_name_or_path in cls.pretrained_model_archive_map: 321 | msg = "Couldn't reach server at '{}' to download pretrained weights.".format( 322 | archive_file) 323 | else: 324 | msg = "Model name '{}' was not found in model name list ({}). " \ 325 | "We assumed '{}' was a path or url to model weight files named one of {} but " \ 326 | "couldn't find any such file at this path or url.".format( 327 | pretrained_model_name_or_path, 328 | ', '.join(cls.pretrained_model_archive_map.keys()), 329 | archive_file, 330 | [WEIGHTS_NAME, TF2_WEIGHTS_NAME, TF_WEIGHTS_NAME]) 331 | raise EnvironmentError(msg) 332 | 333 | if resolved_archive_file == archive_file: 334 | logger.info("loading weights file {}".format(archive_file)) 335 | else: 336 | logger.info("loading weights file {} from cache at {}".format( 337 | archive_file, resolved_archive_file)) 338 | else: 339 | resolved_archive_file = None 340 | 341 | # Instantiate model. 342 | model = cls(config, *model_args, **model_kwargs) 343 | 344 | if state_dict is None and not from_tf: 345 | state_dict = torch.load(resolved_archive_file, map_location='cpu') 346 | 347 | missing_keys = [] 348 | unexpected_keys = [] 349 | error_msgs = [] 350 | 351 | if from_tf: 352 | if resolved_archive_file.endswith('.index'): 353 | # Load from a TensorFlow 1.X checkpoint - provided by original authors 354 | model = cls.load_tf_weights(model, config, resolved_archive_file[:-6]) # Remove the '.index' 355 | else: 356 | # Load from our TensorFlow 2.0 checkpoints 357 | try: 358 | from transformers import load_tf2_checkpoint_in_pytorch_model 359 | model = load_tf2_checkpoint_in_pytorch_model(model, resolved_archive_file, allow_missing_keys=True) 360 | except ImportError as e: 361 | logger.error("Loading a TensorFlow model in PyTorch, requires both PyTorch and TensorFlow to be installed. Please see " 362 | "https://pytorch.org/ and https://www.tensorflow.org/install/ for installation instructions.") 363 | raise e 364 | else: 365 | # Convert old format to new format if needed from a PyTorch state_dict 366 | old_keys = [] 367 | new_keys = [] 368 | for key in state_dict.keys(): 369 | new_key = None 370 | if 'gamma' in key: 371 | new_key = key.replace('gamma', 'weight') 372 | if 'beta' in key: 373 | new_key = key.replace('beta', 'bias') 374 | if new_key: 375 | old_keys.append(key) 376 | new_keys.append(new_key) 377 | for old_key, new_key in zip(old_keys, new_keys): 378 | state_dict[new_key] = state_dict.pop(old_key) 379 | 380 | # copy state_dict so _load_from_state_dict can modify it 381 | metadata = getattr(state_dict, '_metadata', None) 382 | state_dict = state_dict.copy() 383 | if metadata is not None: 384 | state_dict._metadata = metadata 385 | 386 | def load(module, prefix=''): 387 | local_metadata = {} if metadata is None else metadata.get(prefix[:-1], {}) 388 | module._load_from_state_dict( 389 | state_dict, prefix, local_metadata, True, missing_keys, unexpected_keys, error_msgs) 390 | for name, child in module._modules.items(): 391 | if child is not None: 392 | load(child, prefix + name + '.') 393 | 394 | # Make sure we are able to load base models as well as derived models (with heads) 395 | start_prefix = '' 396 | model_to_load = model 397 | if not hasattr(model, cls.base_model_prefix) and any(s.startswith(cls.base_model_prefix) for s in state_dict.keys()): 398 | start_prefix = cls.base_model_prefix + '.' 399 | if hasattr(model, cls.base_model_prefix) and not any(s.startswith(cls.base_model_prefix) for s in state_dict.keys()): 400 | model_to_load = getattr(model, cls.base_model_prefix) 401 | 402 | load(model_to_load, prefix=start_prefix) 403 | if len(missing_keys) > 0: 404 | logger.info("Weights of {} not initialized from pretrained model: {}".format( 405 | model.__class__.__name__, missing_keys)) 406 | if len(unexpected_keys) > 0: 407 | logger.info("Weights from pretrained model not used in {}: {}".format( 408 | model.__class__.__name__, unexpected_keys)) 409 | if len(error_msgs) > 0: 410 | raise RuntimeError('Error(s) in loading state_dict for {}:\n\t{}'.format( 411 | model.__class__.__name__, "\n\t".join(error_msgs))) 412 | 413 | if hasattr(model, 'tie_weights'): 414 | model.tie_weights() # make sure word embedding weights are still tied 415 | 416 | # Set model in evaluation mode to desactivate DropOut modules by default 417 | model.eval() 418 | 419 | if output_loading_info: 420 | loading_info = {"missing_keys": missing_keys, "unexpected_keys": unexpected_keys, "error_msgs": error_msgs} 421 | return model, loading_info 422 | 423 | return model 424 | 425 | 426 | class Conv1D(nn.Module): 427 | def __init__(self, nf, nx): 428 | """ Conv1D layer as defined by Radford et al. for OpenAI GPT (and also used in GPT-2) 429 | Basically works like a Linear layer but the weights are transposed 430 | """ 431 | super(Conv1D, self).__init__() 432 | self.nf = nf 433 | w = torch.empty(nx, nf) 434 | nn.init.normal_(w, std=0.02) 435 | self.weight = nn.Parameter(w) 436 | self.bias = nn.Parameter(torch.zeros(nf)) 437 | 438 | def forward(self, x): 439 | size_out = x.size()[:-1] + (self.nf,) 440 | x = torch.addmm(self.bias, x.view(-1, x.size(-1)), self.weight) 441 | x = x.view(*size_out) 442 | return x 443 | 444 | 445 | class PoolerStartLogits(nn.Module): 446 | """ Compute SQuAD start_logits from sequence hidden states. """ 447 | def __init__(self, config): 448 | super(PoolerStartLogits, self).__init__() 449 | self.dense = nn.Linear(config.hidden_size, 1) 450 | 451 | def forward(self, hidden_states, p_mask=None): 452 | """ Args: 453 | **p_mask**: (`optional`) ``torch.FloatTensor`` of shape `(batch_size, seq_len)` 454 | invalid position mask such as query and special symbols (PAD, SEP, CLS) 455 | 1.0 means token should be masked. 456 | """ 457 | x = self.dense(hidden_states).squeeze(-1) 458 | 459 | if p_mask is not None: 460 | if next(self.parameters()).dtype == torch.float16: 461 | x = x * (1 - p_mask) - 65500 * p_mask 462 | else: 463 | x = x * (1 - p_mask) - 1e30 * p_mask 464 | 465 | return x 466 | 467 | 468 | class PoolerEndLogits(nn.Module): 469 | """ Compute SQuAD end_logits from sequence hidden states and start token hidden state. 470 | """ 471 | def __init__(self, config): 472 | super(PoolerEndLogits, self).__init__() 473 | self.dense_0 = nn.Linear(config.hidden_size * 2, config.hidden_size) 474 | self.activation = nn.Tanh() 475 | self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) 476 | self.dense_1 = nn.Linear(config.hidden_size, 1) 477 | 478 | def forward(self, hidden_states, start_states=None, start_positions=None, p_mask=None): 479 | """ Args: 480 | One of ``start_states``, ``start_positions`` should be not None. 481 | If both are set, ``start_positions`` overrides ``start_states``. 482 | 483 | **start_states**: ``torch.LongTensor`` of shape identical to hidden_states 484 | hidden states of the first tokens for the labeled span. 485 | **start_positions**: ``torch.LongTensor`` of shape ``(batch_size,)`` 486 | position of the first token for the labeled span: 487 | **p_mask**: (`optional`) ``torch.FloatTensor`` of shape ``(batch_size, seq_len)`` 488 | Mask of invalid position such as query and special symbols (PAD, SEP, CLS) 489 | 1.0 means token should be masked. 490 | """ 491 | assert start_states is not None or start_positions is not None, "One of start_states, start_positions should be not None" 492 | if start_positions is not None: 493 | slen, hsz = hidden_states.shape[-2:] 494 | start_positions = start_positions[:, None, None].expand(-1, -1, hsz) # shape (bsz, 1, hsz) 495 | start_states = hidden_states.gather(-2, start_positions) # shape (bsz, 1, hsz) 496 | start_states = start_states.expand(-1, slen, -1) # shape (bsz, slen, hsz) 497 | 498 | x = self.dense_0(torch.cat([hidden_states, start_states], dim=-1)) 499 | x = self.activation(x) 500 | x = self.LayerNorm(x) 501 | x = self.dense_1(x).squeeze(-1) 502 | 503 | if p_mask is not None: 504 | if next(self.parameters()).dtype == torch.float16: 505 | x = x * (1 - p_mask) - 65500 * p_mask 506 | else: 507 | x = x * (1 - p_mask) - 1e30 * p_mask 508 | 509 | return x 510 | 511 | 512 | class PoolerAnswerClass(nn.Module): 513 | """ Compute SQuAD 2.0 answer class from classification and start tokens hidden states. """ 514 | def __init__(self, config): 515 | super(PoolerAnswerClass, self).__init__() 516 | self.dense_0 = nn.Linear(config.hidden_size * 2, config.hidden_size) 517 | self.activation = nn.Tanh() 518 | self.dense_1 = nn.Linear(config.hidden_size, 1, bias=False) 519 | 520 | def forward(self, hidden_states, start_states=None, start_positions=None, cls_index=None): 521 | """ 522 | Args: 523 | One of ``start_states``, ``start_positions`` should be not None. 524 | If both are set, ``start_positions`` overrides ``start_states``. 525 | 526 | **start_states**: ``torch.LongTensor`` of shape identical to ``hidden_states``. 527 | hidden states of the first tokens for the labeled span. 528 | **start_positions**: ``torch.LongTensor`` of shape ``(batch_size,)`` 529 | position of the first token for the labeled span. 530 | **cls_index**: torch.LongTensor of shape ``(batch_size,)`` 531 | position of the CLS token. If None, take the last token. 532 | 533 | note(Original repo): 534 | no dependency on end_feature so that we can obtain one single `cls_logits` 535 | for each sample 536 | """ 537 | hsz = hidden_states.shape[-1] 538 | assert start_states is not None or start_positions is not None, "One of start_states, start_positions should be not None" 539 | if start_positions is not None: 540 | start_positions = start_positions[:, None, None].expand(-1, -1, hsz) # shape (bsz, 1, hsz) 541 | start_states = hidden_states.gather(-2, start_positions).squeeze(-2) # shape (bsz, hsz) 542 | 543 | if cls_index is not None: 544 | cls_index = cls_index[:, None, None].expand(-1, -1, hsz) # shape (bsz, 1, hsz) 545 | cls_token_state = hidden_states.gather(-2, cls_index).squeeze(-2) # shape (bsz, hsz) 546 | else: 547 | cls_token_state = hidden_states[:, -1, :] # shape (bsz, hsz) 548 | 549 | x = self.dense_0(torch.cat([start_states, cls_token_state], dim=-1)) 550 | x = self.activation(x) 551 | x = self.dense_1(x).squeeze(-1) 552 | 553 | return x 554 | 555 | 556 | class SQuADHead(nn.Module): 557 | r""" A SQuAD head inspired by XLNet. 558 | 559 | Parameters: 560 | config (:class:`~transformers.XLNetConfig`): Model configuration class with all the parameters of the model. 561 | 562 | Inputs: 563 | **hidden_states**: ``torch.FloatTensor`` of shape ``(batch_size, seq_len, hidden_size)`` 564 | hidden states of sequence tokens 565 | **start_positions**: ``torch.LongTensor`` of shape ``(batch_size,)`` 566 | position of the first token for the labeled span. 567 | **end_positions**: ``torch.LongTensor`` of shape ``(batch_size,)`` 568 | position of the last token for the labeled span. 569 | **cls_index**: torch.LongTensor of shape ``(batch_size,)`` 570 | position of the CLS token. If None, take the last token. 571 | **is_impossible**: ``torch.LongTensor`` of shape ``(batch_size,)`` 572 | Whether the question has a possible answer in the paragraph or not. 573 | **p_mask**: (`optional`) ``torch.FloatTensor`` of shape ``(batch_size, seq_len)`` 574 | Mask of invalid position such as query and special symbols (PAD, SEP, CLS) 575 | 1.0 means token should be masked. 576 | 577 | Outputs: `Tuple` comprising various elements depending on the configuration (config) and inputs: 578 | **loss**: (`optional`, returned if both ``start_positions`` and ``end_positions`` are provided) ``torch.FloatTensor`` of shape ``(1,)``: 579 | Classification loss as the sum of start token, end token (and is_impossible if provided) classification losses. 580 | **start_top_log_probs**: (`optional`, returned if ``start_positions`` or ``end_positions`` is not provided) 581 | ``torch.FloatTensor`` of shape ``(batch_size, config.start_n_top)`` 582 | Log probabilities for the top config.start_n_top start token possibilities (beam-search). 583 | **start_top_index**: (`optional`, returned if ``start_positions`` or ``end_positions`` is not provided) 584 | ``torch.LongTensor`` of shape ``(batch_size, config.start_n_top)`` 585 | Indices for the top config.start_n_top start token possibilities (beam-search). 586 | **end_top_log_probs**: (`optional`, returned if ``start_positions`` or ``end_positions`` is not provided) 587 | ``torch.FloatTensor`` of shape ``(batch_size, config.start_n_top * config.end_n_top)`` 588 | Log probabilities for the top ``config.start_n_top * config.end_n_top`` end token possibilities (beam-search). 589 | **end_top_index**: (`optional`, returned if ``start_positions`` or ``end_positions`` is not provided) 590 | ``torch.LongTensor`` of shape ``(batch_size, config.start_n_top * config.end_n_top)`` 591 | Indices for the top ``config.start_n_top * config.end_n_top`` end token possibilities (beam-search). 592 | **cls_logits**: (`optional`, returned if ``start_positions`` or ``end_positions`` is not provided) 593 | ``torch.FloatTensor`` of shape ``(batch_size,)`` 594 | Log probabilities for the ``is_impossible`` label of the answers. 595 | """ 596 | def __init__(self, config): 597 | super(SQuADHead, self).__init__() 598 | self.start_n_top = config.start_n_top 599 | self.end_n_top = config.end_n_top 600 | 601 | self.start_logits = PoolerStartLogits(config) 602 | self.end_logits = PoolerEndLogits(config) 603 | self.answer_class = PoolerAnswerClass(config) 604 | 605 | def forward(self, hidden_states, start_positions=None, end_positions=None, 606 | cls_index=None, is_impossible=None, p_mask=None): 607 | outputs = () 608 | 609 | start_logits = self.start_logits(hidden_states, p_mask=p_mask) 610 | 611 | if start_positions is not None and end_positions is not None: 612 | # If we are on multi-GPU, let's remove the dimension added by batch splitting 613 | for x in (start_positions, end_positions, cls_index, is_impossible): 614 | if x is not None and x.dim() > 1: 615 | x.squeeze_(-1) 616 | 617 | # during training, compute the end logits based on the ground truth of the start position 618 | end_logits = self.end_logits(hidden_states, start_positions=start_positions, p_mask=p_mask) 619 | 620 | loss_fct = CrossEntropyLoss() 621 | start_loss = loss_fct(start_logits, start_positions) 622 | end_loss = loss_fct(end_logits, end_positions) 623 | total_loss = (start_loss + end_loss) / 2 624 | 625 | if cls_index is not None and is_impossible is not None: 626 | # Predict answerability from the representation of CLS and START 627 | cls_logits = self.answer_class(hidden_states, start_positions=start_positions, cls_index=cls_index) 628 | loss_fct_cls = nn.BCEWithLogitsLoss() 629 | cls_loss = loss_fct_cls(cls_logits, is_impossible) 630 | 631 | # note(zhiliny): by default multiply the loss by 0.5 so that the scale is comparable to start_loss and end_loss 632 | total_loss += cls_loss * 0.5 633 | 634 | outputs = (total_loss,) + outputs 635 | 636 | else: 637 | # during inference, compute the end logits based on beam search 638 | bsz, slen, hsz = hidden_states.size() 639 | start_log_probs = F.softmax(start_logits, dim=-1) # shape (bsz, slen) 640 | 641 | start_top_log_probs, start_top_index = torch.topk(start_log_probs, self.start_n_top, dim=-1) # shape (bsz, start_n_top) 642 | start_top_index_exp = start_top_index.unsqueeze(-1).expand(-1, -1, hsz) # shape (bsz, start_n_top, hsz) 643 | start_states = torch.gather(hidden_states, -2, start_top_index_exp) # shape (bsz, start_n_top, hsz) 644 | start_states = start_states.unsqueeze(1).expand(-1, slen, -1, -1) # shape (bsz, slen, start_n_top, hsz) 645 | 646 | hidden_states_expanded = hidden_states.unsqueeze(2).expand_as(start_states) # shape (bsz, slen, start_n_top, hsz) 647 | p_mask = p_mask.unsqueeze(-1) if p_mask is not None else None 648 | end_logits = self.end_logits(hidden_states_expanded, start_states=start_states, p_mask=p_mask) 649 | end_log_probs = F.softmax(end_logits, dim=1) # shape (bsz, slen, start_n_top) 650 | 651 | end_top_log_probs, end_top_index = torch.topk(end_log_probs, self.end_n_top, dim=1) # shape (bsz, end_n_top, start_n_top) 652 | end_top_log_probs = end_top_log_probs.view(-1, self.start_n_top * self.end_n_top) 653 | end_top_index = end_top_index.view(-1, self.start_n_top * self.end_n_top) 654 | 655 | start_states = torch.einsum("blh,bl->bh", hidden_states, start_log_probs) 656 | cls_logits = self.answer_class(hidden_states, start_states=start_states, cls_index=cls_index) 657 | 658 | outputs = (start_top_log_probs, start_top_index, end_top_log_probs, end_top_index, cls_logits) + outputs 659 | 660 | # return start_top_log_probs, start_top_index, end_top_log_probs, end_top_index, cls_logits 661 | # or (if labels are provided) (total_loss,) 662 | return outputs 663 | 664 | 665 | class SequenceSummary(nn.Module): 666 | r""" Compute a single vector summary of a sequence hidden states according to various possibilities: 667 | Args of the config class: 668 | summary_type: 669 | - 'last' => [default] take the last token hidden state (like XLNet) 670 | - 'first' => take the first token hidden state (like Bert) 671 | - 'mean' => take the mean of all tokens hidden states 672 | - 'cls_index' => supply a Tensor of classification token position (GPT/GPT-2) 673 | - 'attn' => Not implemented now, use multi-head attention 674 | summary_use_proj: Add a projection after the vector extraction 675 | summary_proj_to_labels: If True, the projection outputs to config.num_labels classes (otherwise to hidden_size). Default: False. 676 | summary_activation: 'tanh' => add a tanh activation to the output, Other => no activation. Default 677 | summary_first_dropout: Add a dropout before the projection and activation 678 | summary_last_dropout: Add a dropout after the projection and activation 679 | """ 680 | def __init__(self, config): 681 | super(SequenceSummary, self).__init__() 682 | 683 | self.summary_type = config.summary_type if hasattr(config, 'summary_use_proj') else 'last' 684 | if self.summary_type == 'attn': 685 | # We should use a standard multi-head attention module with absolute positional embedding for that. 686 | # Cf. https://github.com/zihangdai/xlnet/blob/master/modeling.py#L253-L276 687 | # We can probably just use the multi-head attention module of PyTorch >=1.1.0 688 | raise NotImplementedError 689 | 690 | self.summary = Identity() 691 | if hasattr(config, 'summary_use_proj') and config.summary_use_proj: 692 | if hasattr(config, 'summary_proj_to_labels') and config.summary_proj_to_labels and config.num_labels > 0: 693 | num_classes = config.num_labels 694 | else: 695 | num_classes = config.hidden_size 696 | self.summary = nn.Linear(config.hidden_size, num_classes) 697 | 698 | self.activation = Identity() 699 | if hasattr(config, 'summary_activation') and config.summary_activation == 'tanh': 700 | self.activation = nn.Tanh() 701 | 702 | self.first_dropout = Identity() 703 | if hasattr(config, 'summary_first_dropout') and config.summary_first_dropout > 0: 704 | self.first_dropout = nn.Dropout(config.summary_first_dropout) 705 | 706 | self.last_dropout = Identity() 707 | if hasattr(config, 'summary_last_dropout') and config.summary_last_dropout > 0: 708 | self.last_dropout = nn.Dropout(config.summary_last_dropout) 709 | 710 | def forward(self, hidden_states, cls_index=None): 711 | """ hidden_states: float Tensor in shape [bsz, ..., seq_len, hidden_size], the hidden-states of the last layer. 712 | cls_index: [optional] position of the classification token if summary_type == 'cls_index', 713 | shape (bsz,) or more generally (bsz, ...) where ... are optional leading dimensions of hidden_states. 714 | if summary_type == 'cls_index' and cls_index is None: 715 | we take the last token of the sequence as classification token 716 | """ 717 | if self.summary_type == 'last': 718 | output = hidden_states[:, -1] 719 | elif self.summary_type == 'first': 720 | output = hidden_states[:, 0] 721 | elif self.summary_type == 'mean': 722 | output = hidden_states.mean(dim=1) 723 | elif self.summary_type == 'cls_index': 724 | if cls_index is None: 725 | cls_index = torch.full_like(hidden_states[..., :1, :], hidden_states.shape[-2]-1, dtype=torch.long) 726 | else: 727 | cls_index = cls_index.unsqueeze(-1).unsqueeze(-1) 728 | cls_index = cls_index.expand((-1,) * (cls_index.dim()-1) + (hidden_states.size(-1),)) 729 | # shape of cls_index: (bsz, XX, 1, hidden_size) where XX are optional leading dim of hidden_states 730 | output = hidden_states.gather(-2, cls_index).squeeze(-2) # shape (bsz, XX, hidden_size) 731 | elif self.summary_type == 'attn': 732 | raise NotImplementedError 733 | 734 | output = self.first_dropout(output) 735 | output = self.summary(output) 736 | output = self.activation(output) 737 | output = self.last_dropout(output) 738 | 739 | return output 740 | 741 | 742 | def prune_linear_layer(layer, index, dim=0): 743 | """ Prune a linear layer (a model parameters) to keep only entries in index. 744 | Return the pruned layer as a new layer with requires_grad=True. 745 | Used to remove heads. 746 | """ 747 | index = index.to(layer.weight.device) 748 | W = layer.weight.index_select(dim, index).clone().detach() 749 | if layer.bias is not None: 750 | if dim == 1: 751 | b = layer.bias.clone().detach() 752 | else: 753 | b = layer.bias[index].clone().detach() 754 | new_size = list(layer.weight.size()) 755 | new_size[dim] = len(index) 756 | new_layer = nn.Linear(new_size[1], new_size[0], bias=layer.bias is not None).to(layer.weight.device) 757 | new_layer.weight.requires_grad = False 758 | new_layer.weight.copy_(W.contiguous()) 759 | new_layer.weight.requires_grad = True 760 | if layer.bias is not None: 761 | new_layer.bias.requires_grad = False 762 | new_layer.bias.copy_(b.contiguous()) 763 | new_layer.bias.requires_grad = True 764 | return new_layer 765 | 766 | 767 | def prune_conv1d_layer(layer, index, dim=1): 768 | """ Prune a Conv1D layer (a model parameters) to keep only entries in index. 769 | A Conv1D work as a Linear layer (see e.g. BERT) but the weights are transposed. 770 | Return the pruned layer as a new layer with requires_grad=True. 771 | Used to remove heads. 772 | """ 773 | index = index.to(layer.weight.device) 774 | W = layer.weight.index_select(dim, index).clone().detach() 775 | if dim == 0: 776 | b = layer.bias.clone().detach() 777 | else: 778 | b = layer.bias[index].clone().detach() 779 | new_size = list(layer.weight.size()) 780 | new_size[dim] = len(index) 781 | new_layer = Conv1D(new_size[1], new_size[0]).to(layer.weight.device) 782 | new_layer.weight.requires_grad = False 783 | new_layer.weight.copy_(W.contiguous()) 784 | new_layer.weight.requires_grad = True 785 | new_layer.bias.requires_grad = False 786 | new_layer.bias.copy_(b.contiguous()) 787 | new_layer.bias.requires_grad = True 788 | return new_layer 789 | 790 | 791 | def prune_layer(layer, index, dim=None): 792 | """ Prune a Conv1D or nn.Linear layer (a model parameters) to keep only entries in index. 793 | Return the pruned layer as a new layer with requires_grad=True. 794 | Used to remove heads. 795 | """ 796 | if isinstance(layer, nn.Linear): 797 | return prune_linear_layer(layer, index, dim=0 if dim is None else dim) 798 | elif isinstance(layer, Conv1D): 799 | return prune_conv1d_layer(layer, index, dim=1 if dim is None else dim) 800 | else: 801 | raise ValueError("Can't prune layer of class {}".format(layer.__class__)) 802 | -------------------------------------------------------------------------------- /src/transformers19/tokenization_gpt2.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2018 The Open AI Team Authors and The HuggingFace Inc. team. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | """Tokenization classes for OpenAI GPT.""" 16 | from __future__ import (absolute_import, division, print_function, 17 | unicode_literals) 18 | 19 | import sys 20 | import json 21 | import logging 22 | import os 23 | import regex as re 24 | from io import open 25 | 26 | try: 27 | from functools import lru_cache 28 | except ImportError: 29 | # Just a dummy decorator to get the checks to run on python2 30 | # because honestly I don't want to support a byte-level unicode BPE tokenizer on python 2 right now. 31 | def lru_cache(): 32 | return lambda func: func 33 | 34 | from .tokenization_utils import PreTrainedTokenizer 35 | 36 | logger = logging.getLogger(__name__) 37 | 38 | VOCAB_FILES_NAMES = { 39 | 'vocab_file': 'vocab.json', 40 | 'merges_file': 'merges.txt', 41 | } 42 | 43 | PRETRAINED_VOCAB_FILES_MAP = { 44 | 'vocab_file': 45 | { 46 | 'gpt2': "https://s3.amazonaws.com/models.huggingface.co/bert/gpt2-vocab.json", 47 | 'gpt2-medium': "https://s3.amazonaws.com/models.huggingface.co/bert/gpt2-medium-vocab.json", 48 | 'gpt2-large': "https://s3.amazonaws.com/models.huggingface.co/bert/gpt2-large-vocab.json", 49 | 'distilgpt2': "https://s3.amazonaws.com/models.huggingface.co/bert/distilgpt2-vocab.json", 50 | }, 51 | 'merges_file': 52 | { 53 | 'gpt2': "https://s3.amazonaws.com/models.huggingface.co/bert/gpt2-merges.txt", 54 | 'gpt2-medium': "https://s3.amazonaws.com/models.huggingface.co/bert/gpt2-medium-merges.txt", 55 | 'gpt2-large': "https://s3.amazonaws.com/models.huggingface.co/bert/gpt2-large-merges.txt", 56 | 'distilgpt2': "https://s3.amazonaws.com/models.huggingface.co/bert/distilgpt2-merges.txt", 57 | }, 58 | } 59 | 60 | PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES = { 61 | 'gpt2': 1024, 62 | 'gpt2-medium': 1024, 63 | 'gpt2-large': 1024, 64 | 'distilgpt2': 1024, 65 | } 66 | 67 | @lru_cache() 68 | def bytes_to_unicode(): 69 | """ 70 | Returns list of utf-8 byte and a mapping to unicode strings. 71 | We specifically avoids mapping to whitespace/control characters the bpe code barfs on. 72 | 73 | The reversible bpe codes work on unicode strings. 74 | This means you need a large # of unicode characters in your vocab if you want to avoid UNKs. 75 | When you're at something like a 10B token dataset you end up needing around 5K for decent coverage. 76 | This is a signficant percentage of your normal, say, 32K bpe vocab. 77 | To avoid that, we want lookup tables between utf-8 bytes and unicode strings. 78 | """ 79 | _chr = unichr if sys.version_info[0] == 2 else chr 80 | bs = list(range(ord("!"), ord("~")+1))+list(range(ord("¡"), ord("¬")+1))+list(range(ord("®"), ord("ÿ")+1)) 81 | cs = bs[:] 82 | n = 0 83 | for b in range(2**8): 84 | if b not in bs: 85 | bs.append(b) 86 | cs.append(2**8+n) 87 | n += 1 88 | cs = [_chr(n) for n in cs] 89 | return dict(zip(bs, cs)) 90 | 91 | def get_pairs(word): 92 | """Return set of symbol pairs in a word. 93 | 94 | Word is represented as tuple of symbols (symbols being variable-length strings). 95 | """ 96 | pairs = set() 97 | prev_char = word[0] 98 | for char in word[1:]: 99 | pairs.add((prev_char, char)) 100 | prev_char = char 101 | return pairs 102 | 103 | class GPT2Tokenizer(PreTrainedTokenizer): 104 | """ 105 | GPT-2 BPE tokenizer. Peculiarities: 106 | - Byte-level Byte-Pair-Encoding 107 | - Requires a space to start the input string => the encoding methods should be called with the 108 | ``add_prefix_space`` flag set to ``True``. 109 | Otherwise, this tokenizer ``encode`` and ``decode`` method will not conserve 110 | the absence of a space at the beginning of a string: `tokenizer.decode(tokenizer.encode("Hello")) = " Hello"` 111 | """ 112 | vocab_files_names = VOCAB_FILES_NAMES 113 | pretrained_vocab_files_map = PRETRAINED_VOCAB_FILES_MAP 114 | max_model_input_sizes = PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES 115 | 116 | def __init__(self, vocab_file, merges_file, errors='replace', unk_token="<|endoftext|>", 117 | bos_token="<|endoftext|>", eos_token="<|endoftext|>", **kwargs): 118 | super(GPT2Tokenizer, self).__init__(bos_token=bos_token, eos_token=eos_token, unk_token=unk_token, **kwargs) 119 | self.max_len_single_sentence = self.max_len # no default special tokens - you can update this value if you add special tokens 120 | self.max_len_sentences_pair = self.max_len # no default special tokens - you can update this value if you add special tokens 121 | 122 | self.encoder = json.load(open(vocab_file, encoding="utf-8")) 123 | self.decoder = {v: k for k, v in self.encoder.items()} 124 | self.errors = errors # how to handle errors in decoding 125 | self.byte_encoder = bytes_to_unicode() 126 | self.byte_decoder = {v: k for k, v in self.byte_encoder.items()} 127 | bpe_data = open(merges_file, encoding='utf-8').read().split('\n')[1:-1] 128 | bpe_merges = [tuple(merge.split()) for merge in bpe_data] 129 | self.bpe_ranks = dict(zip(bpe_merges, range(len(bpe_merges)))) 130 | self.cache = {} 131 | 132 | # Should haved added re.IGNORECASE so BPE merges can happen for capitalized versions of contractions 133 | self.pat = re.compile(r"""'s|'t|'re|'ve|'m|'ll|'d| ?\p{L}+| ?\p{N}+| ?[^\s\p{L}\p{N}]+|\s+(?!\S)|\s+""") 134 | 135 | @property 136 | def vocab_size(self): 137 | return len(self.encoder) 138 | 139 | def bpe(self, token): 140 | if token in self.cache: 141 | return self.cache[token] 142 | word = tuple(token) 143 | pairs = get_pairs(word) 144 | 145 | if not pairs: 146 | return token 147 | 148 | while True: 149 | bigram = min(pairs, key = lambda pair: self.bpe_ranks.get(pair, float('inf'))) 150 | if bigram not in self.bpe_ranks: 151 | break 152 | first, second = bigram 153 | new_word = [] 154 | i = 0 155 | while i < len(word): 156 | try: 157 | j = word.index(first, i) 158 | new_word.extend(word[i:j]) 159 | i = j 160 | except: 161 | new_word.extend(word[i:]) 162 | break 163 | 164 | if word[i] == first and i < len(word)-1 and word[i+1] == second: 165 | new_word.append(first+second) 166 | i += 2 167 | else: 168 | new_word.append(word[i]) 169 | i += 1 170 | new_word = tuple(new_word) 171 | word = new_word 172 | if len(word) == 1: 173 | break 174 | else: 175 | pairs = get_pairs(word) 176 | word = ' '.join(word) 177 | self.cache[token] = word 178 | return word 179 | 180 | def _tokenize(self, text, add_prefix_space=False): 181 | """ Tokenize a string. 182 | Args: 183 | - add_prefix_space (boolean, default False): 184 | Begin the sentence with at least one space toto get invariance to word order in GPT-2 (and RoBERTa) tokenizers. 185 | """ 186 | if add_prefix_space: 187 | text = ' ' + text 188 | 189 | bpe_tokens = [] 190 | for token in re.findall(self.pat, text): 191 | if sys.version_info[0] == 2: 192 | token = ''.join(self.byte_encoder[ord(b)] for b in token) # Maps all our bytes to unicode strings, avoiding controle tokens of the BPE (spaces in our case) 193 | else: 194 | token = ''.join(self.byte_encoder[b] for b in token.encode('utf-8')) # Maps all our bytes to unicode strings, avoiding controle tokens of the BPE (spaces in our case) 195 | bpe_tokens.extend(bpe_token for bpe_token in self.bpe(token).split(' ')) 196 | return bpe_tokens 197 | 198 | def _convert_token_to_id(self, token): 199 | """ Converts a token (str/unicode) in an id using the vocab. """ 200 | return self.encoder.get(token, self.encoder.get(self.unk_token)) 201 | 202 | def _convert_id_to_token(self, index): 203 | """Converts an index (integer) in a token (string/unicode) using the vocab.""" 204 | return self.decoder.get(index) 205 | 206 | def convert_tokens_to_string(self, tokens): 207 | """ Converts a sequence of tokens (string) in a single string. """ 208 | text = ''.join(tokens) 209 | text = bytearray([self.byte_decoder[c] for c in text]).decode('utf-8', errors=self.errors) 210 | return text 211 | 212 | def save_vocabulary(self, save_directory): 213 | """Save the tokenizer vocabulary and merge files to a directory.""" 214 | if not os.path.isdir(save_directory): 215 | logger.error("Vocabulary path ({}) should be a directory".format(save_directory)) 216 | return 217 | vocab_file = os.path.join(save_directory, VOCAB_FILES_NAMES['vocab_file']) 218 | merge_file = os.path.join(save_directory, VOCAB_FILES_NAMES['merges_file']) 219 | 220 | with open(vocab_file, 'w', encoding='utf-8') as f: 221 | f.write(json.dumps(self.encoder, ensure_ascii=False)) 222 | 223 | index = 0 224 | with open(merge_file, "w", encoding="utf-8") as writer: 225 | writer.write(u'#version: 0.2\n') 226 | for bpe_tokens, token_index in sorted(self.bpe_ranks.items(), key=lambda kv: kv[1]): 227 | if index != token_index: 228 | logger.warning("Saving vocabulary to {}: BPE merge indices are not consecutive." 229 | " Please check that the tokenizer is not corrupted!".format(merge_file)) 230 | index = token_index 231 | writer.write(' '.join(bpe_tokens) + u'\n') 232 | index += 1 233 | 234 | return vocab_file, merge_file --------------------------------------------------------------------------------