├── examples
├── test
├── __init__.py
├── save_models
│ ├── tt
│ ├── TBEHRT_Test__CUT0.bin
│ ├── TBEHRT_Test__CUT1.bin
│ ├── TBEHRT_Test__CUT2.bin
│ ├── TBEHRT_Test__CUT3.bin
│ └── TBEHRT_Test__CUT4.bin
├── test.parquet
├── TBEHRT_Test__CUT0.npz
├── TBEHRT_Test__CUT1.npz
├── TBEHRT_Test__CUT2.npz
├── TBEHRT_Test__CUT3.npz
├── TBEHRT_Test__CUT4.npz
└── CVTMLE_example.ipynb
├── screenshot.png
├── .idea
├── misc.xml
├── vcs.xml
├── .gitignore
├── inspectionProfiles
│ └── profiles_settings.xml
├── modules.xml
└── Targeted_BEHRT.iml
├── LICENSE
├── README.md
├── pytorch_pretrained_bert
├── __init__.py
├── convert_tf_checkpoint_to_pytorch.py
├── convert_gpt2_checkpoint_to_pytorch.py
├── convert_openai_checkpoint_to_pytorch.py
├── optimizer.py
├── module.py
├── __main__.py
├── convert_transfo_xl_checkpoint_to_pytorch.py
├── optimization_openai.py
├── optimization.py
├── tokenization_gpt2.py
├── file_utils.py
├── tokenization_openai.py
├── modeling_transfo_xl_utilities.py
├── tokenization.py
└── tokenization_transfo_xl.py
└── src
├── data.py
├── vae.py
├── CV_TMLE.py
├── utils.py
└── model.py
/examples/test:
--------------------------------------------------------------------------------
1 |
2 |
--------------------------------------------------------------------------------
/examples/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/examples/save_models/tt:
--------------------------------------------------------------------------------
1 |
2 |
--------------------------------------------------------------------------------
/screenshot.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/deepmedicine/Targeted-BEHRT/HEAD/screenshot.png
--------------------------------------------------------------------------------
/examples/test.parquet:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/deepmedicine/Targeted-BEHRT/HEAD/examples/test.parquet
--------------------------------------------------------------------------------
/examples/TBEHRT_Test__CUT0.npz:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/deepmedicine/Targeted-BEHRT/HEAD/examples/TBEHRT_Test__CUT0.npz
--------------------------------------------------------------------------------
/examples/TBEHRT_Test__CUT1.npz:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/deepmedicine/Targeted-BEHRT/HEAD/examples/TBEHRT_Test__CUT1.npz
--------------------------------------------------------------------------------
/examples/TBEHRT_Test__CUT2.npz:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/deepmedicine/Targeted-BEHRT/HEAD/examples/TBEHRT_Test__CUT2.npz
--------------------------------------------------------------------------------
/examples/TBEHRT_Test__CUT3.npz:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/deepmedicine/Targeted-BEHRT/HEAD/examples/TBEHRT_Test__CUT3.npz
--------------------------------------------------------------------------------
/examples/TBEHRT_Test__CUT4.npz:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/deepmedicine/Targeted-BEHRT/HEAD/examples/TBEHRT_Test__CUT4.npz
--------------------------------------------------------------------------------
/examples/save_models/TBEHRT_Test__CUT0.bin:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/deepmedicine/Targeted-BEHRT/HEAD/examples/save_models/TBEHRT_Test__CUT0.bin
--------------------------------------------------------------------------------
/examples/save_models/TBEHRT_Test__CUT1.bin:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/deepmedicine/Targeted-BEHRT/HEAD/examples/save_models/TBEHRT_Test__CUT1.bin
--------------------------------------------------------------------------------
/examples/save_models/TBEHRT_Test__CUT2.bin:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/deepmedicine/Targeted-BEHRT/HEAD/examples/save_models/TBEHRT_Test__CUT2.bin
--------------------------------------------------------------------------------
/examples/save_models/TBEHRT_Test__CUT3.bin:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/deepmedicine/Targeted-BEHRT/HEAD/examples/save_models/TBEHRT_Test__CUT3.bin
--------------------------------------------------------------------------------
/examples/save_models/TBEHRT_Test__CUT4.bin:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/deepmedicine/Targeted-BEHRT/HEAD/examples/save_models/TBEHRT_Test__CUT4.bin
--------------------------------------------------------------------------------
/.idea/misc.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
--------------------------------------------------------------------------------
/.idea/vcs.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
--------------------------------------------------------------------------------
/.idea/.gitignore:
--------------------------------------------------------------------------------
1 | # Default ignored files
2 | /shelf/
3 | /workspace.xml
4 | # Datasource local storage ignored files
5 | /dataSources/
6 | /dataSources.local.xml
7 | # Editor-based HTTP Client requests
8 | /httpRequests/
9 |
--------------------------------------------------------------------------------
/.idea/inspectionProfiles/profiles_settings.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
--------------------------------------------------------------------------------
/.idea/modules.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
--------------------------------------------------------------------------------
/.idea/Targeted_BEHRT.iml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
9 |
10 |
11 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2022 Shishir Rao
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # Targeted BEHRT
2 | Repository for publication: Targeted-BEHRT: Deep Learning for Observational Causal Inference on Longitudinal Electronic Health Records
3 | IEEE Transactions on Neural Networks and Learning Systems; Special Issue on Causality
4 | https://ieeexplore.ieee.org/document/9804397/
5 | DOI: 10.1109/TNNLS.2022.3183864.
6 |
7 | 
8 |
9 | How to use:
10 | In "examples" folder, run the "run_TBEHRT.ipynb" file. A test.csv file is provided to test/play and demonstrate how the vocabulary/year/age/etc function (please read full paper linked above for further methodological details).
11 | Furthermoree, in the examples folder to run the CV-TMLE estimator, run the "CVTMLE_example.ipynb" file. A host of fake fold data is provided to test/play and demonstrate how the CV-TMLE algorithm works (please read methods publication of CV-TMLE for further details).
12 |
13 | The files in the "src" folder contain model and data handling packages in addition to other necessary VAE relevant files and helper functions.
14 |
15 | Requirements:
16 | torch >1.6.0
17 | numpy 1.19.2
18 | sklearn 0.23.2
19 | pandas 1.1.3
20 |
21 |
--------------------------------------------------------------------------------
/pytorch_pretrained_bert/__init__.py:
--------------------------------------------------------------------------------
1 | __version__ = "0.6.1"
2 | # from .tokenization import BertTokenizer, BasicTokenizer, WordpieceTokenizer
3 | # from .tokenization_openai import OpenAIGPTTokenizer
4 | # from .tokenization_transfo_xl import (TransfoXLTokenizer, TransfoXLCorpus)
5 | # from .tokenization_gpt2 import GPT2Tokenizer
6 |
7 | from .modeling import (BertConfig, BertModel, BertForPreTraining,
8 | BertForMaskedLM, BertForNextSentencePrediction,
9 | BertForSequenceClassification, BertForMultipleChoice,
10 | BertForTokenClassification, BertForQuestionAnswering,
11 | load_tf_weights_in_bert)
12 |
13 |
14 | from .modeling_openai import (OpenAIGPTConfig, OpenAIGPTModel,
15 | OpenAIGPTLMHeadModel, OpenAIGPTDoubleHeadsModel,
16 | load_tf_weights_in_openai_gpt)
17 | from .modeling_transfo_xl import (TransfoXLConfig, TransfoXLModel, TransfoXLLMHeadModel,
18 | load_tf_weights_in_transfo_xl)
19 | from .modeling_gpt2 import (GPT2Config, GPT2Model,
20 | GPT2LMHeadModel, GPT2DoubleHeadsModel,
21 | load_tf_weights_in_gpt2)
22 |
23 | from .optimization import BertAdam
24 | from .optimization_openai import OpenAIAdam
25 |
26 | from .file_utils import PYTORCH_PRETRAINED_BERT_CACHE, cached_path
27 |
--------------------------------------------------------------------------------
/pytorch_pretrained_bert/convert_tf_checkpoint_to_pytorch.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2018 The HuggingFace Inc. team.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 | """Convert BERT checkpoint."""
16 |
17 | from __future__ import absolute_import
18 | from __future__ import division
19 | from __future__ import print_function
20 |
21 | import os
22 | import re
23 | import argparse
24 | import tensorflow as tf
25 | import torch
26 | import numpy as np
27 |
28 | from pytorch_pretrained_bert.modeling import BertConfig, BertForPreTraining, load_tf_weights_in_bert
29 |
30 | def convert_tf_checkpoint_to_pytorch(tf_checkpoint_path, bert_config_file, pytorch_dump_path):
31 | # Initialise PyTorch model
32 | config = BertConfig.from_json_file(bert_config_file)
33 | print("Building PyTorch model from configuration: {}".format(str(config)))
34 | model = BertForPreTraining(config)
35 |
36 | # Load weights from tf checkpoint
37 | load_tf_weights_in_bert(model, tf_checkpoint_path)
38 |
39 | # Save pytorch-model
40 | print("Save PyTorch model to {}".format(pytorch_dump_path))
41 | torch.save(model.state_dict(), pytorch_dump_path)
42 |
43 |
44 | if __name__ == "__main__":
45 | parser = argparse.ArgumentParser()
46 | ## Required parameters
47 | parser.add_argument("--tf_checkpoint_path",
48 | default = None,
49 | type = str,
50 | required = True,
51 | help = "Path the TensorFlow checkpoint path.")
52 | parser.add_argument("--bert_config_file",
53 | default = None,
54 | type = str,
55 | required = True,
56 | help = "The config json file corresponding to the pre-trained BERT model. \n"
57 | "This specifies the model architecture.")
58 | parser.add_argument("--pytorch_dump_path",
59 | default = None,
60 | type = str,
61 | required = True,
62 | help = "Path to the output PyTorch model.")
63 | args = parser.parse_args()
64 | convert_tf_checkpoint_to_pytorch(args.tf_checkpoint_path,
65 | args.bert_config_file,
66 | args.pytorch_dump_path)
67 |
--------------------------------------------------------------------------------
/examples/CVTMLE_example.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "code",
5 | "execution_count": 3,
6 | "metadata": {},
7 | "outputs": [],
8 | "source": [
9 | "import sys\n",
10 | "import os\n",
11 | "sys.path.insert(0, '/home/rnshishir/deepmed/TBEHRT_pl/')\n",
12 | "\n",
13 | "import scipy\n",
14 | "import pandas as pd\n",
15 | "import numpy as np\n",
16 | "from src.CV_TMLE import *"
17 | ]
18 | },
19 | {
20 | "cell_type": "markdown",
21 | "metadata": {},
22 | "source": [
23 | "# CV TMLE tutorial"
24 | ]
25 | },
26 | {
27 | "cell_type": "code",
28 | "execution_count": 4,
29 | "metadata": {},
30 | "outputs": [
31 | {
32 | "name": "stdout",
33 | "output_type": "stream",
34 | "text": [
35 | "running CV-TMLE for binary outcomes...\n"
36 | ]
37 | }
38 | ],
39 | "source": [
40 | "# folds in the npz format\n",
41 | "foldNPZ = ['TBEHRT_Test__CUT0.npz', 'TBEHRT_Test__CUT1.npz', 'TBEHRT_Test__CUT2.npz', 'TBEHRT_Test__CUT3.npz', 'TBEHRT_Test__CUT4.npz' ]\n",
42 | "\n",
43 | "# cvtmle runner \n",
44 | "TMLErun = CVTMLE(fromFolds=foldNPZ,truncate_level=0.03 )\n",
45 | "\n",
46 | "# estiamte the risk ratio for binary outcome\n",
47 | "est = TMLErun.run_tmle_binary()"
48 | ]
49 | },
50 | {
51 | "cell_type": "code",
52 | "execution_count": 5,
53 | "metadata": {},
54 | "outputs": [
55 | {
56 | "data": {
57 | "text/plain": [
58 | "[0.10878099048487283, 5.2854239704810925e-08, 223885.61366048577]"
59 | ]
60 | },
61 | "execution_count": 5,
62 | "metadata": {},
63 | "output_type": "execute_result"
64 | }
65 | ],
66 | "source": [
67 | "est\n",
68 | "# prints estimate and lower and upper conf interval bounds"
69 | ]
70 | },
71 | {
72 | "cell_type": "code",
73 | "execution_count": 7,
74 | "metadata": {},
75 | "outputs": [],
76 | "source": [
77 | "data = pd.read_parquet('test.parquet')"
78 | ]
79 | },
80 | {
81 | "cell_type": "code",
82 | "execution_count": 11,
83 | "metadata": {},
84 | "outputs": [],
85 | "source": [
86 | "# raw\n",
87 | "# data[data.explabel ==1].label.mean()/data[data.explabel ==0].label.mean()"
88 | ]
89 | },
90 | {
91 | "cell_type": "code",
92 | "execution_count": null,
93 | "metadata": {},
94 | "outputs": [],
95 | "source": []
96 | }
97 | ],
98 | "metadata": {
99 | "kernelspec": {
100 | "display_name": "real3",
101 | "language": "python",
102 | "name": "py3"
103 | },
104 | "language_info": {
105 | "codemirror_mode": {
106 | "name": "ipython",
107 | "version": 3
108 | },
109 | "file_extension": ".py",
110 | "mimetype": "text/x-python",
111 | "name": "python",
112 | "nbconvert_exporter": "python",
113 | "pygments_lexer": "ipython3",
114 | "version": "3.6.8"
115 | }
116 | },
117 | "nbformat": 4,
118 | "nbformat_minor": 4
119 | }
120 |
--------------------------------------------------------------------------------
/pytorch_pretrained_bert/convert_gpt2_checkpoint_to_pytorch.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2018 The HuggingFace Inc. team.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 | """Convert OpenAI GPT checkpoint."""
16 |
17 | from __future__ import absolute_import, division, print_function
18 |
19 | import argparse
20 | from io import open
21 |
22 | import torch
23 |
24 | from pytorch_pretrained_bert.modeling_gpt2 import (CONFIG_NAME, WEIGHTS_NAME,
25 | GPT2Config,
26 | GPT2Model,
27 | load_tf_weights_in_gpt2)
28 |
29 |
30 | def convert_gpt2_checkpoint_to_pytorch(gpt2_checkpoint_path, gpt2_config_file, pytorch_dump_folder_path):
31 | # Construct model
32 | if gpt2_config_file == "":
33 | config = GPT2Config()
34 | else:
35 | config = GPT2Config(gpt2_config_file)
36 | model = GPT2Model(config)
37 |
38 | # Load weights from numpy
39 | load_tf_weights_in_gpt2(model, gpt2_checkpoint_path)
40 |
41 | # Save pytorch-model
42 | pytorch_weights_dump_path = pytorch_dump_folder_path + '/' + WEIGHTS_NAME
43 | pytorch_config_dump_path = pytorch_dump_folder_path + '/' + CONFIG_NAME
44 | print("Save PyTorch model to {}".format(pytorch_weights_dump_path))
45 | torch.save(model.state_dict(), pytorch_weights_dump_path)
46 | print("Save configuration file to {}".format(pytorch_config_dump_path))
47 | with open(pytorch_config_dump_path, "w", encoding="utf-8") as f:
48 | f.write(config.to_json_string())
49 |
50 |
51 | if __name__ == "__main__":
52 | parser = argparse.ArgumentParser()
53 | ## Required parameters
54 | parser.add_argument("--gpt2_checkpoint_path",
55 | default = None,
56 | type = str,
57 | required = True,
58 | help = "Path the TensorFlow checkpoint path.")
59 | parser.add_argument("--pytorch_dump_folder_path",
60 | default = None,
61 | type = str,
62 | required = True,
63 | help = "Path to the output PyTorch model.")
64 | parser.add_argument("--gpt2_config_file",
65 | default = "",
66 | type = str,
67 | help = "An optional config json file corresponding to the pre-trained OpenAI model. \n"
68 | "This specifies the model architecture.")
69 | args = parser.parse_args()
70 | convert_gpt2_checkpoint_to_pytorch(args.gpt2_checkpoint_path,
71 | args.gpt2_config_file,
72 | args.pytorch_dump_folder_path)
73 |
--------------------------------------------------------------------------------
/pytorch_pretrained_bert/convert_openai_checkpoint_to_pytorch.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2018 The HuggingFace Inc. team.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 | """Convert OpenAI GPT checkpoint."""
16 |
17 | from __future__ import absolute_import, division, print_function
18 |
19 | import argparse
20 | from io import open
21 |
22 | import torch
23 |
24 | from pytorch_pretrained_bert.modeling_openai import (CONFIG_NAME, WEIGHTS_NAME,
25 | OpenAIGPTConfig,
26 | OpenAIGPTModel,
27 | load_tf_weights_in_openai_gpt)
28 |
29 |
30 | def convert_openai_checkpoint_to_pytorch(openai_checkpoint_folder_path, openai_config_file, pytorch_dump_folder_path):
31 | # Construct model
32 | if openai_config_file == "":
33 | config = OpenAIGPTConfig()
34 | else:
35 | config = OpenAIGPTConfig(openai_config_file)
36 | model = OpenAIGPTModel(config)
37 |
38 | # Load weights from numpy
39 | load_tf_weights_in_openai_gpt(model, openai_checkpoint_folder_path)
40 |
41 | # Save pytorch-model
42 | pytorch_weights_dump_path = pytorch_dump_folder_path + '/' + WEIGHTS_NAME
43 | pytorch_config_dump_path = pytorch_dump_folder_path + '/' + CONFIG_NAME
44 | print("Save PyTorch model to {}".format(pytorch_weights_dump_path))
45 | torch.save(model.state_dict(), pytorch_weights_dump_path)
46 | print("Save configuration file to {}".format(pytorch_config_dump_path))
47 | with open(pytorch_config_dump_path, "w", encoding="utf-8") as f:
48 | f.write(config.to_json_string())
49 |
50 |
51 | if __name__ == "__main__":
52 | parser = argparse.ArgumentParser()
53 | ## Required parameters
54 | parser.add_argument("--openai_checkpoint_folder_path",
55 | default = None,
56 | type = str,
57 | required = True,
58 | help = "Path the TensorFlow checkpoint path.")
59 | parser.add_argument("--pytorch_dump_folder_path",
60 | default = None,
61 | type = str,
62 | required = True,
63 | help = "Path to the output PyTorch model.")
64 | parser.add_argument("--openai_config_file",
65 | default = "",
66 | type = str,
67 | help = "An optional config json file corresponding to the pre-trained OpenAI model. \n"
68 | "This specifies the model architecture.")
69 | args = parser.parse_args()
70 | convert_openai_checkpoint_to_pytorch(args.openai_checkpoint_folder_path,
71 | args.openai_config_file,
72 | args.pytorch_dump_folder_path)
73 |
--------------------------------------------------------------------------------
/pytorch_pretrained_bert/optimizer.py:
--------------------------------------------------------------------------------
1 | import pytorch_pretrained_bert as Bert
2 |
3 | def VAEadam(params, config=None):
4 | if config is None:
5 | config = {
6 | 'lr': 3e-5,
7 | 'warmup_proportion': 0.1
8 | }
9 | no_decay = ['bias', 'LayerNorm.bias', 'LayerNorm.weight', 'Eps']
10 | vae= ['VAE']
11 | no_decayfull = no_decay+vae
12 | # print( {'params': [n for n, p in params if not any(nd in n for nd in no_decayfull)], 'weight_decay': 0.01, 'lr': config['lr']},
13 | # {'params': [n for n, p in params if any(nd in n for nd in no_decay)], 'weight_decay': 0.0, 'lr': config['lr']},
14 | # {'params': [n for n, p in params if any(nd in n for nd in vae)], 'weight_decay': 0.0, 'lr':1e-3 }
15 | # )
16 | optimizer_grouped_parameters = [
17 | {'params': [p for n, p in params if not any(nd in n for nd in no_decayfull)], 'weight_decay': 0.01, 'lr': config['lr']},
18 | {'params': [p for n, p in params if (any(nd in n for nd in no_decay) and 'VAE' not in n)], 'weight_decay': 0.0, 'lr': config['lr']},
19 | {'params': [p for n, p in params if any(nd in n for nd in vae)], 'weight_decay': 0.0, 'lr':1e-3 }
20 |
21 | ]
22 |
23 | optim = Bert.optimization.BertAdam(optimizer_grouped_parameters,
24 | warmup=config['warmup_proportion'])
25 | return optim
26 |
27 |
28 | def adam(params, config=None):
29 | if config is None:
30 | config = {
31 | 'lr': 3e-5,
32 | 'warmup_proportion': 0.1
33 | }
34 | no_decay = ['bias', 'LayerNorm.bias', 'LayerNorm.weight', 'Eps','VAE']
35 | optimizer_grouped_parameters = [
36 | {'params': [p for n, p in params if not any(nd in n for nd in no_decay)], 'weight_decay': 0.01},
37 | {'params': [p for n, p in params if any(nd in n for nd in no_decay)], 'weight_decay': 0.0}
38 | ]
39 |
40 | optim = Bert.optimization.BertAdam(optimizer_grouped_parameters,
41 | lr=config['lr'],
42 | warmup=config['warmup_proportion'])
43 | return optim
44 |
45 | def GPadam(params, gpLR, config=None):
46 | if config is None:
47 | config = {
48 | 'lr': 3e-5,
49 | 'warmup_proportion': 0.1
50 | }
51 | no_decay = ['bias', 'LayerNorm.bias', 'LayerNorm.weight','Eps']
52 | gp = ['GP']
53 |
54 |
55 |
56 | optimizer_grouped_parameters = [
57 | {'params': [p for n, p in params if not any(nd in n for nd in no_decay) and not any(nd in n for nd in gp)], 'weight_decay': 0.01 , 'lr': config['lr'], 'warmup_proportion': 0.1},
58 | {'params': [p for n, p in params if any(nd in n for nd in no_decay) and not any(nd in n for nd in gp)], 'weight_decay': 0.0, 'lr': config['lr'], 'warmup_proportion': 0.1},
59 | {'params': [p for n, p in params if any(nd in n for nd in gp)], 'lr': gpLR}
60 | ]
61 |
62 | print([
63 | {'params': [n for n, p in params if not any(nd in n for nd in no_decay) and not any(nd in n for nd in gp)], 'weight_decay': 0.01 , 'lr': config['lr'], 'warmup_proportion': 0.1},
64 | {'params': [n for n, p in params if any(nd in n for nd in no_decay) and not any(nd in n for nd in gp)], 'weight_decay': 0.0, 'lr': config['lr'], 'warmup_proportion': 0.1},
65 | {'params': [n for n, p in params if any(nd in n for nd in gp)], 'lr': gpLR}
66 | ])
67 | optim = Bert.optimization.BertAdam(optimizer_grouped_parameters)
68 | return optim
69 |
--------------------------------------------------------------------------------
/pytorch_pretrained_bert/module.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import pytorch_pretrained_bert as Bert
4 |
5 |
6 | def sequence_mask(sequence_length, max_len=None, device=None):
7 | sequence_length = torch.tensor(sequence_length)
8 | max_len = torch.tensor(max_len)
9 | if max_len is None:
10 | max_len = sequence_length.data.max()
11 | batch_size = sequence_length.size(0)
12 | seq_range = torch.arange(0, max_len).long()
13 | seq_range_expand = seq_range.unsqueeze(0).expand(batch_size, max_len)
14 |
15 | if sequence_length.is_cuda:
16 | seq_range_expand = seq_range_expand.to(device)
17 | seq_length_expand = (sequence_length.unsqueeze(1).expand_as(seq_range_expand))
18 | mask= seq_range_expand < seq_length_expand
19 | return mask.detach().long()
20 |
21 |
22 | class BertEmbeddings(nn.Module):
23 | """Construct the embeddings from word, segment, age
24 | """
25 |
26 | def __init__(self, config):
27 | super(BertEmbeddings, self).__init__()
28 | self.word_embeddings = nn.Embedding(config.vocab_size, config.hidden_size)
29 | self.segment_embeddings = nn.Embedding(config.seg_vocab_size, config.hidden_size)
30 | self.age_embeddings = nn.Embedding(config.age_vocab_size, config.hidden_size)
31 |
32 | self.LayerNorm = Bert.modeling.BertLayerNorm(config.hidden_size, eps=1e-12)
33 | self.dropout = nn.Dropout(config.hidden_dropout_prob)
34 |
35 | def forward(self, word_ids, age_ids=None, seg_ids=None):
36 | if seg_ids is None:
37 | seg_ids = torch.zeros_like(word_ids)
38 | if age_ids is None:
39 | age_ids = torch.zeros_like(word_ids)
40 |
41 | word_embed = self.word_embeddings(word_ids)
42 | segment_embed = self.segment_embeddings(seg_ids)
43 | age_embed = self.age_embeddings(age_ids)
44 |
45 | embeddings = word_embed + segment_embed + age_embed
46 | embeddings = self.LayerNorm(embeddings)
47 | embeddings = self.dropout(embeddings)
48 | return embeddings
49 |
50 |
51 | class BertModel(Bert.modeling.BertPreTrainedModel):
52 | def __init__(self, config):
53 | super(BertModel, self).__init__(config)
54 | self.embeddings = BertEmbeddings(config=config)
55 | self.encoder = Bert.modeling.BertEncoder(config=config)
56 | self.pooler = Bert.modeling.BertPooler(config)
57 | self.apply(self.init_bert_weights)
58 |
59 | def forward(self, input_ids, age_ids=None, seg_ids=None, attention_mask=None, output_all_encoded_layers=True):
60 | if attention_mask is None:
61 | attention_mask = torch.ones_like(input_ids)
62 | if age_ids is None:
63 | age_ids = torch.zeros_like(input_ids)
64 | if seg_ids is None:
65 | seg_ids = torch.zeros_like(input_ids)
66 |
67 | # We create a 3D attention mask from a 2D tensor mask.
68 | # Sizes are [batch_size, 1, 1, to_seq_length]
69 | # So we can broadcast to [batch_size, num_heads, from_seq_length, to_seq_length]
70 | # this attention mask is more simple than the triangular masking of causal attention
71 | # used in OpenAI GPT, we just need to prepare the broadcast dimension here.
72 | extended_attention_mask = attention_mask.unsqueeze(1).unsqueeze(2)
73 |
74 | # Since attention_mask is 1.0 for positions we want to attend and 0.0 for
75 | # masked positions, this operation will create a tensor which is 0.0 for
76 | # positions we want to attend and -10000.0 for masked positions.
77 | # Since we are adding it to the raw scores before the softmax, this is
78 | # effectively the same as removing these entirely.
79 | extended_attention_mask = extended_attention_mask.to(dtype=next(self.parameters()).dtype) # fp16 compatibility
80 | extended_attention_mask = (1.0 - extended_attention_mask) * -10000.0
81 |
82 | embedding_output = self.embeddings(input_ids, age_ids, seg_ids)
83 | encoded_layers = self.encoder(embedding_output,
84 | extended_attention_mask,
85 | output_all_encoded_layers=output_all_encoded_layers)
86 | sequence_output = encoded_layers[-1]
87 | pooled_output = self.pooler(sequence_output)
88 | if not output_all_encoded_layers:
89 | encoded_layers = encoded_layers[-1]
90 | return encoded_layers, pooled_output
--------------------------------------------------------------------------------
/pytorch_pretrained_bert/__main__.py:
--------------------------------------------------------------------------------
1 | # coding: utf8
2 | def main():
3 | import sys
4 | if (len(sys.argv) != 4 and len(sys.argv) != 5) or sys.argv[1] not in [
5 | "convert_tf_checkpoint_to_pytorch",
6 | "convert_openai_checkpoint",
7 | "convert_transfo_xl_checkpoint",
8 | "convert_gpt2_checkpoint",
9 | ]:
10 | print(
11 | "Should be used as one of: \n"
12 | ">> `pytorch_pretrained_bert convert_tf_checkpoint_to_pytorch TF_CHECKPOINT TF_CONFIG PYTORCH_DUMP_OUTPUT`, \n"
13 | ">> `pytorch_pretrained_bert convert_openai_checkpoint OPENAI_GPT_CHECKPOINT_FOLDER_PATH PYTORCH_DUMP_OUTPUT [OPENAI_GPT_CONFIG]`, \n"
14 | ">> `pytorch_pretrained_bert convert_transfo_xl_checkpoint TF_CHECKPOINT_OR_DATASET PYTORCH_DUMP_OUTPUT [TF_CONFIG]` or \n"
15 | ">> `pytorch_pretrained_bert convert_gpt2_checkpoint TF_CHECKPOINT PYTORCH_DUMP_OUTPUT [GPT2_CONFIG]`")
16 | else:
17 | if sys.argv[1] == "convert_tf_checkpoint_to_pytorch":
18 | try:
19 | from .convert_tf_checkpoint_to_pytorch import convert_tf_checkpoint_to_pytorch
20 | except ImportError:
21 | print("pytorch_pretrained_bert can only be used from the commandline to convert TensorFlow models in PyTorch, "
22 | "In that case, it requires TensorFlow to be installed. Please see "
23 | "https://www.tensorflow.org/install/ for installation instructions.")
24 | raise
25 |
26 | if len(sys.argv) != 5:
27 | # pylint: disable=line-too-long
28 | print("Should be used as `pytorch_pretrained_bert convert_tf_checkpoint_to_pytorch TF_CHECKPOINT TF_CONFIG PYTORCH_DUMP_OUTPUT`")
29 | else:
30 | PYTORCH_DUMP_OUTPUT = sys.argv.pop()
31 | TF_CONFIG = sys.argv.pop()
32 | TF_CHECKPOINT = sys.argv.pop()
33 | convert_tf_checkpoint_to_pytorch(TF_CHECKPOINT, TF_CONFIG, PYTORCH_DUMP_OUTPUT)
34 | elif sys.argv[1] == "convert_openai_checkpoint":
35 | from .convert_openai_checkpoint_to_pytorch import convert_openai_checkpoint_to_pytorch
36 | OPENAI_GPT_CHECKPOINT_FOLDER_PATH = sys.argv[2]
37 | PYTORCH_DUMP_OUTPUT = sys.argv[3]
38 | if len(sys.argv) == 5:
39 | OPENAI_GPT_CONFIG = sys.argv[4]
40 | else:
41 | OPENAI_GPT_CONFIG = ""
42 | convert_openai_checkpoint_to_pytorch(OPENAI_GPT_CHECKPOINT_FOLDER_PATH,
43 | OPENAI_GPT_CONFIG,
44 | PYTORCH_DUMP_OUTPUT)
45 | elif sys.argv[1] == "convert_transfo_xl_checkpoint":
46 | try:
47 | from .convert_transfo_xl_checkpoint_to_pytorch import convert_transfo_xl_checkpoint_to_pytorch
48 | except ImportError:
49 | print("pytorch_pretrained_bert can only be used from the commandline to convert TensorFlow models in PyTorch, "
50 | "In that case, it requires TensorFlow to be installed. Please see "
51 | "https://www.tensorflow.org/install/ for installation instructions.")
52 | raise
53 |
54 | if 'ckpt' in sys.argv[2].lower():
55 | TF_CHECKPOINT = sys.argv[2]
56 | TF_DATASET_FILE = ""
57 | else:
58 | TF_DATASET_FILE = sys.argv[2]
59 | TF_CHECKPOINT = ""
60 | PYTORCH_DUMP_OUTPUT = sys.argv[3]
61 | if len(sys.argv) == 5:
62 | TF_CONFIG = sys.argv[4]
63 | else:
64 | TF_CONFIG = ""
65 | convert_transfo_xl_checkpoint_to_pytorch(TF_CHECKPOINT, TF_CONFIG, PYTORCH_DUMP_OUTPUT, TF_DATASET_FILE)
66 | else:
67 | try:
68 | from .convert_gpt2_checkpoint_to_pytorch import convert_gpt2_checkpoint_to_pytorch
69 | except ImportError:
70 | print("pytorch_pretrained_bert can only be used from the commandline to convert TensorFlow models in PyTorch, "
71 | "In that case, it requires TensorFlow to be installed. Please see "
72 | "https://www.tensorflow.org/install/ for installation instructions.")
73 | raise
74 |
75 | TF_CHECKPOINT = sys.argv[2]
76 | PYTORCH_DUMP_OUTPUT = sys.argv[3]
77 | if len(sys.argv) == 5:
78 | TF_CONFIG = sys.argv[4]
79 | else:
80 | TF_CONFIG = ""
81 | convert_gpt2_checkpoint_to_pytorch(TF_CHECKPOINT, TF_CONFIG, PYTORCH_DUMP_OUTPUT)
82 | if __name__ == '__main__':
83 | main()
84 |
--------------------------------------------------------------------------------
/src/data.py:
--------------------------------------------------------------------------------
1 | from torch.utils.data.dataset import Dataset
2 | import torch
3 | from src.utils import *
4 | # data for var autoencoder deep unsup learning with tbehrt
5 |
6 |
7 | class TBEHRT_data_formation(Dataset):
8 | def __init__(self, token2idx, dataframe, code= 'code', age = 'age', year = 'year' , static= 'static' , max_len=1000,expColumn='explabel', outcomeColumn='label', max_age=110, yvocab=None, list2avoid=None, MEM=True):
9 | """
10 | The dataset class for the pytorch coded model, Targeted BEHRT
11 |
12 | token2idx - the dict that maps tokens in EHR to numbers /index
13 | dataframe - the pandas dataframe that has the code,age,year, and any static columns
14 | code - name of code column
15 | age - name of age column
16 | year - name of year column
17 | static - name of static column
18 | max_len - length of sequence
19 | yvocab - the year vocab for the year based sequence of variables
20 | expColumn - the exposure column for dataframe
21 | outcomeColumn - the outcome column
22 | MEM - the masked EHR modelling flag for unsupervised learning
23 | list2avoid - list of tokens /diseases to not include in the MEM masking procedure
24 |
25 | """
26 |
27 | if list2avoid is None:
28 | self.acceptableVoc = token2idx
29 | else:
30 | self.acceptableVoc = {x: y for x, y in token2idx.items() if x not in list2avoid}
31 | print("old Vocab size: ", len(token2idx), ", and new Vocab size: ", len(self.acceptableVoc))
32 | self.vocab = token2idx
33 | self.max_len = max_len
34 | self.code = dataframe[code]
35 | self.age = dataframe[age]
36 | self.year = dataframe[year]
37 | if outcomeColumn is None:
38 | self.label = dataframe.deathLabel
39 | else:
40 | self.label = dataframe[outcomeColumn]
41 | self.age2idx, _ = age_vocab(110, year, symbol=None)
42 |
43 | if expColumn is None:
44 | self.treatmentLabel = dataframe.diseaseLabel
45 | else:
46 | self.treatmentLabel = dataframe[expColumn]
47 | self.year2idx = yvocab
48 | self.codeS = dataframe[static]
49 | self.MEM = MEM
50 | def __getitem__(self, index):
51 | """
52 | return: age, code, position, segmentation, mask, label
53 | """
54 |
55 | # extract data
56 |
57 | age = self.age[index]
58 |
59 | code = self.code[index]
60 | year = self.year[index]
61 |
62 | age = age[(-self.max_len + 1):]
63 | code = code[(-self.max_len + 1):]
64 | year = year[(-self.max_len + 1):]
65 |
66 |
67 | treatmentOutcome = torch.LongTensor([self.treatmentLabel[index]])
68 |
69 | # avoid data cut with first element to be 'SEP'
70 | labelOutcome = self.label[index]
71 |
72 |
73 | # moved CLS to end as opposed to beginning.
74 | code[-1] = 'CLS'
75 |
76 | mask = np.ones(self.max_len)
77 | mask[:-len(code)] = 0
78 | mask = np.append(np.array([1]), mask)
79 |
80 |
81 | tokensReal, code2 = code2index(code, self.vocab)
82 | # pad age sequence and code sequence
83 | year = seq_padding_reverse(year, self.max_len, token2idx=self.year2idx)
84 |
85 | age = seq_padding_reverse(age, self.max_len, token2idx=self.age2idx)
86 |
87 | if self.MEM == False:
88 | tokens, codeMLM, labelMLM = nonMASK(code, self.vocab)
89 | else:
90 | tokens, codeMLM, labelMLM = randommaskreal(code, self.acceptableVoc)
91 |
92 | # get position code and segment code
93 | tokens = seq_padding_reverse(tokens, self.max_len)
94 | position = position_idx(tokens)
95 | segment = index_seg(tokens)
96 |
97 | code2 = seq_padding_reverse(code2, self.max_len, symbol=self.vocab['PAD'])
98 |
99 | codeMLM = seq_padding_reverse(codeMLM, self.max_len, symbol=self.vocab['PAD'])
100 | labelMLM = seq_padding_reverse(labelMLM, self.max_len, symbol=-1)
101 |
102 | outCodeS = [int(xx) for xx in self.codeS[index]]
103 | fixedcovar = np.array(outCodeS )
104 | labelcovar = np.array(([-1] * len(outCodeS)) + [-1, -1])
105 | if self.MEM == True:
106 | fixedcovar, labelcovar = covarUnsupMaker(fixedcovar)
107 | code2 = np.append(fixedcovar, code2)
108 | codeMLM = np.append(fixedcovar, codeMLM)
109 |
110 |
111 |
112 | # code2 is the fixed static covariates while the codeMLM are the longutidunal one
113 | return torch.LongTensor(age), torch.LongTensor(code2), torch.LongTensor(codeMLM), torch.LongTensor(
114 | position), torch.LongTensor(segment), torch.LongTensor(year), \
115 | torch.LongTensor(mask), torch.LongTensor(labelMLM), torch.LongTensor(
116 | [labelOutcome]), treatmentOutcome, torch.LongTensor(labelcovar)
117 |
118 |
119 | def __len__(self):
120 | return len(self.code)
121 |
122 |
--------------------------------------------------------------------------------
/pytorch_pretrained_bert/convert_transfo_xl_checkpoint_to_pytorch.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2018 The HuggingFace Inc. team.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 | """Convert Transformer XL checkpoint and datasets."""
16 |
17 | from __future__ import absolute_import, division, print_function
18 |
19 | import argparse
20 | import os
21 | import sys
22 | from io import open
23 |
24 | import torch
25 |
26 | import pytorch_pretrained_bert.tokenization_transfo_xl as data_utils
27 | from pytorch_pretrained_bert.modeling_transfo_xl import (CONFIG_NAME,
28 | WEIGHTS_NAME,
29 | TransfoXLConfig,
30 | TransfoXLLMHeadModel,
31 | load_tf_weights_in_transfo_xl)
32 | from pytorch_pretrained_bert.tokenization_transfo_xl import (CORPUS_NAME,
33 | VOCAB_NAME)
34 |
35 | if sys.version_info[0] == 2:
36 | import cPickle as pickle
37 | else:
38 | import pickle
39 |
40 | # We do this to be able to load python 2 datasets pickles
41 | # See e.g. https://stackoverflow.com/questions/2121874/python-pickling-after-changing-a-modules-directory/2121918#2121918
42 | data_utils.Vocab = data_utils.TransfoXLTokenizer
43 | data_utils.Corpus = data_utils.TransfoXLCorpus
44 | sys.modules['data_utils'] = data_utils
45 | sys.modules['vocabulary'] = data_utils
46 |
47 | def convert_transfo_xl_checkpoint_to_pytorch(tf_checkpoint_path,
48 | transfo_xl_config_file,
49 | pytorch_dump_folder_path,
50 | transfo_xl_dataset_file):
51 | if transfo_xl_dataset_file:
52 | # Convert a pre-processed corpus (see original TensorFlow repo)
53 | with open(transfo_xl_dataset_file, "rb") as fp:
54 | corpus = pickle.load(fp, encoding="latin1")
55 | # Save vocabulary and dataset cache as Dictionaries (should be better than pickles for the long-term)
56 | pytorch_vocab_dump_path = pytorch_dump_folder_path + '/' + VOCAB_NAME
57 | print("Save vocabulary to {}".format(pytorch_vocab_dump_path))
58 | corpus_vocab_dict = corpus.vocab.__dict__
59 | torch.save(corpus_vocab_dict, pytorch_vocab_dump_path)
60 |
61 | corpus_dict_no_vocab = corpus.__dict__
62 | corpus_dict_no_vocab.pop('vocab', None)
63 | pytorch_dataset_dump_path = pytorch_dump_folder_path + '/' + CORPUS_NAME
64 | print("Save dataset to {}".format(pytorch_dataset_dump_path))
65 | torch.save(corpus_dict_no_vocab, pytorch_dataset_dump_path)
66 |
67 | if tf_checkpoint_path:
68 | # Convert a pre-trained TensorFlow model
69 | config_path = os.path.abspath(transfo_xl_config_file)
70 | tf_path = os.path.abspath(tf_checkpoint_path)
71 |
72 | print("Converting Transformer XL checkpoint from {} with config at {}".format(tf_path, config_path))
73 | # Initialise PyTorch model
74 | if transfo_xl_config_file == "":
75 | config = TransfoXLConfig()
76 | else:
77 | config = TransfoXLConfig(transfo_xl_config_file)
78 | print("Building PyTorch model from configuration: {}".format(str(config)))
79 | model = TransfoXLLMHeadModel(config)
80 |
81 | model = load_tf_weights_in_transfo_xl(model, config, tf_path)
82 | # Save pytorch-model
83 | pytorch_weights_dump_path = os.path.join(pytorch_dump_folder_path, WEIGHTS_NAME)
84 | pytorch_config_dump_path = os.path.join(pytorch_dump_folder_path, CONFIG_NAME)
85 | print("Save PyTorch model to {}".format(os.path.abspath(pytorch_weights_dump_path)))
86 | torch.save(model.state_dict(), pytorch_weights_dump_path)
87 | print("Save configuration file to {}".format(os.path.abspath(pytorch_config_dump_path)))
88 | with open(pytorch_config_dump_path, "w", encoding="utf-8") as f:
89 | f.write(config.to_json_string())
90 |
91 |
92 | if __name__ == "__main__":
93 | parser = argparse.ArgumentParser()
94 | parser.add_argument("--pytorch_dump_folder_path",
95 | default = None,
96 | type = str,
97 | required = True,
98 | help = "Path to the folder to store the PyTorch model or dataset/vocab.")
99 | parser.add_argument("--tf_checkpoint_path",
100 | default = "",
101 | type = str,
102 | help = "An optional path to a TensorFlow checkpoint path to be converted.")
103 | parser.add_argument("--transfo_xl_config_file",
104 | default = "",
105 | type = str,
106 | help = "An optional config json file corresponding to the pre-trained BERT model. \n"
107 | "This specifies the model architecture.")
108 | parser.add_argument("--transfo_xl_dataset_file",
109 | default = "",
110 | type = str,
111 | help = "An optional dataset file to be converted in a vocabulary.")
112 | args = parser.parse_args()
113 | convert_transfo_xl_checkpoint_to_pytorch(args.tf_checkpoint_path,
114 | args.transfo_xl_config_file,
115 | args.pytorch_dump_folder_path,
116 | args.transfo_xl_dataset_file)
117 |
--------------------------------------------------------------------------------
/src/vae.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import pytorch_pretrained_bert as Bert
4 |
5 | # borrowed boilerplate vae code from https://github.com/AntixK/PyTorch-VAE
6 |
7 |
8 | class VAE(Bert.modeling.BertPreTrainedModel):
9 | def __init__(self, config):
10 | super(VAE, self).__init__(config)
11 |
12 | self.unsuplist = config.unsupSize
13 |
14 |
15 |
16 |
17 | vaelatentdim = config.vaelatentdim
18 | vaeinchannels = config.vaeinchannels
19 |
20 | modules = []
21 | vaehidden = [config.poolingSize]
22 | self.linearFC = nn.Linear(config.hidden_size, config.poolingSize)
23 | self.activ = nn.ReLU()
24 |
25 | # Build Encoder
26 | self.fc_mu = nn.Linear(vaehidden[-1], vaelatentdim)
27 | self.fc_var = nn.Linear(vaehidden[-1], vaelatentdim)
28 |
29 |
30 | # Build Decoder
31 | modules = []
32 |
33 | self.decoder1 = nn.Linear(vaelatentdim, vaehidden[-1])
34 | self.decoder2 = nn.Linear(vaehidden[-1],int( vaehidden[-1]))
35 |
36 | self.logSoftmax = nn.LogSoftmax(dim=1)
37 |
38 | self.linearOut = nn.ModuleList([nn.Linear (int( vaehidden[-1]), el[0]) for el in self.unsuplist])
39 | self.BetaD = config.BetaD
40 |
41 | self.apply(self.init_bert_weights)
42 |
43 | def encode(self, input: torch.Tensor) :
44 | """
45 | Encodes the input by passing through the encoder network
46 | and returns the latent codes.
47 | :param input: (Tensor) Input tensor to encoder [N x C x H x W]
48 | :return: (Tensor) List of latent codes
49 | """
50 | # result = self.activ (self.linearFC(input))
51 |
52 | mu = self.fc_mu(input)
53 | log_var = self.fc_var(input)
54 |
55 | return [mu, log_var]
56 |
57 | def decode(self, z: torch.Tensor) -> torch.Tensor:
58 | """
59 | Maps the given latent codes
60 | onto the image space.
61 | :param z: (Tensor) [B x D]
62 | :return: (Tensor) [B x C x H x W]
63 | """
64 | result = self.activ(self.decoder1(z))
65 | result = self.activ(self.decoder2(result))
66 | outs = []
67 |
68 |
69 | for outputiter , linoutnetwork in enumerate(self.linearOut):
70 | resout = self.logSoftmax(linoutnetwork(result))
71 | outs.append(resout)
72 |
73 | outs = torch.cat((outs), dim=1)
74 | return outs
75 |
76 | def reparameterize(self, mu: torch.Tensor, logvar: torch.Tensor) -> torch.Tensor:
77 | """
78 | Reparameterization trick to sample from N(mu, var) from
79 | N(0,1).
80 | :param mu: (Tensor) Mean of the latent Gaussian [B x D]
81 | :param logvar: (Tensor) Standard deviation of the latent Gaussian [B x D]
82 | :return: (Tensor) [B x D]
83 | """
84 | std = torch.exp(0.5 * logvar)
85 | eps = torch.randn_like(std)
86 | return eps * std + mu
87 |
88 | def forward(self, input: torch.Tensor, label: torch.Tensor):
89 |
90 | if self.BetaD==False:
91 | mu, log_var = self.encode(input)
92 | z = self.reparameterize(mu, log_var)
93 | return [self.decode(z), label, mu, log_var]
94 | else:
95 | mu, log_var = self.encode(input)
96 | z = self.reparameterize(mu, log_var)
97 | return [self.decode(z), label, mu, log_var]
98 | def loss_function(self,dictout) -> dict:
99 | """
100 | Computes the VAE loss function.
101 | KL(N(\mu, \sigma), N(0, 1)) = \log \frac{1}{\sigma} + \frac{\sigma^2 + \mu^2}{2} - \frac{1}{2}
102 | :param args:
103 | :param kwargs:
104 | :return:
105 | """
106 | recons = dictout[0].transpose(1,0)
107 | input = dictout[1].transpose(1,0)
108 |
109 | mu = dictout[2]
110 | log_var = dictout[3]
111 | if self.BetaD==False:
112 |
113 | kld_weight = self.config.klpar # Account for the minibatch samples from the dataset
114 | reconsloss = 0
115 | startindx = 0
116 |
117 | outs = []
118 | labs = []
119 | for outputiter , output in enumerate(self.unsuplist):
120 | elementssize = output[0]
121 | chunkrecons = recons[startindx:startindx+elementssize].transpose(1,0)
122 | labels= input[outputiter]
123 | lossF = nn.NLLLoss(reduction='none', ignore_index=-1)
124 | temploss = lossF(chunkrecons,labels).sum()
125 | reconsloss =reconsloss+ temploss
126 |
127 | outs.append(chunkrecons)
128 | labs.append(labels)
129 | startindx = startindx+elementssize
130 |
131 |
132 | kld_loss = torch.sum(-0.5 * torch.sum(1 + log_var - mu ** 2 - log_var.exp(), dim = 1), dim = 0)
133 |
134 | loss = (reconsloss + kld_weight * kld_loss)/len(dictout[0])
135 |
136 | if self.config.klpar<1:
137 | self.config.klpar = self.config.klpar + 1e-5
138 |
139 | return {'loss': loss, 'Reconstruction_Loss':reconsloss, 'KLD':-kld_loss, 'outs':outs, 'labs':labs}
140 | else:
141 |
142 |
143 | return 0
144 |
145 | def sample(self,
146 | num_samples:int,
147 | current_device: int, **kwargs) -> torch.Tensor:
148 | """
149 | Samples from the latent space and return the corresponding
150 | image space map.
151 | :param num_samples: (Int) Number of samples
152 | :param current_device: (Int) Device to run the model
153 | :return: (Tensor)
154 | """
155 | z = torch.randn(num_samples,
156 | self.vaelatentdim)
157 |
158 | z = z.to(current_device)
159 |
160 | samples = self.decode(z)
161 | return samples
162 |
163 | def generate(self, x: torch.Tensor, **kwargs) -> torch.Tensor:
164 | """
165 | Given an input image x, returns the reconstructed image
166 | :param x: (Tensor) [B x C x H x W]
167 | :return: (Tensor) [B x C x H x W]
168 | """
169 |
170 | return self.forward(x)[0]
171 |
172 |
173 |
174 |
--------------------------------------------------------------------------------
/pytorch_pretrained_bert/optimization_openai.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2018 The Open AI Team Authors and The HuggingFace Inc. team.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 | """PyTorch optimization for OpenAI GPT model."""
16 |
17 | import math
18 | import torch
19 | from torch.optim import Optimizer
20 | from torch.optim.optimizer import required
21 | from torch.nn.utils import clip_grad_norm_
22 | import logging
23 |
24 | logger = logging.getLogger(__name__)
25 |
26 | def warmup_cosine(x, warmup=0.002):
27 | if x < warmup:
28 | return x/warmup
29 | x_ = (x - warmup) / (1 - warmup) # progress after warmup
30 | return 0.5 * (1. + math.cos(math.pi * x_))
31 |
32 | def warmup_constant(x, warmup=0.002):
33 | """ Linearly increases learning rate over `warmup`*`t_total` (as provided to OpenAIAdam) training steps.
34 | Learning rate is 1. afterwards. """
35 | if x < warmup:
36 | return x/warmup
37 | return 1.0
38 |
39 | def warmup_linear(x, warmup=0.002):
40 | """ Specifies a triangular learning rate schedule where peak is reached at `warmup`*`t_total`-th (as provided to OpenAIAdam) training step.
41 | After `t_total`-th training step, learning rate is zero. """
42 | if x < warmup:
43 | return x/warmup
44 | return max((x-1.)/(warmup-1.), 0)
45 |
46 | SCHEDULES = {
47 | 'warmup_cosine':warmup_cosine,
48 | 'warmup_constant':warmup_constant,
49 | 'warmup_linear':warmup_linear,
50 | }
51 |
52 |
53 | class OpenAIAdam(Optimizer):
54 | """Implements Open AI version of Adam algorithm with weight decay fix.
55 | """
56 | def __init__(self, params, lr=required, schedule='warmup_linear', warmup=-1, t_total=-1,
57 | b1=0.9, b2=0.999, e=1e-8, weight_decay=0,
58 | vector_l2=False, max_grad_norm=-1, **kwargs):
59 | if lr is not required and lr < 0.0:
60 | raise ValueError("Invalid learning rate: {} - should be >= 0.0".format(lr))
61 | if schedule not in SCHEDULES:
62 | raise ValueError("Invalid schedule parameter: {}".format(schedule))
63 | if not 0.0 <= warmup < 1.0 and not warmup == -1:
64 | raise ValueError("Invalid warmup: {} - should be in [0.0, 1.0[ or -1".format(warmup))
65 | if not 0.0 <= b1 < 1.0:
66 | raise ValueError("Invalid b1 parameter: {}".format(b1))
67 | if not 0.0 <= b2 < 1.0:
68 | raise ValueError("Invalid b2 parameter: {}".format(b2))
69 | if not e >= 0.0:
70 | raise ValueError("Invalid epsilon value: {}".format(e))
71 | defaults = dict(lr=lr, schedule=schedule, warmup=warmup, t_total=t_total,
72 | b1=b1, b2=b2, e=e, weight_decay=weight_decay, vector_l2=vector_l2,
73 | max_grad_norm=max_grad_norm)
74 | super(OpenAIAdam, self).__init__(params, defaults)
75 |
76 | def get_lr(self):
77 | lr = []
78 | for group in self.param_groups:
79 | for p in group['params']:
80 | state = self.state[p]
81 | if len(state) == 0:
82 | return [0]
83 | if group['t_total'] != -1:
84 | schedule_fct = SCHEDULES[group['schedule']]
85 | lr_scheduled = group['lr'] * schedule_fct(state['step']/group['t_total'], group['warmup'])
86 | else:
87 | lr_scheduled = group['lr']
88 | lr.append(lr_scheduled)
89 | return lr
90 |
91 | def step(self, closure=None):
92 | """Performs a single optimization step.
93 |
94 | Arguments:
95 | closure (callable, optional): A closure that reevaluates the model
96 | and returns the loss.
97 | """
98 | loss = None
99 | if closure is not None:
100 | loss = closure()
101 |
102 | warned_for_t_total = False
103 |
104 | for group in self.param_groups:
105 | for p in group['params']:
106 | if p.grad is None:
107 | continue
108 | grad = p.grad.data
109 | if grad.is_sparse:
110 | raise RuntimeError('Adam does not support sparse gradients, please consider SparseAdam instead')
111 |
112 | state = self.state[p]
113 |
114 | # State initialization
115 | if len(state) == 0:
116 | state['step'] = 0
117 | # Exponential moving average of gradient values
118 | state['exp_avg'] = torch.zeros_like(p.data)
119 | # Exponential moving average of squared gradient values
120 | state['exp_avg_sq'] = torch.zeros_like(p.data)
121 |
122 | exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq']
123 | beta1, beta2 = group['b1'], group['b2']
124 |
125 | state['step'] += 1
126 |
127 | # Add grad clipping
128 | if group['max_grad_norm'] > 0:
129 | clip_grad_norm_(p, group['max_grad_norm'])
130 |
131 | # Decay the first and second moment running average coefficient
132 | exp_avg.mul_(beta1).add_(1 - beta1, grad)
133 | exp_avg_sq.mul_(beta2).addcmul_(1 - beta2, grad, grad)
134 | denom = exp_avg_sq.sqrt().add_(group['e'])
135 |
136 | bias_correction1 = 1 - beta1 ** state['step']
137 | bias_correction2 = 1 - beta2 ** state['step']
138 |
139 | if group['t_total'] != -1:
140 | schedule_fct = SCHEDULES[group['schedule']]
141 | progress = state['step']/group['t_total']
142 | lr_scheduled = group['lr'] * schedule_fct(progress, group['warmup'])
143 | # warning for exceeding t_total (only active with warmup_linear
144 | if group['schedule'] == "warmup_linear" and progress > 1. and not warned_for_t_total:
145 | logger.warning(
146 | "Training beyond specified 't_total' steps with schedule '{}'. Learning rate set to {}. "
147 | "Please set 't_total' of {} correctly.".format(group['schedule'], lr_scheduled, self.__class__.__name__))
148 | warned_for_t_total = True
149 | # end warning
150 | else:
151 | lr_scheduled = group['lr']
152 |
153 | step_size = lr_scheduled * math.sqrt(bias_correction2) / bias_correction1
154 |
155 | p.data.addcdiv_(-step_size, exp_avg, denom)
156 |
157 | # Add weight decay at the end (fixed version)
158 | if (len(p.size()) > 1 or group['vector_l2']) and group['weight_decay'] > 0:
159 | p.data.add_(-lr_scheduled * group['weight_decay'], p.data)
160 |
161 | return loss
162 |
--------------------------------------------------------------------------------
/src/CV_TMLE.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | from scipy.special import logit, expit
3 | from scipy.optimize import minimize
4 |
5 | import numpy as np
6 | from scipy.special import logit
7 | import itertools
8 | import sklearn.linear_model as lm
9 |
10 | import numpy as np
11 | from sklearn.feature_extraction.text import CountVectorizer
12 |
13 | np.random.seed(0)
14 |
15 |
16 | class CVTMLE:
17 | def __init__(self, q_t0=None, q_t1=None, g=None, t=None, y=None, fromFolds=None, est_keys=None,
18 | truncate_level=0.05):
19 | """
20 | CVTMLE as conceived by Levi, 2018:
21 | Levy, Jonathan. "An easy implementation of CV-TMLE." arXiv preprint arXiv:1811.04573 (2018).
22 |
23 | :param q_t0: initial estimate with control exposure
24 | :param q_t1: initial estimate with treatment/non-control exposure
25 | :param g: prediction of propensity score
26 | :param t: treatment label
27 | :param y: factual outcome
28 | :param fromFolds: if files for estimates per fold are provided (type list) which are npz files, then no need to provide first five parameterse
29 | :param est_keys: once npz files are read, the keys are needed to extract estimates (i.e., first five parameters in this init)
30 | :param truncate_level: truncation for propensity scores (0.05 default means that only patients with estimates between 0.05 and 0.95 will be considered)
31 | """
32 |
33 |
34 | self.q_t0 = q_t0
35 | self.q_t1 = q_t1
36 | self.g = g
37 | self.t = t
38 | self.y = y
39 | self.est_keys = est_keys
40 | self.truncate_level = truncate_level
41 | if fromFolds is not None:
42 | self.q_t0, self.q_t1, self.y, self.g, self.t = self.collateFromFolds(fromFolds)
43 |
44 | def _perturbed_model_bin_outcome(self, q_t0, q_t1, g, t, eps):
45 | """
46 | Helper for psi_tmle_bin_outcome
47 |
48 | Returns q_\eps (t,x) and the h term
49 | (i.e., value of perturbed predictor at t, eps, x; where q_t0, q_t1, g are all evaluated at x
50 | """
51 | h = t * (1. / g) - (1. - t) / (1. - g)
52 | full_lq = (1. - t) * logit(q_t0) + t * logit(q_t1) # logit predictions from unperturbed model
53 | logit_perturb = full_lq + eps * h
54 | return expit(logit_perturb), h
55 |
56 | def run_tmle_binary(self):
57 | """
58 | This is for CV-TMLE on binary outcomes yielding risk ratio with 95% CI. Read Levi et al for methodological details.
59 | Influence curves coded from Gruber S, van der Laan, MJ. (2011).
60 |
61 | """
62 |
63 | print('running CV-TMLE for binary outcomes...')
64 | q_t0, q_t1, g, t, y, truncatel = np.copy(self.q_t0), np.copy(self.q_t1), np.copy(self.g), np.copy(
65 | self.t), np.copy(self.y), np.copy(self.truncate_level)
66 | q_t0, q_t1, g, t, y = self.truncate_all_by_g(q_t0, q_t1, g, t, y, truncatel)
67 |
68 | eps_hat = minimize(
69 | lambda eps: self.cross_entropy(y, self._perturbed_model_bin_outcome(q_t0, q_t1, g, t, eps)[0]), 0.,
70 | method='Nelder-Mead')
71 | eps_hat = eps_hat.x[0]
72 |
73 | def q1(t_cf):
74 | return self._perturbed_model_bin_outcome(q_t0, q_t1, g, t_cf, eps_hat)
75 |
76 | qall = ((1. - t) * (q_t0)) + (t * (q_t1)) # full predictions from unperturbed model
77 |
78 | qq1, h1 = q1(np.ones_like(t))
79 | qq0, h0 = q1(np.zeros_like(t))
80 | rr = np.mean(qq1) / np.mean(qq0)
81 |
82 | ic = (1 / np.mean(qq1) * (h1 * (y - qall) + qq1 - np.mean(qq1)) -
83 | (1 / np.mean(qq0)) * (-1 * h0 * (y - qall) + qq0 - np.mean(qq0)))
84 | psi_tmle_std = 1.96 * np.sqrt(np.var(ic) / (t.shape[0]))
85 |
86 | return [rr, np.exp(np.log(rr) - psi_tmle_std), np.exp(np.log(rr) + psi_tmle_std)]
87 |
88 | def run_tmle_continuous(self):
89 | """
90 | This is for CV-TMLE on continuous outcomes yielding ATE/MD with 95% CI. Read Levi et al for methodological details.
91 | Influence curves coded from Gruber S, van der Laan, MJ. (2011).
92 |
93 | """
94 | print('running CV-TMLE for continuous outcomes...')
95 |
96 | q_t0, q_t1, g, t, y, truncatel = np.copy(self.q_t0), np.copy(self.q_t1), np.copy(self.g), np.copy(
97 | self.t), np.copy(self.y), np.copy(self.truncate_level)
98 | q_t0, q_t1, g, t, y = self.truncate_all_by_g(q_t0, q_t1, g, t, y, truncatel)
99 |
100 | h = t * (1.0 / g) - (1.0 - t) / (1.0 - g)
101 | full_q = (1.0 - t) * q_t0 + t * q_t1
102 | eps_hat = np.sum(h * (y - full_q)) / np.sum(np.square(h))
103 |
104 | def q1(t_cf):
105 | h_cf = t_cf * (1.0 / g) - (1.0 - t_cf) / (1.0 - g)
106 | full_q = ((1.0 - t_cf) * q_t0) + (t_cf * q_t1)
107 | return full_q + eps_hat * h_cf, h_cf
108 |
109 | qq1, h_cf1 = q1(np.ones_like(t))
110 | qq0, h_cf0 = q1(np.zeros_like(t))
111 | haw = h_cf0 + h_cf1
112 |
113 | rd = np.mean(qq1 - qq0)
114 | ic = (haw * (y - full_q)) + (qq1 - qq0) - rd
115 | psi_tmle_std = 1.96 * np.sqrt(np.var(ic) / (t.shape[0]))
116 |
117 | return [rd, rd - psi_tmle_std, rd + psi_tmle_std]
118 |
119 | def truncate_by_g(self, attribute, g, level=0.1):
120 | keep_these = np.logical_and(g >= level, g <= 1. - level)
121 | return attribute[keep_these]
122 |
123 | def truncate_all_by_g(self, q_t0, q_t1, g, t, y, truncate_level=0.05):
124 | """
125 | Helper function to clean up nuisance parameter estimates.
126 | """
127 | orig_g = np.copy(g)
128 | q_t0 = self.truncate_by_g(np.copy(q_t0), orig_g, truncate_level)
129 | q_t1 = self.truncate_by_g(np.copy(q_t1), orig_g, truncate_level)
130 | g = self.truncate_by_g(np.copy(g), orig_g, truncate_level)
131 | t = self.truncate_by_g(np.copy(t), orig_g, truncate_level)
132 | y = self.truncate_by_g(np.copy(y), orig_g, truncate_level)
133 | return q_t0, q_t1, g, t, y
134 |
135 | def cross_entropy(self, y, p):
136 | return -np.mean((y * np.log(p) + (1. - y) * np.log(1. - p)))
137 |
138 | def collateFromFolds(self, foldNPZ):
139 | """
140 | FYI: keys can be provided but default is below
141 | est_keys = {
142 | 'treatment_label_key' : 'treatment_label',
143 | 'outcome_key' : 'outcome',
144 | 'treatment_pred_key' : 'treatment',
145 | 'outcome_label_key' : 'outcome_label'}
146 |
147 | """
148 | if self.est_keys is None:
149 | self.est_keys = {
150 | 'treatment_label_key': 'treatment_label',
151 | 'outcome_key': 'outcome',
152 | 'treatment_pred_key': 'treatment',
153 | 'outcome_label_key': 'outcome_label'}
154 | t_all = []
155 | q1_all = []
156 | q0_all = []
157 | g_all = []
158 | y_all = []
159 | for fold in foldNPZ:
160 | ld = np.load(fold)
161 | t_all.append(ld[self.est_keys['treatment_label_key']])
162 | q0_all.append(ld[self.est_keys['outcome_key']][:, 0])
163 | q1_all.append(ld[self.est_keys['outcome_key']][:, 1])
164 | y_all.append(ld[self.est_keys['outcome_label_key']])
165 | g_all.append(ld[self.est_keys['treatment_pred_key']][:, 1])
166 | t_all = np.array(list(itertools.chain(*t_all))).flatten()
167 | g_all = np.array(list(itertools.chain(*g_all))).flatten()
168 | q0_all = np.array(list(itertools.chain(*q0_all))).flatten()
169 | q1_all = np.array(list(itertools.chain(*q1_all))).flatten()
170 | y_all = np.array(list(itertools.chain(*y_all))).flatten()
171 | return q0_all, q1_all, y_all, g_all, t_all
172 |
--------------------------------------------------------------------------------
/pytorch_pretrained_bert/optimization.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2018 The Google AI Language Team Authors and The HuggingFace Inc. team.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 | """PyTorch optimization for BERT model."""
16 |
17 | import math
18 | import torch
19 | from torch.optim import Optimizer
20 | from torch.optim.optimizer import required
21 | from torch.nn.utils import clip_grad_norm_
22 | import logging
23 |
24 | logger = logging.getLogger(__name__)
25 |
26 | def warmup_cosine(x, warmup=0.002):
27 | if x < warmup:
28 | return x/warmup
29 | x_ = (x - warmup) / (1 - warmup) # progress after warmup -
30 | return 0.5 * (1. + math.cos(math.pi * x_))
31 |
32 | def warmup_constant(x, warmup=0.002):
33 | """ Linearly increases learning rate over `warmup`*`t_total` (as provided to BertAdam) training steps.
34 | Learning rate is 1. afterwards. """
35 | if x < warmup:
36 | return x/warmup
37 | return 1.0
38 |
39 | def warmup_linear(x, warmup=0.002):
40 | """ Specifies a triangular learning rate schedule where peak is reached at `warmup`*`t_total`-th (as provided to BertAdam) training step.
41 | After `t_total`-th training step, learning rate is zero. """
42 | if x < warmup:
43 | return x/warmup
44 | return max((x-1.)/(warmup-1.), 0)
45 |
46 | SCHEDULES = {
47 | 'warmup_cosine': warmup_cosine,
48 | 'warmup_constant': warmup_constant,
49 | 'warmup_linear': warmup_linear,
50 | }
51 |
52 |
53 | class BertAdam(Optimizer):
54 | """Implements BERT version of Adam algorithm with weight decay fix.
55 | Params:
56 | lr: learning rate
57 | warmup: portion of t_total for the warmup, -1 means no warmup. Default: -1
58 | t_total: total number of training steps for the learning
59 | rate schedule, -1 means constant learning rate. Default: -1
60 | schedule: schedule to use for the warmup (see above). Default: 'warmup_linear'
61 | b1: Adams b1. Default: 0.9
62 | b2: Adams b2. Default: 0.999
63 | e: Adams epsilon. Default: 1e-6
64 | weight_decay: Weight decay. Default: 0.01
65 | max_grad_norm: Maximum norm for the gradients (-1 means no clipping). Default: 1.0
66 | """
67 | def __init__(self, params, lr=required, warmup=-1, t_total=-1, schedule='warmup_linear',
68 | b1=0.9, b2=0.999, e=1e-6, weight_decay=0.01,
69 | max_grad_norm=1.0):
70 | if lr is not required and lr < 0.0:
71 | raise ValueError("Invalid learning rate: {} - should be >= 0.0".format(lr))
72 | if schedule not in SCHEDULES:
73 | raise ValueError("Invalid schedule parameter: {}".format(schedule))
74 | if not 0.0 <= warmup < 1.0 and not warmup == -1:
75 | raise ValueError("Invalid warmup: {} - should be in [0.0, 1.0[ or -1".format(warmup))
76 | if not 0.0 <= b1 < 1.0:
77 | raise ValueError("Invalid b1 parameter: {} - should be in [0.0, 1.0[".format(b1))
78 | if not 0.0 <= b2 < 1.0:
79 | raise ValueError("Invalid b2 parameter: {} - should be in [0.0, 1.0[".format(b2))
80 | if not e >= 0.0:
81 | raise ValueError("Invalid epsilon value: {} - should be >= 0.0".format(e))
82 | defaults = dict(lr=lr, schedule=schedule, warmup=warmup, t_total=t_total,
83 | b1=b1, b2=b2, e=e, weight_decay=weight_decay,
84 | max_grad_norm=max_grad_norm)
85 | super(BertAdam, self).__init__(params, defaults)
86 |
87 | def get_lr(self):
88 | lr = []
89 | for group in self.param_groups:
90 | for p in group['params']:
91 | state = self.state[p]
92 | if len(state) == 0:
93 | return [0]
94 | if group['t_total'] != -1:
95 | schedule_fct = SCHEDULES[group['schedule']]
96 | lr_scheduled = group['lr'] * schedule_fct(state['step']/group['t_total'], group['warmup'])
97 | else:
98 | lr_scheduled = group['lr']
99 | lr.append(lr_scheduled)
100 | return lr
101 |
102 | def step(self, closure=None):
103 | """Performs a single optimization step.
104 |
105 | Arguments:
106 | closure (callable, optional): A closure that reevaluates the model
107 | and returns the loss.
108 | """
109 | loss = None
110 | if closure is not None:
111 | loss = closure()
112 |
113 | warned_for_t_total = False
114 |
115 | for group in self.param_groups:
116 | for p in group['params']:
117 | if p.grad is None:
118 | continue
119 | grad = p.grad.data
120 | if grad.is_sparse:
121 | raise RuntimeError('Adam does not support sparse gradients, please consider SparseAdam instead')
122 |
123 | state = self.state[p]
124 |
125 | # State initialization
126 | if len(state) == 0:
127 | state['step'] = 0
128 | # Exponential moving average of gradient values
129 | state['next_m'] = torch.zeros_like(p.data)
130 | # Exponential moving average of squared gradient values
131 | state['next_v'] = torch.zeros_like(p.data)
132 |
133 | next_m, next_v = state['next_m'], state['next_v']
134 | beta1, beta2 = group['b1'], group['b2']
135 |
136 | # Add grad clipping
137 | if group['max_grad_norm'] > 0:
138 | clip_grad_norm_(p, group['max_grad_norm'])
139 |
140 | # Decay the first and second moment running average coefficient
141 | # In-place operations to update the averages at the same time
142 | next_m.mul_(beta1).add_(1 - beta1, grad)
143 | next_v.mul_(beta2).addcmul_(1 - beta2, grad, grad)
144 | update = next_m / (next_v.sqrt() + group['e'])
145 |
146 | # Just adding the square of the weights to the loss function is *not*
147 | # the correct way of using L2 regularization/weight decay with Adam,
148 | # since that will interact with the m and v parameters in strange ways.
149 | #
150 | # Instead we want to decay the weights in a manner that doesn't interact
151 | # with the m/v parameters. This is equivalent to adding the square
152 | # of the weights to the loss with plain (non-momentum) SGD.
153 | if group['weight_decay'] > 0.0:
154 | update += group['weight_decay'] * p.data
155 |
156 | if group['t_total'] != -1:
157 | schedule_fct = SCHEDULES[group['schedule']]
158 | progress = state['step']/group['t_total']
159 | lr_scheduled = group['lr'] * schedule_fct(progress, group['warmup'])
160 | # warning for exceeding t_total (only active with warmup_linear
161 | if group['schedule'] == "warmup_linear" and progress > 1. and not warned_for_t_total:
162 | logger.warning(
163 | "Training beyond specified 't_total' steps with schedule '{}'. Learning rate set to {}. "
164 | "Please set 't_total' of {} correctly.".format(group['schedule'], lr_scheduled, self.__class__.__name__))
165 | warned_for_t_total = True
166 | # end warning
167 | else:
168 | lr_scheduled = group['lr']
169 |
170 | update_with_lr = lr_scheduled * update
171 | p.data.add_(-update_with_lr)
172 |
173 | state['step'] += 1
174 |
175 | # step_size = lr_scheduled * math.sqrt(bias_correction2) / bias_correction1
176 | # No bias correction
177 | # bias_correction1 = 1 - beta1 ** state['step']
178 | # bias_correction2 = 1 - beta2 ** state['step']
179 |
180 | return loss
181 |
--------------------------------------------------------------------------------
/pytorch_pretrained_bert/tokenization_gpt2.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2018 The Open AI Team Authors and The HuggingFace Inc. team.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 | """Tokenization classes for OpenAI GPT."""
16 | from __future__ import (absolute_import, division, print_function,
17 | unicode_literals)
18 |
19 | import json
20 | import logging
21 | import os
22 | import regex as re
23 | from io import open
24 |
25 | try:
26 | from functools import lru_cache
27 | except ImportError:
28 | # Just a dummy decorator to get the checks to run on python2
29 | # because honestly I don't want to support a byte-level unicode BPE tokenizer on python 2 right now.
30 | def lru_cache():
31 | return lambda func: func
32 |
33 | from .file_utils import cached_path
34 |
35 | logger = logging.getLogger(__name__)
36 |
37 | PRETRAINED_VOCAB_ARCHIVE_MAP = {
38 | 'gpt2': "https://s3.amazonaws.com/models.huggingface.co/bert/gpt2-vocab.json",
39 | }
40 | PRETRAINED_MERGES_ARCHIVE_MAP = {
41 | 'gpt2': "https://s3.amazonaws.com/models.huggingface.co/bert/gpt2-merges.txt",
42 | }
43 | PRETRAINED_VOCAB_POSITIONAL_EMBEDDINGS_SIZE_MAP = {
44 | 'gpt2': 1024,
45 | }
46 | VOCAB_NAME = 'vocab.json'
47 | MERGES_NAME = 'merges.txt'
48 |
49 | @lru_cache()
50 | def bytes_to_unicode():
51 | """
52 | Returns list of utf-8 byte and a corresponding list of unicode strings.
53 | The reversible bpe codes work on unicode strings.
54 | This means you need a large # of unicode characters in your vocab if you want to avoid UNKs.
55 | When you're at something like a 10B token dataset you end up needing around 5K for decent coverage.
56 | This is a signficant percentage of your normal, say, 32K bpe vocab.
57 | To avoid that, we want lookup tables between utf-8 bytes and unicode strings.
58 | And avoids mapping to whitespace/control characters the bpe code barfs on.
59 | """
60 | bs = list(range(ord("!"), ord("~")+1))+list(range(ord("¡"), ord("¬")+1))+list(range(ord("®"), ord("ÿ")+1))
61 | cs = bs[:]
62 | n = 0
63 | for b in range(2**8):
64 | if b not in bs:
65 | bs.append(b)
66 | cs.append(2**8+n)
67 | n += 1
68 | cs = [chr(n) for n in cs]
69 | return dict(zip(bs, cs))
70 |
71 | def get_pairs(word):
72 | """Return set of symbol pairs in a word.
73 |
74 | Word is represented as tuple of symbols (symbols being variable-length strings).
75 | """
76 | pairs = set()
77 | prev_char = word[0]
78 | for char in word[1:]:
79 | pairs.add((prev_char, char))
80 | prev_char = char
81 | return pairs
82 |
83 | class GPT2Tokenizer(object):
84 | """
85 | GPT-2 BPE tokenizer. Peculiarities:
86 | - Byte-level BPE
87 | """
88 | @classmethod
89 | def from_pretrained(cls, pretrained_model_name_or_path, cache_dir=None, *inputs, **kwargs):
90 | """
91 | Instantiate a PreTrainedBertModel from a pre-trained model file.
92 | Download and cache the pre-trained model file if needed.
93 | """
94 | if pretrained_model_name_or_path in PRETRAINED_VOCAB_ARCHIVE_MAP:
95 | vocab_file = PRETRAINED_VOCAB_ARCHIVE_MAP[pretrained_model_name_or_path]
96 | merges_file = PRETRAINED_MERGES_ARCHIVE_MAP[pretrained_model_name_or_path]
97 | else:
98 | vocab_file = os.path.join(pretrained_model_name_or_path, VOCAB_NAME)
99 | merges_file = os.path.join(pretrained_model_name_or_path, MERGES_NAME)
100 | # redirect to the cache, if necessary
101 | try:
102 | resolved_vocab_file = cached_path(vocab_file, cache_dir=cache_dir)
103 | resolved_merges_file = cached_path(merges_file, cache_dir=cache_dir)
104 | except EnvironmentError:
105 | logger.error(
106 | "Model name '{}' was not found in model name list ({}). "
107 | "We assumed '{}' was a path or url but couldn't find files {} and {} "
108 | "at this path or url.".format(
109 | pretrained_model_name_or_path,
110 | ', '.join(PRETRAINED_VOCAB_ARCHIVE_MAP.keys()),
111 | pretrained_model_name_or_path,
112 | vocab_file, merges_file))
113 | return None
114 | if resolved_vocab_file == vocab_file and resolved_merges_file == merges_file:
115 | logger.info("loading vocabulary file {}".format(vocab_file))
116 | logger.info("loading merges file {}".format(merges_file))
117 | else:
118 | logger.info("loading vocabulary file {} from cache at {}".format(
119 | vocab_file, resolved_vocab_file))
120 | logger.info("loading merges file {} from cache at {}".format(
121 | merges_file, resolved_merges_file))
122 | if pretrained_model_name_or_path in PRETRAINED_VOCAB_POSITIONAL_EMBEDDINGS_SIZE_MAP:
123 | # if we're using a pretrained model, ensure the tokenizer wont index sequences longer
124 | # than the number of positional embeddings
125 | max_len = PRETRAINED_VOCAB_POSITIONAL_EMBEDDINGS_SIZE_MAP[pretrained_model_name_or_path]
126 | kwargs['max_len'] = min(kwargs.get('max_len', int(1e12)), max_len)
127 | # Instantiate tokenizer.
128 | tokenizer = cls(resolved_vocab_file, resolved_merges_file, *inputs, **kwargs)
129 | return tokenizer
130 |
131 | def __init__(self, vocab_file, merges_file, errors='replace', max_len=None):
132 | self.max_len = max_len if max_len is not None else int(1e12)
133 | self.encoder = json.load(open(vocab_file))
134 | self.decoder = {v:k for k,v in self.encoder.items()}
135 | self.errors = errors # how to handle errors in decoding
136 | self.byte_encoder = bytes_to_unicode()
137 | self.byte_decoder = {v:k for k, v in self.byte_encoder.items()}
138 | bpe_data = open(merges_file, encoding='utf-8').read().split('\n')[1:-1]
139 | bpe_merges = [tuple(merge.split()) for merge in bpe_data]
140 | self.bpe_ranks = dict(zip(bpe_merges, range(len(bpe_merges))))
141 | self.cache = {}
142 |
143 | # Should haved added re.IGNORECASE so BPE merges can happen for capitalized versions of contractions
144 | self.pat = re.compile(r"""'s|'t|'re|'ve|'m|'ll|'d| ?\p{L}+| ?\p{N}+| ?[^\s\p{L}\p{N}]+|\s+(?!\S)|\s+""")
145 |
146 | def __len__(self):
147 | return len(self.encoder)
148 |
149 | def bpe(self, token):
150 | if token in self.cache:
151 | return self.cache[token]
152 | word = tuple(token)
153 | pairs = get_pairs(word)
154 |
155 | if not pairs:
156 | return token
157 |
158 | while True:
159 | bigram = min(pairs, key = lambda pair: self.bpe_ranks.get(pair, float('inf')))
160 | if bigram not in self.bpe_ranks:
161 | break
162 | first, second = bigram
163 | new_word = []
164 | i = 0
165 | while i < len(word):
166 | try:
167 | j = word.index(first, i)
168 | new_word.extend(word[i:j])
169 | i = j
170 | except:
171 | new_word.extend(word[i:])
172 | break
173 |
174 | if word[i] == first and i < len(word)-1 and word[i+1] == second:
175 | new_word.append(first+second)
176 | i += 2
177 | else:
178 | new_word.append(word[i])
179 | i += 1
180 | new_word = tuple(new_word)
181 | word = new_word
182 | if len(word) == 1:
183 | break
184 | else:
185 | pairs = get_pairs(word)
186 | word = ' '.join(word)
187 | self.cache[token] = word
188 | return word
189 |
190 | def encode(self, text):
191 | bpe_tokens = []
192 | for token in re.findall(self.pat, text):
193 | token = ''.join(self.byte_encoder[b] for b in token.encode('utf-8'))
194 | bpe_tokens.extend(self.encoder[bpe_token] for bpe_token in self.bpe(token).split(' '))
195 | if len(bpe_tokens) > self.max_len:
196 | logger.warning(
197 | "Token indices sequence length is longer than the specified maximum "
198 | " sequence length for this OpenAI GPT-2 model ({} > {}). Running this"
199 | " sequence through the model will result in indexing errors".format(len(bpe_tokens), self.max_len)
200 | )
201 | return bpe_tokens
202 |
203 | def decode(self, tokens):
204 | text = ''.join([self.decoder[token] for token in tokens])
205 | text = bytearray([self.byte_decoder[c] for c in text]).decode('utf-8', errors=self.errors)
206 | return text
207 |
--------------------------------------------------------------------------------
/pytorch_pretrained_bert/file_utils.py:
--------------------------------------------------------------------------------
1 | """
2 | Utilities for working with the local dataset cache.
3 | This file is adapted from the AllenNLP library at https://github.com/allenai/allennlp
4 | Copyright by the AllenNLP authors.
5 | """
6 | from __future__ import (absolute_import, division, print_function, unicode_literals)
7 |
8 | import json
9 | import logging
10 | import os
11 | import shutil
12 | import tempfile
13 | from functools import wraps
14 | from hashlib import sha256
15 | import sys
16 | from io import open
17 |
18 | # import boto3
19 | # import requests
20 | # from botocore.exceptions import ClientError
21 | from tqdm import tqdm
22 |
23 | try:
24 | from urllib.parse import urlparse
25 | except ImportError:
26 | from urlparse import urlparse
27 |
28 | try:
29 | from pathlib import Path
30 | PYTORCH_PRETRAINED_BERT_CACHE = Path(os.getenv('PYTORCH_PRETRAINED_BERT_CACHE',
31 | Path.home() / '.pytorch_pretrained_bert'))
32 | except (AttributeError, ImportError):
33 | PYTORCH_PRETRAINED_BERT_CACHE = os.getenv('PYTORCH_PRETRAINED_BERT_CACHE',
34 | os.path.join(os.path.expanduser("~"), '.pytorch_pretrained_bert'))
35 |
36 | logger = logging.getLogger(__name__) # pylint: disable=invalid-name
37 |
38 |
39 | def url_to_filename(url, etag=None):
40 | """
41 | Convert `url` into a hashed filename in a repeatable way.
42 | If `etag` is specified, append its hash to the url's, delimited
43 | by a period.
44 | """
45 | url_bytes = url.encode('utf-8')
46 | url_hash = sha256(url_bytes)
47 | filename = url_hash.hexdigest()
48 |
49 | if etag:
50 | etag_bytes = etag.encode('utf-8')
51 | etag_hash = sha256(etag_bytes)
52 | filename += '.' + etag_hash.hexdigest()
53 |
54 | return filename
55 |
56 |
57 | def filename_to_url(filename, cache_dir=None):
58 | """
59 | Return the url and etag (which may be ``None``) stored for `filename`.
60 | Raise ``EnvironmentError`` if `filename` or its stored metadata do not exist.
61 | """
62 | if cache_dir is None:
63 | cache_dir = PYTORCH_PRETRAINED_BERT_CACHE
64 | if sys.version_info[0] == 3 and isinstance(cache_dir, Path):
65 | cache_dir = str(cache_dir)
66 |
67 | cache_path = os.path.join(cache_dir, filename)
68 | if not os.path.exists(cache_path):
69 | raise EnvironmentError("file {} not found".format(cache_path))
70 |
71 | meta_path = cache_path + '.json'
72 | if not os.path.exists(meta_path):
73 | raise EnvironmentError("file {} not found".format(meta_path))
74 |
75 | with open(meta_path, encoding="utf-8") as meta_file:
76 | metadata = json.load(meta_file)
77 | url = metadata['url']
78 | etag = metadata['etag']
79 |
80 | return url, etag
81 |
82 |
83 | def cached_path(url_or_filename, cache_dir=None):
84 | """
85 | Given something that might be a URL (or might be a local path),
86 | determine which. If it's a URL, download the file and cache it, and
87 | return the path to the cached file. If it's already a local path,
88 | make sure the file exists and then return the path.
89 | """
90 | if cache_dir is None:
91 | cache_dir = PYTORCH_PRETRAINED_BERT_CACHE
92 | if sys.version_info[0] == 3 and isinstance(url_or_filename, Path):
93 | url_or_filename = str(url_or_filename)
94 | if sys.version_info[0] == 3 and isinstance(cache_dir, Path):
95 | cache_dir = str(cache_dir)
96 |
97 | parsed = urlparse(url_or_filename)
98 |
99 | if parsed.scheme in ('http', 'https', 's3'):
100 | # URL, so get it from the cache (downloading if necessary)
101 | return get_from_cache(url_or_filename, cache_dir)
102 | elif os.path.exists(url_or_filename):
103 | # File, and it exists.
104 | return url_or_filename
105 | elif parsed.scheme == '':
106 | # File, but it doesn't exist.
107 | raise EnvironmentError("file {} not found".format(url_or_filename))
108 | else:
109 | # Something unknown
110 | raise ValueError("unable to parse {} as a URL or as a local path".format(url_or_filename))
111 |
112 |
113 | def split_s3_path(url):
114 | """Split a full s3 path into the bucket name and path."""
115 | parsed = urlparse(url)
116 | if not parsed.netloc or not parsed.path:
117 | raise ValueError("bad s3 path {}".format(url))
118 | bucket_name = parsed.netloc
119 | s3_path = parsed.path
120 | # Remove '/' at beginning of path.
121 | if s3_path.startswith("/"):
122 | s3_path = s3_path[1:]
123 | return bucket_name, s3_path
124 |
125 |
126 | def s3_request(func):
127 | """
128 | Wrapper function for s3 requests in order to create more helpful error
129 | messages.
130 | """
131 |
132 | @wraps(func)
133 | def wrapper(url, *args, **kwargs):
134 | try:
135 | return func(url, *args, **kwargs)
136 | except ClientError as exc:
137 | if int(exc.response["Error"]["Code"]) == 404:
138 | raise EnvironmentError("file {} not found".format(url))
139 | else:
140 | raise
141 |
142 | return wrapper
143 |
144 |
145 | @s3_request
146 | def s3_etag(url):
147 | """Check ETag on S3 object."""
148 | s3_resource = boto3.resource("s3")
149 | bucket_name, s3_path = split_s3_path(url)
150 | s3_object = s3_resource.Object(bucket_name, s3_path)
151 | return s3_object.e_tag
152 |
153 |
154 | @s3_request
155 | def s3_get(url, temp_file):
156 | """Pull a file directly from S3."""
157 | s3_resource = boto3.resource("s3")
158 | bucket_name, s3_path = split_s3_path(url)
159 | s3_resource.Bucket(bucket_name).download_fileobj(s3_path, temp_file)
160 |
161 |
162 | def http_get(url, temp_file):
163 | req = requests.get(url, stream=True)
164 | content_length = req.headers.get('Content-Length')
165 | total = int(content_length) if content_length is not None else None
166 | progress = tqdm(unit="B", total=total)
167 | for chunk in req.iter_content(chunk_size=1024):
168 | if chunk: # filter out keep-alive new chunks
169 | progress.update(len(chunk))
170 | temp_file.write(chunk)
171 | progress.close()
172 |
173 |
174 | def get_from_cache(url, cache_dir=None):
175 | """
176 | Given a URL, look for the corresponding dataset in the local cache.
177 | If it's not there, download it. Then return the path to the cached file.
178 | """
179 | if cache_dir is None:
180 | cache_dir = PYTORCH_PRETRAINED_BERT_CACHE
181 | if sys.version_info[0] == 3 and isinstance(cache_dir, Path):
182 | cache_dir = str(cache_dir)
183 |
184 | if not os.path.exists(cache_dir):
185 | os.makedirs(cache_dir)
186 |
187 | # Get eTag to add to filename, if it exists.
188 | if url.startswith("s3://"):
189 | etag = s3_etag(url)
190 | else:
191 | response = requests.head(url, allow_redirects=True)
192 | if response.status_code != 200:
193 | raise IOError("HEAD request failed for url {} with status code {}"
194 | .format(url, response.status_code))
195 | etag = response.headers.get("ETag")
196 |
197 | filename = url_to_filename(url, etag)
198 |
199 | # get cache path to put the file
200 | cache_path = os.path.join(cache_dir, filename)
201 |
202 | if not os.path.exists(cache_path):
203 | # Download to temporary file, then copy to cache dir once finished.
204 | # Otherwise you get corrupt cache entries if the download gets interrupted.
205 | with tempfile.NamedTemporaryFile() as temp_file:
206 | logger.info("%s not found in cache, downloading to %s", url, temp_file.name)
207 |
208 | # GET file object
209 | if url.startswith("s3://"):
210 | s3_get(url, temp_file)
211 | else:
212 | http_get(url, temp_file)
213 |
214 | # we are copying the file before closing it, so flush to avoid truncation
215 | temp_file.flush()
216 | # shutil.copyfileobj() starts at the current position, so go to the start
217 | temp_file.seek(0)
218 |
219 | logger.info("copying %s to cache at %s", temp_file.name, cache_path)
220 | with open(cache_path, 'wb') as cache_file:
221 | shutil.copyfileobj(temp_file, cache_file)
222 |
223 | logger.info("creating metadata file for %s", cache_path)
224 | meta = {'url': url, 'etag': etag}
225 | meta_path = cache_path + '.json'
226 | with open(meta_path, 'w', encoding="utf-8") as meta_file:
227 | json.dump(meta, meta_file)
228 |
229 | logger.info("removing temp file %s", temp_file.name)
230 |
231 | return cache_path
232 |
233 |
234 | def read_set_from_file(filename):
235 | '''
236 | Extract a de-duped collection (set) of text from a file.
237 | Expected file format is one item per line.
238 | '''
239 | collection = set()
240 | with open(filename, 'r', encoding='utf-8') as file_:
241 | for line in file_:
242 | collection.add(line.rstrip())
243 | return collection
244 |
245 |
246 | def get_file_extension(path, dot=True, lower=True):
247 | ext = os.path.splitext(path)[1]
248 | ext = ext if dot else ext[1:]
249 | return ext.lower() if lower else ext
250 |
--------------------------------------------------------------------------------
/src/utils.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import _pickle as pickle
3 | import random
4 | import torch.nn as nn
5 | import torch
6 | import os
7 | import sklearn
8 | import sklearn.metrics as skm
9 | import warnings
10 |
11 |
12 | def nonMASK(tokens, token2idx):
13 | output_label = []
14 | output_token = []
15 | for i, token in enumerate(tokens):
16 | prob = random.random()
17 | # mask token with 15% probability
18 | if prob < 0:
19 | prob /= 0.15
20 |
21 | # 80% randomly change token to mask token
22 | if prob < 0.8:
23 | output_token.append(token2idx["MASK"])
24 |
25 | # 10% randomly change token to random token
26 | elif prob < 0.9:
27 | output_token.append(random.choice(list(token2idx.values())))
28 |
29 | # -> rest 10% randomly keep current token
30 |
31 | # append current token to output (we will predict these later
32 | output_label.append(token2idx.get(token, token2idx['UNK']))
33 | else:
34 | # no masking token (will be ignored by loss function later)
35 | output_label.append(-1)
36 | output_token.append(token2idx.get(token, token2idx['UNK']))
37 |
38 | return tokens, output_token, output_label
39 |
40 |
41 | # static var masking
42 | def covarUnsupMaker(covar, covarprobb=0.4):
43 | inputcovar = []
44 | labelcovar = []
45 | for i,x in enumerate(covar):
46 | prob = random.random()
47 | if x != 0:
48 | if prob rest 10% randomly keep current token
86 | else:
87 | output_label.append(-1)
88 |
89 | # append current token to output (we will predict these later
90 | output_token.append(token2idx.get(token, token2idx['UNK']))
91 |
92 |
93 |
94 | else:
95 | # no masking token (will be ignored by loss function later)
96 | output_label.append(-1)
97 | output_token.append(token2idx.get(token, token2idx['UNK']))
98 |
99 | return tokens, output_token, output_label
100 |
101 |
102 |
103 | def save_obj(obj, name):
104 | with open(name + '.pkl', 'wb') as f:
105 | pickle.dump(obj, f)
106 |
107 |
108 | def load_obj(name):
109 | with open(name + '.pkl', 'rb') as f:
110 | return pickle.load(f)
111 |
112 |
113 | def code2index(tokens, token2idx):
114 | output_tokens = []
115 | for i, token in enumerate(tokens):
116 | output_tokens.append(token2idx.get(token, token2idx['UNK']))
117 | return tokens, output_tokens
118 |
119 |
120 |
121 |
122 | def index_seg(tokens, symbol='SEP'):
123 | flag = 0
124 | seg = []
125 |
126 | for token in tokens:
127 | if token == symbol:
128 | seg.append(flag)
129 | if flag == 0:
130 | flag = 1
131 | else:
132 | flag = 0
133 | else:
134 | seg.append(flag)
135 | return seg
136 |
137 |
138 | def position_idx(tokens, symbol='SEP'):
139 | pos = []
140 | flag = 0
141 |
142 | for token in tokens:
143 | if token == symbol:
144 | pos.append(flag)
145 | flag += 1
146 | else:
147 | pos.append(flag)
148 | return pos
149 |
150 |
151 | def age_vocab(max_age, year=False, symbol=None):
152 | age2idx = {}
153 | idx2age = {}
154 | if symbol is None:
155 | symbol = ['PAD', 'UNK']
156 |
157 | for i in range(len(symbol)):
158 | age2idx[str(symbol[i])] = i
159 | idx2age[i] = str(symbol[i])
160 |
161 | if year:
162 | for i in range(max_age):
163 | age2idx[str(i)] = len(symbol) + i
164 | idx2age[len(symbol) + i] = str(i)
165 | else:
166 | for i in range(max_age * 12):
167 | age2idx[str(i)] = len(symbol) + i
168 | idx2age[len(symbol) + i] = str(i)
169 |
170 | return age2idx, idx2age
171 |
172 |
173 | def seq_padding(tokens, max_len, token2idx=None, symbol=None):
174 | if symbol is None:
175 | symbol = 'PAD'
176 |
177 | seq = []
178 | token_len = len(tokens)
179 | for i in range(max_len):
180 | if token2idx is None:
181 | if i < token_len:
182 | seq.append(tokens[i])
183 | else:
184 | seq.append(symbol)
185 | else:
186 | if i < token_len:
187 | # 1 indicate UNK
188 | seq.append(token2idx.get(tokens[i], token2idx['UNK']))
189 | else:
190 | seq.append(token2idx.get(symbol))
191 | return seq
192 |
193 |
194 | def seq_padding_reverse(tokens, max_len, token2idx=None, symbol=None):
195 | if symbol is None:
196 | symbol = 'PAD'
197 |
198 | seq = []
199 | token_len = len(tokens)
200 | tokens = tokens[::-1]
201 | for i in range(max_len):
202 | if token2idx is None:
203 | if i < token_len:
204 | seq.append(tokens[i])
205 | else:
206 | seq.append(symbol)
207 | else:
208 | if i < token_len:
209 | # 1 indicate UNK
210 | seq.append(token2idx.get(tokens[i], token2idx['UNK']))
211 | else:
212 | seq.append(token2idx.get(symbol))
213 | return seq[::-1]
214 |
215 |
216 | def age_seq_padding(tokens, max_len, token2idx=None, symbol=None):
217 | if symbol is None:
218 | symbol = 'PAD'
219 |
220 | seq = []
221 | token_len = len(tokens)
222 | for i in range(max_len):
223 | if token2idx is None:
224 | if i < token_len:
225 | seq.append(tokens[i])
226 | else:
227 | seq.append(symbol)
228 | else:
229 | if i < token_len:
230 | # 1 indicate UNK
231 | seq.append(token2idx[tokens[i]])
232 | else:
233 | seq.append(token2idx[symbol])
234 | return seq
235 |
236 |
237 |
238 | def cal_acc(label, pred, logS=True):
239 | logs = nn.LogSoftmax()
240 | label = label.cpu().numpy()
241 | ind = np.where(label != -1)[0]
242 | truepred = pred.detach().cpu().numpy()
243 | truepred = truepred[ind]
244 | truelabel = label[ind]
245 | if logS == True:
246 | truepred = logs(torch.tensor(truepred))
247 | else:
248 | truepred = torch.tensor(truepred)
249 | outs = [np.argmax(pred_x) for pred_x in truepred.numpy()]
250 | precision = skm.precision_score(truelabel, outs, average='micro')
251 |
252 | return precision
253 |
254 | def cal_acc(label, pred, logS=True):
255 | logs = nn.LogSoftmax()
256 | label = label.cpu().numpy()
257 | ind = np.where(label != -1)[0]
258 | truepred = pred.detach().cpu().numpy()
259 | truepred = truepred[ind]
260 | truelabel = label[ind]
261 | if logS ==True:
262 | truepred = logs(torch.tensor(truepred))
263 | else:
264 | truepred = torch.tensor(truepred)
265 | outs = [np.argmax(pred_x) for pred_x in truepred.numpy()]
266 | precision = skm.precision_score(truelabel, outs, average='micro')
267 |
268 | return precision
269 |
270 | def partition(values, indices):
271 | idx = 0
272 | for index in indices:
273 | sublist = []
274 | idxfill = []
275 | while idx < len(values) and values[idx] <= index:
276 | # sublist.append(values[idx])
277 | idxfill.append(idx)
278 |
279 | idx += 1
280 | if idxfill:
281 | yield idxfill
282 |
283 |
284 | def toLoad(model, filepath, custom=None):
285 | pre_bert = filepath
286 |
287 | pretrained_dict = torch.load(pre_bert, map_location='cpu')
288 | modeld = model.state_dict()
289 | # 1. filter out unnecessary keys
290 | if custom == None:
291 | pretrained_dict = {k: v for k, v in pretrained_dict.items() if k in modeld}
292 | else:
293 | pretrained_dict = {k: v for k, v in pretrained_dict.items() if k in modeld and k not in custom}
294 |
295 | modeld.update(pretrained_dict)
296 | # 3. load the new state dict
297 | model.load_state_dict(modeld)
298 | return model
299 |
300 |
301 |
302 |
303 | def OutcomePrecision(logits, label, sig=True):
304 | sig = nn.Sigmoid()
305 | if sig == True:
306 | output = sig(logits)
307 | else:
308 | output = logits
309 | label, output = label.cpu(), output.detach().cpu()
310 | tempprc = sklearn.metrics.average_precision_score(label.numpy(), output.numpy())
311 | return tempprc, output, label
312 |
313 |
314 | def set_requires_grad(model, requires_grad=True):
315 | for param in model.parameters():
316 | param.requires_grad = requires_grad
317 |
318 |
319 | def precision_test(logits, label, sig=True):
320 | sigm = nn.Sigmoid()
321 | if sig == True:
322 | output = sigm(logits)
323 | else:
324 | output = logits
325 | label, output = label.cpu(), output.detach().cpu()
326 |
327 | tempprc = sklearn.metrics.average_precision_score(label.numpy(), output.numpy())
328 | return tempprc, output, label
329 |
330 |
331 | def roc_auc(logits, label, sig=True):
332 | sigm = nn.Sigmoid()
333 | if sig == True:
334 | output = sigm(logits)
335 | else:
336 | output = logits
337 | label, output = label.cpu(), output.detach().cpu()
338 |
339 | tempprc = sklearn.metrics.roc_auc_score(label.numpy(), output.numpy())
340 | return tempprc, output, label
341 |
342 |
343 | # golobal function
344 | def create_folder(path):
345 | if not os.path.exists(path):
346 | os.mkdir(path)
347 |
348 |
--------------------------------------------------------------------------------
/pytorch_pretrained_bert/tokenization_openai.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2018 The Open AI Team Authors and The HuggingFace Inc. team.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 | """Tokenization classes for OpenAI GPT."""
16 | from __future__ import (absolute_import, division, print_function,
17 | unicode_literals)
18 |
19 | import json
20 | import logging
21 | import os
22 | import re
23 | import sys
24 | from io import open
25 |
26 | from tqdm import tqdm
27 |
28 | from .file_utils import cached_path
29 | from .tokenization import BasicTokenizer
30 |
31 | logger = logging.getLogger(__name__)
32 |
33 | PRETRAINED_VOCAB_ARCHIVE_MAP = {
34 | 'openai-gpt': "https://s3.amazonaws.com/models.huggingface.co/bert/openai-gpt-vocab.json",
35 | }
36 | PRETRAINED_MERGES_ARCHIVE_MAP = {
37 | 'openai-gpt': "https://s3.amazonaws.com/models.huggingface.co/bert/openai-gpt-merges.txt",
38 | }
39 | PRETRAINED_VOCAB_POSITIONAL_EMBEDDINGS_SIZE_MAP = {
40 | 'openai-gpt': 512,
41 | }
42 | VOCAB_NAME = 'vocab.json'
43 | MERGES_NAME = 'merges.txt'
44 |
45 | def get_pairs(word):
46 | """
47 | Return set of symbol pairs in a word.
48 | word is represented as tuple of symbols (symbols being variable-length strings)
49 | """
50 | pairs = set()
51 | prev_char = word[0]
52 | for char in word[1:]:
53 | pairs.add((prev_char, char))
54 | prev_char = char
55 | return pairs
56 |
57 | def text_standardize(text):
58 | """
59 | fixes some issues the spacy tokenizer had on books corpus
60 | also does some whitespace standardization
61 | """
62 | text = text.replace('—', '-')
63 | text = text.replace('–', '-')
64 | text = text.replace('―', '-')
65 | text = text.replace('…', '...')
66 | text = text.replace('´', "'")
67 | text = re.sub(r'''(-+|~+|!+|"+|;+|\?+|\++|,+|\)+|\(+|\\+|\/+|\*+|\[+|\]+|}+|{+|\|+|_+)''', r' \1 ', text)
68 | text = re.sub(r'\s*\n\s*', ' \n ', text)
69 | text = re.sub(r'[^\S\n]+', ' ', text)
70 | return text.strip()
71 |
72 | class OpenAIGPTTokenizer(object):
73 | """
74 | BPE tokenizer. Peculiarities:
75 | - lower case all inputs
76 | - uses SpaCy tokenizer and ftfy for pre-BPE tokenization if they are installed, fallback to BERT's BasicTokenizer if not.
77 | - argument special_tokens and function set_special_tokens:
78 | can be used to add additional symbols (ex: "__classify__") to a vocabulary.
79 | """
80 | @classmethod
81 | def from_pretrained(cls, pretrained_model_name_or_path, cache_dir=None, *inputs, **kwargs):
82 | """
83 | Instantiate a PreTrainedBertModel from a pre-trained model file.
84 | Download and cache the pre-trained model file if needed.
85 | """
86 | if pretrained_model_name_or_path in PRETRAINED_VOCAB_ARCHIVE_MAP:
87 | vocab_file = PRETRAINED_VOCAB_ARCHIVE_MAP[pretrained_model_name_or_path]
88 | merges_file = PRETRAINED_MERGES_ARCHIVE_MAP[pretrained_model_name_or_path]
89 | else:
90 | vocab_file = os.path.join(pretrained_model_name_or_path, VOCAB_NAME)
91 | merges_file = os.path.join(pretrained_model_name_or_path, MERGES_NAME)
92 | # redirect to the cache, if necessary
93 | try:
94 | resolved_vocab_file = cached_path(vocab_file, cache_dir=cache_dir)
95 | resolved_merges_file = cached_path(merges_file, cache_dir=cache_dir)
96 | except EnvironmentError:
97 | logger.error(
98 | "Model name '{}' was not found in model name list ({}). "
99 | "We assumed '{}' was a path or url but couldn't find files {} and {} "
100 | "at this path or url.".format(
101 | pretrained_model_name_or_path,
102 | ', '.join(PRETRAINED_VOCAB_ARCHIVE_MAP.keys()),
103 | pretrained_model_name_or_path,
104 | vocab_file, merges_file))
105 | return None
106 | if resolved_vocab_file == vocab_file and resolved_merges_file == merges_file:
107 | logger.info("loading vocabulary file {}".format(vocab_file))
108 | logger.info("loading merges file {}".format(merges_file))
109 | else:
110 | logger.info("loading vocabulary file {} from cache at {}".format(
111 | vocab_file, resolved_vocab_file))
112 | logger.info("loading merges file {} from cache at {}".format(
113 | merges_file, resolved_merges_file))
114 | if pretrained_model_name_or_path in PRETRAINED_VOCAB_POSITIONAL_EMBEDDINGS_SIZE_MAP:
115 | # if we're using a pretrained model, ensure the tokenizer wont index sequences longer
116 | # than the number of positional embeddings
117 | max_len = PRETRAINED_VOCAB_POSITIONAL_EMBEDDINGS_SIZE_MAP[pretrained_model_name_or_path]
118 | kwargs['max_len'] = min(kwargs.get('max_len', int(1e12)), max_len)
119 | # Instantiate tokenizer.
120 | tokenizer = cls(resolved_vocab_file, resolved_merges_file, *inputs, **kwargs)
121 | return tokenizer
122 |
123 | def __init__(self, vocab_file, merges_file, special_tokens=None, max_len=None):
124 | try:
125 | import ftfy
126 | import spacy
127 | self.nlp = spacy.load('en', disable=['parser', 'tagger', 'ner', 'textcat'])
128 | self.fix_text = ftfy.fix_text
129 | except ImportError:
130 | logger.warning("ftfy or spacy is not installed using BERT BasicTokenizer instead of SpaCy & ftfy.")
131 | self.nlp = BasicTokenizer(do_lower_case=True,
132 | never_split=special_tokens if special_tokens is not None else [])
133 | self.fix_text = None
134 |
135 | self.max_len = max_len if max_len is not None else int(1e12)
136 | self.encoder = json.load(open(vocab_file, encoding="utf-8"))
137 | self.decoder = {v:k for k,v in self.encoder.items()}
138 | merges = open(merges_file, encoding='utf-8').read().split('\n')[1:-1]
139 | merges = [tuple(merge.split()) for merge in merges]
140 | self.bpe_ranks = dict(zip(merges, range(len(merges))))
141 | self.cache = {}
142 | self.set_special_tokens(special_tokens)
143 |
144 | def __len__(self):
145 | return len(self.encoder) + len(self.special_tokens)
146 |
147 | def set_special_tokens(self, special_tokens):
148 | """ Add a list of additional tokens to the encoder.
149 | The additional tokens are indexed starting from the last index of the
150 | current vocabulary in the order of the `special_tokens` list.
151 | """
152 | if not special_tokens:
153 | self.special_tokens = {}
154 | self.special_tokens_decoder = {}
155 | return
156 | self.special_tokens = dict((tok, len(self.encoder) + i) for i, tok in enumerate(special_tokens))
157 | self.special_tokens_decoder = {v:k for k, v in self.special_tokens.items()}
158 | if self.fix_text is None:
159 | # Using BERT's BasicTokenizer: we can update the tokenizer
160 | self.nlp.never_split = special_tokens
161 | logger.info("Special tokens {}".format(self.special_tokens))
162 |
163 | def bpe(self, token):
164 | word = tuple(token[:-1]) + (token[-1] + '',)
165 | if token in self.cache:
166 | return self.cache[token]
167 | pairs = get_pairs(word)
168 |
169 | if not pairs:
170 | return token+''
171 |
172 | while True:
173 | bigram = min(pairs, key=lambda pair: self.bpe_ranks.get(pair, float('inf')))
174 | if bigram not in self.bpe_ranks:
175 | break
176 | first, second = bigram
177 | new_word = []
178 | i = 0
179 | while i < len(word):
180 | try:
181 | j = word.index(first, i)
182 | new_word.extend(word[i:j])
183 | i = j
184 | except:
185 | new_word.extend(word[i:])
186 | break
187 |
188 | if word[i] == first and i < len(word)-1 and word[i+1] == second:
189 | new_word.append(first+second)
190 | i += 2
191 | else:
192 | new_word.append(word[i])
193 | i += 1
194 | new_word = tuple(new_word)
195 | word = new_word
196 | if len(word) == 1:
197 | break
198 | else:
199 | pairs = get_pairs(word)
200 | word = ' '.join(word)
201 | if word == '\n ':
202 | word = '\n'
203 | self.cache[token] = word
204 | return word
205 |
206 | def tokenize(self, text):
207 | """ Tokenize a string. """
208 | split_tokens = []
209 | if self.fix_text is None:
210 | # Using BERT's BasicTokenizer
211 | text = self.nlp.tokenize(text)
212 | for token in text:
213 | split_tokens.extend([t for t in self.bpe(token).split(' ')])
214 | else:
215 | # Using SpaCy & ftfy (original tokenization process of OpenAI GPT)
216 | text = self.nlp(text_standardize(self.fix_text(text)))
217 | for token in text:
218 | split_tokens.extend([t for t in self.bpe(token.text.lower()).split(' ')])
219 | return split_tokens
220 |
221 | def convert_tokens_to_ids(self, tokens):
222 | """ Converts a sequence of tokens into ids using the vocab. """
223 | ids = []
224 | if isinstance(tokens, str) or (sys.version_info[0] == 2 and isinstance(tokens, unicode)):
225 | if tokens in self.special_tokens:
226 | return self.special_tokens[tokens]
227 | else:
228 | return self.encoder.get(tokens, 0)
229 | for token in tokens:
230 | if token in self.special_tokens:
231 | ids.append(self.special_tokens[token])
232 | else:
233 | ids.append(self.encoder.get(token, 0))
234 | if len(ids) > self.max_len:
235 | logger.warning(
236 | "Token indices sequence length is longer than the specified maximum "
237 | " sequence length for this OpenAI GPT model ({} > {}). Running this"
238 | " sequence through the model will result in indexing errors".format(len(ids), self.max_len)
239 | )
240 | return ids
241 |
242 | def convert_ids_to_tokens(self, ids, skip_special_tokens=False):
243 | """Converts a sequence of ids in BPE tokens using the vocab."""
244 | tokens = []
245 | for i in ids:
246 | if i in self.special_tokens_decoder:
247 | if not skip_special_tokens:
248 | tokens.append(self.special_tokens_decoder[i])
249 | else:
250 | tokens.append(self.decoder[i])
251 | return tokens
252 |
253 | def decode(self, ids, skip_special_tokens=False, clean_up_tokenization_spaces=False):
254 | """Converts a sequence of ids in a string."""
255 | tokens = self.convert_ids_to_tokens(ids, skip_special_tokens=skip_special_tokens)
256 | out_string = ''.join(tokens).replace('', ' ').strip()
257 | if clean_up_tokenization_spaces:
258 | out_string = out_string.replace('', '')
259 | out_string = out_string.replace(' .', '.').replace(' ?', '?').replace(' !', '!').replace(' ,', ',').replace(' ,', ','
260 | ).replace(" n't", "n't").replace(" 'm", "'m").replace(" 're", "'re").replace(" do not", " don't"
261 | ).replace(" 's", "'s").replace(" t ", "'t ").replace(" s ", "'s ").replace(" m ", "'m "
262 | ).replace(" 've", "'ve")
263 | return out_string
264 |
--------------------------------------------------------------------------------
/pytorch_pretrained_bert/modeling_transfo_xl_utilities.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2018 Google AI, Google Brain and Carnegie Mellon University Authors and the HuggingFace Inc. team.
3 | # Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved.
4 | #
5 | # Licensed under the Apache License, Version 2.0 (the "License");
6 | # you may not use this file except in compliance with the License.
7 | # You may obtain a copy of the License at
8 | #
9 | # http://www.apache.org/licenses/LICENSE-2.0
10 | #
11 | # Unless required by applicable law or agreed to in writing, software
12 | # distributed under the License is distributed on an "AS IS" BASIS,
13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 | # See the License for the specific language governing permissions and
15 | # limitations under the License.
16 | """ Utilities for PyTorch Transformer XL model.
17 | Directly adapted from https://github.com/kimiyoung/transformer-xl.
18 | """
19 |
20 | from collections import defaultdict
21 |
22 | import numpy as np
23 |
24 | import torch
25 | import torch.nn as nn
26 | import torch.nn.functional as F
27 |
28 | # CUDA_MAJOR = int(torch.version.cuda.split('.')[0])
29 | # CUDA_MINOR = int(torch.version.cuda.split('.')[1])
30 |
31 | class ProjectedAdaptiveLogSoftmax(nn.Module):
32 | def __init__(self, n_token, d_embed, d_proj, cutoffs, div_val=1,
33 | keep_order=False):
34 | super(ProjectedAdaptiveLogSoftmax, self).__init__()
35 |
36 | self.n_token = n_token
37 | self.d_embed = d_embed
38 | self.d_proj = d_proj
39 |
40 | self.cutoffs = cutoffs + [n_token]
41 | self.cutoff_ends = [0] + self.cutoffs
42 | self.div_val = div_val
43 |
44 | self.shortlist_size = self.cutoffs[0]
45 | self.n_clusters = len(self.cutoffs) - 1
46 | self.head_size = self.shortlist_size + self.n_clusters
47 |
48 | if self.n_clusters > 0:
49 | self.cluster_weight = nn.Parameter(torch.zeros(self.n_clusters, self.d_embed))
50 | self.cluster_bias = nn.Parameter(torch.zeros(self.n_clusters))
51 |
52 | self.out_layers = nn.ModuleList()
53 | self.out_projs = nn.ParameterList()
54 |
55 | if div_val == 1:
56 | for i in range(len(self.cutoffs)):
57 | if d_proj != d_embed:
58 | self.out_projs.append(
59 | nn.Parameter(torch.Tensor(d_proj, d_embed))
60 | )
61 | else:
62 | self.out_projs.append(None)
63 |
64 | self.out_layers.append(nn.Linear(d_embed, n_token))
65 | else:
66 | for i in range(len(self.cutoffs)):
67 | l_idx, r_idx = self.cutoff_ends[i], self.cutoff_ends[i+1]
68 | d_emb_i = d_embed // (div_val ** i)
69 |
70 | self.out_projs.append(
71 | nn.Parameter(torch.Tensor(d_proj, d_emb_i))
72 | )
73 |
74 | self.out_layers.append(nn.Linear(d_emb_i, r_idx-l_idx))
75 |
76 | self.keep_order = keep_order
77 |
78 | def _compute_logit(self, hidden, weight, bias, proj):
79 | if proj is None:
80 | logit = F.linear(hidden, weight, bias=bias)
81 | else:
82 | # if CUDA_MAJOR <= 9 and CUDA_MINOR <= 1:
83 | proj_hid = F.linear(hidden, proj.t().contiguous())
84 | logit = F.linear(proj_hid, weight, bias=bias)
85 | # else:
86 | # logit = torch.einsum('bd,de,ev->bv', (hidden, proj, weight.t()))
87 | # if bias is not None:
88 | # logit = logit + bias
89 |
90 | return logit
91 |
92 | def forward(self, hidden, target=None, keep_order=False):
93 | '''
94 | Params:
95 | hidden :: [len*bsz x d_proj]
96 | target :: [len*bsz]
97 | Return:
98 | if target is None:
99 | out :: [len*bsz] Negative log likelihood
100 | else:
101 | out :: [len*bsz x n_tokens] log probabilities of tokens over the vocabulary
102 | We could replace this implementation by the native PyTorch one
103 | if their's had an option to set bias on all clusters in the native one.
104 | here: https://github.com/pytorch/pytorch/blob/dbe6a7a9ff1a364a8706bf5df58a1ca96d2fd9da/torch/nn/modules/adaptive.py#L138
105 | '''
106 |
107 | if target is not None:
108 | target = target.view(-1)
109 | if hidden.size(0) != target.size(0):
110 | raise RuntimeError('Input and target should have the same size '
111 | 'in the batch dimension.')
112 |
113 | if self.n_clusters == 0:
114 | logit = self._compute_logit(hidden, self.out_layers[0].weight,
115 | self.out_layers[0].bias, self.out_projs[0])
116 | if target is not None:
117 | output = -F.log_softmax(logit, dim=-1) \
118 | .gather(1, target.unsqueeze(1)).squeeze(1)
119 | else:
120 | output = F.log_softmax(logit, dim=-1)
121 | else:
122 | # construct weights and biases
123 | weights, biases = [], []
124 | for i in range(len(self.cutoffs)):
125 | if self.div_val == 1:
126 | l_idx, r_idx = self.cutoff_ends[i], self.cutoff_ends[i + 1]
127 | weight_i = self.out_layers[0].weight[l_idx:r_idx]
128 | bias_i = self.out_layers[0].bias[l_idx:r_idx]
129 | else:
130 | weight_i = self.out_layers[i].weight
131 | bias_i = self.out_layers[i].bias
132 |
133 | if i == 0:
134 | weight_i = torch.cat(
135 | [weight_i, self.cluster_weight], dim=0)
136 | bias_i = torch.cat(
137 | [bias_i, self.cluster_bias], dim=0)
138 |
139 | weights.append(weight_i)
140 | biases.append(bias_i)
141 |
142 | head_weight, head_bias, head_proj = weights[0], biases[0], self.out_projs[0]
143 |
144 | head_logit = self._compute_logit(hidden, head_weight, head_bias, head_proj)
145 | head_logprob = F.log_softmax(head_logit, dim=1)
146 |
147 | if target is None:
148 | out = hidden.new_empty((head_logit.size(0), self.n_token))
149 | else:
150 | out = torch.zeros_like(target, dtype=hidden.dtype, device=hidden.device)
151 |
152 | offset = 0
153 | cutoff_values = [0] + self.cutoffs
154 | for i in range(len(cutoff_values) - 1):
155 | l_idx, r_idx = cutoff_values[i], cutoff_values[i + 1]
156 |
157 | if target is not None:
158 | mask_i = (target >= l_idx) & (target < r_idx)
159 | indices_i = mask_i.nonzero().squeeze()
160 |
161 | if indices_i.numel() == 0:
162 | continue
163 |
164 | target_i = target.index_select(0, indices_i) - l_idx
165 | head_logprob_i = head_logprob.index_select(0, indices_i)
166 | hidden_i = hidden.index_select(0, indices_i)
167 | else:
168 | hidden_i = hidden
169 |
170 | if i == 0:
171 | if target is not None:
172 | logprob_i = head_logprob_i.gather(1, target_i[:, None]).squeeze(1)
173 | else:
174 | out[:, :self.cutoffs[0]] = head_logprob[:, :self.cutoffs[0]]
175 | else:
176 | weight_i, bias_i, proj_i = weights[i], biases[i], self.out_projs[i]
177 |
178 | tail_logit_i = self._compute_logit(hidden_i, weight_i, bias_i, proj_i)
179 | tail_logprob_i = F.log_softmax(tail_logit_i, dim=1)
180 | cluster_prob_idx = self.cutoffs[0] + i - 1 # No probability for the head cluster
181 | if target is not None:
182 | logprob_i = head_logprob_i[:, cluster_prob_idx] \
183 | + tail_logprob_i.gather(1, target_i[:, None]).squeeze(1)
184 | else:
185 | logprob_i = head_logprob[:, cluster_prob_idx, None] + tail_logprob_i
186 | out[:, l_idx:r_idx] = logprob_i
187 |
188 | if target is not None:
189 | if (hasattr(self, 'keep_order') and self.keep_order) or keep_order:
190 | out.index_copy_(0, indices_i, -logprob_i)
191 | else:
192 | out[offset:offset+logprob_i.size(0)].copy_(-logprob_i)
193 | offset += logprob_i.size(0)
194 |
195 | return out
196 |
197 |
198 | def log_prob(self, hidden):
199 | r""" Computes log probabilities for all :math:`n\_classes`
200 | From: https://github.com/pytorch/pytorch/blob/master/torch/nn/modules/adaptive.py
201 | Args:
202 | hidden (Tensor): a minibatch of examples
203 | Returns:
204 | log-probabilities of for each class :math:`c`
205 | in range :math:`0 <= c <= n\_classes`, where :math:`n\_classes` is a
206 | parameter passed to ``AdaptiveLogSoftmaxWithLoss`` constructor.
207 | Shape:
208 | - Input: :math:`(N, in\_features)`
209 | - Output: :math:`(N, n\_classes)`
210 | """
211 | if self.n_clusters == 0:
212 | logit = self._compute_logit(hidden, self.out_layers[0].weight,
213 | self.out_layers[0].bias, self.out_projs[0])
214 | return F.log_softmax(logit, dim=-1)
215 | else:
216 | # construct weights and biases
217 | weights, biases = [], []
218 | for i in range(len(self.cutoffs)):
219 | if self.div_val == 1:
220 | l_idx, r_idx = self.cutoff_ends[i], self.cutoff_ends[i + 1]
221 | weight_i = self.out_layers[0].weight[l_idx:r_idx]
222 | bias_i = self.out_layers[0].bias[l_idx:r_idx]
223 | else:
224 | weight_i = self.out_layers[i].weight
225 | bias_i = self.out_layers[i].bias
226 |
227 | if i == 0:
228 | weight_i = torch.cat(
229 | [weight_i, self.cluster_weight], dim=0)
230 | bias_i = torch.cat(
231 | [bias_i, self.cluster_bias], dim=0)
232 |
233 | weights.append(weight_i)
234 | biases.append(bias_i)
235 |
236 | head_weight, head_bias, head_proj = weights[0], biases[0], self.out_projs[0]
237 | head_logit = self._compute_logit(hidden, head_weight, head_bias, head_proj)
238 |
239 | out = hidden.new_empty((head_logit.size(0), self.n_token))
240 | head_logprob = F.log_softmax(head_logit, dim=1)
241 |
242 | cutoff_values = [0] + self.cutoffs
243 | for i in range(len(cutoff_values) - 1):
244 | start_idx, stop_idx = cutoff_values[i], cutoff_values[i + 1]
245 |
246 | if i == 0:
247 | out[:, :self.cutoffs[0]] = head_logprob[:, :self.cutoffs[0]]
248 | else:
249 | weight_i, bias_i, proj_i = weights[i], biases[i], self.out_projs[i]
250 |
251 | tail_logit_i = self._compute_logit(hidden, weight_i, bias_i, proj_i)
252 | tail_logprob_i = F.log_softmax(tail_logit_i, dim=1)
253 |
254 | logprob_i = head_logprob[:, -i] + tail_logprob_i
255 | out[:, start_idx, stop_idx] = logprob_i
256 |
257 | return out
258 |
259 |
260 | class LogUniformSampler(object):
261 | def __init__(self, range_max, n_sample):
262 | """
263 | Reference : https://github.com/tensorflow/tensorflow/blob/r1.10/tensorflow/python/ops/candidate_sampling_ops.py
264 | `P(class) = (log(class + 2) - log(class + 1)) / log(range_max + 1)`
265 |
266 | expected count can be approximated by 1 - (1 - p)^n
267 | and we use a numerically stable version -expm1(num_tries * log1p(-p))
268 |
269 | Our implementation fixes num_tries at 2 * n_sample, and the actual #samples will vary from run to run
270 | """
271 | with torch.no_grad():
272 | self.range_max = range_max
273 | log_indices = torch.arange(1., range_max+2., 1.).log_()
274 | self.dist = (log_indices[1:] - log_indices[:-1]) / log_indices[-1]
275 | # print('P', self.dist.numpy().tolist()[-30:])
276 |
277 | self.log_q = (- (-self.dist.double().log1p_() * 2 * n_sample).expm1_()).log_().float()
278 |
279 | self.n_sample = n_sample
280 |
281 | def sample(self, labels):
282 | """
283 | labels: [b1, b2]
284 | Return
285 | true_log_probs: [b1, b2]
286 | samp_log_probs: [n_sample]
287 | neg_samples: [n_sample]
288 | """
289 |
290 | # neg_samples = torch.empty(0).long()
291 | n_sample = self.n_sample
292 | n_tries = 2 * n_sample
293 |
294 | with torch.no_grad():
295 | neg_samples = torch.multinomial(self.dist, n_tries, replacement=True).unique()
296 | device = labels.device
297 | neg_samples = neg_samples.to(device)
298 | true_log_probs = self.log_q[labels].to(device)
299 | samp_log_probs = self.log_q[neg_samples].to(device)
300 | return true_log_probs, samp_log_probs, neg_samples
301 |
302 | def sample_logits(embedding, bias, labels, inputs, sampler):
303 | """
304 | embedding: an nn.Embedding layer
305 | bias: [n_vocab]
306 | labels: [b1, b2]
307 | inputs: [b1, b2, n_emb]
308 | sampler: you may use a LogUniformSampler
309 | Return
310 | logits: [b1, b2, 1 + n_sample]
311 | """
312 | true_log_probs, samp_log_probs, neg_samples = sampler.sample(labels)
313 | n_sample = neg_samples.size(0)
314 | b1, b2 = labels.size(0), labels.size(1)
315 | all_ids = torch.cat([labels.view(-1), neg_samples])
316 | all_w = embedding(all_ids)
317 | true_w = all_w[: -n_sample].view(b1, b2, -1)
318 | sample_w = all_w[- n_sample:].view(n_sample, -1)
319 |
320 | all_b = bias[all_ids]
321 | true_b = all_b[: -n_sample].view(b1, b2)
322 | sample_b = all_b[- n_sample:]
323 |
324 | hit = (labels[:, :, None] == neg_samples).detach()
325 |
326 | true_logits = torch.einsum('ijk,ijk->ij',
327 | [true_w, inputs]) + true_b - true_log_probs
328 | sample_logits = torch.einsum('lk,ijk->ijl',
329 | [sample_w, inputs]) + sample_b - samp_log_probs
330 | sample_logits.masked_fill_(hit, -1e30)
331 | logits = torch.cat([true_logits[:, :, None], sample_logits], -1)
332 |
333 | return logits
334 |
335 |
336 | # class LogUniformSampler(object):
337 | # def __init__(self, range_max, unique=False):
338 | # """
339 | # Reference : https://github.com/tensorflow/tensorflow/blob/r1.10/tensorflow/python/ops/candidate_sampling_ops.py
340 | # `P(class) = (log(class + 2) - log(class + 1)) / log(range_max + 1)`
341 | # """
342 | # self.range_max = range_max
343 | # log_indices = torch.arange(1., range_max+2., 1.).log_()
344 | # self.dist = (log_indices[1:] - log_indices[:-1]) / log_indices[-1]
345 |
346 | # self.unique = unique
347 |
348 | # if self.unique:
349 | # self.exclude_mask = torch.ByteTensor(range_max).fill_(0)
350 |
351 | # def sample(self, n_sample, labels):
352 | # pos_sample, new_labels = labels.unique(return_inverse=True)
353 | # n_pos_sample = pos_sample.size(0)
354 | # n_neg_sample = n_sample - n_pos_sample
355 |
356 | # if self.unique:
357 | # self.exclude_mask.index_fill_(0, pos_sample, 1)
358 | # sample_dist = self.dist.clone().masked_fill_(self.exclude_mask, 0)
359 | # self.exclude_mask.index_fill_(0, pos_sample, 0)
360 | # else:
361 | # sample_dist = self.dist
362 |
363 | # neg_sample = torch.multinomial(sample_dist, n_neg_sample)
364 |
365 | # sample = torch.cat([pos_sample, neg_sample])
366 | # sample_prob = self.dist[sample]
367 |
368 | # return new_labels, sample, sample_prob
369 |
370 |
371 | if __name__ == '__main__':
372 | S, B = 3, 4
373 | n_vocab = 10000
374 | n_sample = 5
375 | H = 32
376 |
377 | labels = torch.LongTensor(S, B).random_(0, n_vocab)
378 |
379 | # sampler = LogUniformSampler(n_vocab, unique=False)
380 | # new_labels, sample, sample_prob = sampler.sample(n_sample, labels)
381 |
382 | sampler = LogUniformSampler(n_vocab, n_sample)#, unique=True)
383 | # true_probs, samp_probs, neg_samples = sampler.sample(n_sample, labels)
384 |
385 | # print('true_probs', true_probs.numpy().tolist())
386 | # print('samp_probs', samp_probs.numpy().tolist())
387 | # print('neg_samples', neg_samples.numpy().tolist())
388 |
389 | # print('sum', torch.sum(sampler.dist).item())
390 |
391 | # assert torch.all(torch.sort(sample.unique())[0].eq(torch.sort(sample)[0])).item()
392 |
393 | embedding = nn.Embedding(n_vocab, H)
394 | bias = torch.zeros(n_vocab)
395 | inputs = torch.Tensor(S, B, H).normal_()
396 |
397 | logits, out_labels = sample_logits(embedding, bias, labels, inputs, sampler, n_sample)
398 | print('logits', logits.detach().numpy().tolist())
399 | print('logits shape', logits.size())
400 | print('out_labels', out_labels.detach().numpy().tolist())
401 | print('out_labels shape', out_labels.size())
402 |
403 |
--------------------------------------------------------------------------------
/pytorch_pretrained_bert/tokenization.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2018 The Google AI Language Team Authors and The HuggingFace Inc. team.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 | """Tokenization classes."""
16 |
17 | from __future__ import absolute_import, division, print_function, unicode_literals
18 |
19 | import collections
20 | import logging
21 | import os
22 | import unicodedata
23 | from io import open
24 |
25 | from .file_utils import cached_path
26 |
27 | logger = logging.getLogger(__name__)
28 |
29 | PRETRAINED_VOCAB_ARCHIVE_MAP = {
30 | 'bert-base-uncased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-uncased-vocab.txt",
31 | 'bert-large-uncased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-uncased-vocab.txt",
32 | 'bert-base-cased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-cased-vocab.txt",
33 | 'bert-large-cased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-cased-vocab.txt",
34 | 'bert-base-multilingual-uncased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-multilingual-uncased-vocab.txt",
35 | 'bert-base-multilingual-cased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-multilingual-cased-vocab.txt",
36 | 'bert-base-chinese': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-chinese-vocab.txt",
37 | }
38 | PRETRAINED_VOCAB_POSITIONAL_EMBEDDINGS_SIZE_MAP = {
39 | 'bert-base-uncased': 512,
40 | 'bert-large-uncased': 512,
41 | 'bert-base-cased': 512,
42 | 'bert-large-cased': 512,
43 | 'bert-base-multilingual-uncased': 512,
44 | 'bert-base-multilingual-cased': 512,
45 | 'bert-base-chinese': 512,
46 | }
47 | VOCAB_NAME = 'vocab.txt'
48 |
49 |
50 | def load_vocab(vocab_file):
51 | """Loads a vocabulary file into a dictionary."""
52 | vocab = collections.OrderedDict()
53 | index = 0
54 | with open(vocab_file, "r", encoding="utf-8") as reader:
55 | while True:
56 | token = reader.readline()
57 | if not token:
58 | break
59 | token = token.strip()
60 | vocab[token] = index
61 | index += 1
62 | return vocab
63 |
64 |
65 | def whitespace_tokenize(text):
66 | """Runs basic whitespace cleaning and splitting on a piece of text."""
67 | text = text.strip()
68 | if not text:
69 | return []
70 | tokens = text.split()
71 | return tokens
72 |
73 |
74 | class BertTokenizer(object):
75 | """Runs end-to-end tokenization: punctuation splitting + wordpiece"""
76 |
77 | def __init__(self, vocab_file, do_lower_case=True, max_len=None, do_basic_tokenize=True,
78 | never_split=("[UNK]", "[SEP]", "[PAD]", "[CLS]", "[MASK]")):
79 | """Constructs a BertTokenizer.
80 |
81 | Args:
82 | vocab_file: Path to a one-wordpiece-per-line vocabulary file
83 | do_lower_case: Whether to lower case the input
84 | Only has an effect when do_wordpiece_only=False
85 | do_basic_tokenize: Whether to do basic tokenization before wordpiece.
86 | max_len: An artificial maximum length to truncate tokenized sequences to;
87 | Effective maximum length is always the minimum of this
88 | value (if specified) and the underlying BERT model's
89 | sequence length.
90 | never_split: List of tokens which will never be split during tokenization.
91 | Only has an effect when do_wordpiece_only=False
92 | """
93 | if not os.path.isfile(vocab_file):
94 | raise ValueError(
95 | "Can't find a vocabulary file at path '{}'. To load the vocabulary from a Google pretrained "
96 | "model use `tokenizer = BertTokenizer.from_pretrained(PRETRAINED_MODEL_NAME)`".format(vocab_file))
97 | self.vocab = load_vocab(vocab_file)
98 | self.ids_to_tokens = collections.OrderedDict(
99 | [(ids, tok) for tok, ids in self.vocab.items()])
100 | self.do_basic_tokenize = do_basic_tokenize
101 | if do_basic_tokenize:
102 | self.basic_tokenizer = BasicTokenizer(do_lower_case=do_lower_case,
103 | never_split=never_split)
104 | self.wordpiece_tokenizer = WordpieceTokenizer(vocab=self.vocab)
105 | self.max_len = max_len if max_len is not None else int(1e12)
106 |
107 | def tokenize(self, text):
108 | split_tokens = []
109 | if self.do_basic_tokenize:
110 | for token in self.basic_tokenizer.tokenize(text):
111 | for sub_token in self.wordpiece_tokenizer.tokenize(token):
112 | split_tokens.append(sub_token)
113 | else:
114 | split_tokens = self.wordpiece_tokenizer.tokenize(text)
115 | return split_tokens
116 |
117 | def convert_tokens_to_ids(self, tokens):
118 | """Converts a sequence of tokens into ids using the vocab."""
119 | ids = []
120 | for token in tokens:
121 | ids.append(self.vocab[token])
122 | if len(ids) > self.max_len:
123 | logger.warning(
124 | "Token indices sequence length is longer than the specified maximum "
125 | " sequence length for this BERT model ({} > {}). Running this"
126 | " sequence through BERT will result in indexing errors".format(len(ids), self.max_len)
127 | )
128 | return ids
129 |
130 | def convert_ids_to_tokens(self, ids):
131 | """Converts a sequence of ids in wordpiece tokens using the vocab."""
132 | tokens = []
133 | for i in ids:
134 | tokens.append(self.ids_to_tokens[i])
135 | return tokens
136 |
137 | @classmethod
138 | def from_pretrained(cls, pretrained_model_name_or_path, cache_dir=None, *inputs, **kwargs):
139 | """
140 | Instantiate a PreTrainedBertModel from a pre-trained model file.
141 | Download and cache the pre-trained model file if needed.
142 | """
143 | if pretrained_model_name_or_path in PRETRAINED_VOCAB_ARCHIVE_MAP:
144 | vocab_file = PRETRAINED_VOCAB_ARCHIVE_MAP[pretrained_model_name_or_path]
145 | if '-cased' in pretrained_model_name_or_path and kwargs.get('do_lower_case', True):
146 | logger.warning("The pre-trained model you are loading is a cased model but you have not set "
147 | "`do_lower_case` to False. We are setting `do_lower_case=False` for you but "
148 | "you may want to check this behavior.")
149 | kwargs['do_lower_case'] = False
150 | elif '-cased' not in pretrained_model_name_or_path and not kwargs.get('do_lower_case', True):
151 | logger.warning("The pre-trained model you are loading is an uncased model but you have set "
152 | "`do_lower_case` to False. We are setting `do_lower_case=True` for you "
153 | "but you may want to check this behavior.")
154 | kwargs['do_lower_case'] = True
155 | else:
156 | vocab_file = pretrained_model_name_or_path
157 | if os.path.isdir(vocab_file):
158 | vocab_file = os.path.join(vocab_file, VOCAB_NAME)
159 | # redirect to the cache, if necessary
160 | try:
161 | resolved_vocab_file = cached_path(vocab_file, cache_dir=cache_dir)
162 | except EnvironmentError:
163 | logger.error(
164 | "Model name '{}' was not found in model name list ({}). "
165 | "We assumed '{}' was a path or url but couldn't find any file "
166 | "associated to this path or url.".format(
167 | pretrained_model_name_or_path,
168 | ', '.join(PRETRAINED_VOCAB_ARCHIVE_MAP.keys()),
169 | vocab_file))
170 | return None
171 | if resolved_vocab_file == vocab_file:
172 | logger.info("loading vocabulary file {}".format(vocab_file))
173 | else:
174 | logger.info("loading vocabulary file {} from cache at {}".format(
175 | vocab_file, resolved_vocab_file))
176 | if pretrained_model_name_or_path in PRETRAINED_VOCAB_POSITIONAL_EMBEDDINGS_SIZE_MAP:
177 | # if we're using a pretrained model, ensure the tokenizer wont index sequences longer
178 | # than the number of positional embeddings
179 | max_len = PRETRAINED_VOCAB_POSITIONAL_EMBEDDINGS_SIZE_MAP[pretrained_model_name_or_path]
180 | kwargs['max_len'] = min(kwargs.get('max_len', int(1e12)), max_len)
181 | # Instantiate tokenizer.
182 | tokenizer = cls(resolved_vocab_file, *inputs, **kwargs)
183 | return tokenizer
184 |
185 |
186 | class BasicTokenizer(object):
187 | """Runs basic tokenization (punctuation splitting, lower casing, etc.)."""
188 |
189 | def __init__(self,
190 | do_lower_case=True,
191 | never_split=("[UNK]", "[SEP]", "[PAD]", "[CLS]", "[MASK]")):
192 | """Constructs a BasicTokenizer.
193 |
194 | Args:
195 | do_lower_case: Whether to lower case the input.
196 | """
197 | self.do_lower_case = do_lower_case
198 | self.never_split = never_split
199 |
200 | def tokenize(self, text):
201 | """Tokenizes a piece of text."""
202 | text = self._clean_text(text)
203 | # This was added on November 1st, 2018 for the multilingual and Chinese
204 | # models. This is also applied to the English models now, but it doesn't
205 | # matter since the English models were not trained on any Chinese data
206 | # and generally don't have any Chinese data in them (there are Chinese
207 | # characters in the vocabulary because Wikipedia does have some Chinese
208 | # words in the English Wikipedia.).
209 | text = self._tokenize_chinese_chars(text)
210 | orig_tokens = whitespace_tokenize(text)
211 | split_tokens = []
212 | for token in orig_tokens:
213 | if self.do_lower_case and token not in self.never_split:
214 | token = token.lower()
215 | token = self._run_strip_accents(token)
216 | split_tokens.extend(self._run_split_on_punc(token))
217 |
218 | output_tokens = whitespace_tokenize(" ".join(split_tokens))
219 | return output_tokens
220 |
221 | def _run_strip_accents(self, text):
222 | """Strips accents from a piece of text."""
223 | text = unicodedata.normalize("NFD", text)
224 | output = []
225 | for char in text:
226 | cat = unicodedata.category(char)
227 | if cat == "Mn":
228 | continue
229 | output.append(char)
230 | return "".join(output)
231 |
232 | def _run_split_on_punc(self, text):
233 | """Splits punctuation on a piece of text."""
234 | if text in self.never_split:
235 | return [text]
236 | chars = list(text)
237 | i = 0
238 | start_new_word = True
239 | output = []
240 | while i < len(chars):
241 | char = chars[i]
242 | if _is_punctuation(char):
243 | output.append([char])
244 | start_new_word = True
245 | else:
246 | if start_new_word:
247 | output.append([])
248 | start_new_word = False
249 | output[-1].append(char)
250 | i += 1
251 |
252 | return ["".join(x) for x in output]
253 |
254 | def _tokenize_chinese_chars(self, text):
255 | """Adds whitespace around any CJK character."""
256 | output = []
257 | for char in text:
258 | cp = ord(char)
259 | if self._is_chinese_char(cp):
260 | output.append(" ")
261 | output.append(char)
262 | output.append(" ")
263 | else:
264 | output.append(char)
265 | return "".join(output)
266 |
267 | def _is_chinese_char(self, cp):
268 | """Checks whether CP is the codepoint of a CJK character."""
269 | # This defines a "chinese character" as anything in the CJK Unicode block:
270 | # https://en.wikipedia.org/wiki/CJK_Unified_Ideographs_(Unicode_block)
271 | #
272 | # Note that the CJK Unicode block is NOT all Japanese and Korean characters,
273 | # despite its name. The modern Korean Hangul alphabet is a different block,
274 | # as is Japanese Hiragana and Katakana. Those alphabets are used to write
275 | # space-separated words, so they are not treated specially and handled
276 | # like the all of the other languages.
277 | if ((cp >= 0x4E00 and cp <= 0x9FFF) or #
278 | (cp >= 0x3400 and cp <= 0x4DBF) or #
279 | (cp >= 0x20000 and cp <= 0x2A6DF) or #
280 | (cp >= 0x2A700 and cp <= 0x2B73F) or #
281 | (cp >= 0x2B740 and cp <= 0x2B81F) or #
282 | (cp >= 0x2B820 and cp <= 0x2CEAF) or
283 | (cp >= 0xF900 and cp <= 0xFAFF) or #
284 | (cp >= 0x2F800 and cp <= 0x2FA1F)): #
285 | return True
286 |
287 | return False
288 |
289 | def _clean_text(self, text):
290 | """Performs invalid character removal and whitespace cleanup on text."""
291 | output = []
292 | for char in text:
293 | cp = ord(char)
294 | if cp == 0 or cp == 0xfffd or _is_control(char):
295 | continue
296 | if _is_whitespace(char):
297 | output.append(" ")
298 | else:
299 | output.append(char)
300 | return "".join(output)
301 |
302 |
303 | class WordpieceTokenizer(object):
304 | """Runs WordPiece tokenization."""
305 |
306 | def __init__(self, vocab, unk_token="[UNK]", max_input_chars_per_word=100):
307 | self.vocab = vocab
308 | self.unk_token = unk_token
309 | self.max_input_chars_per_word = max_input_chars_per_word
310 |
311 | def tokenize(self, text):
312 | """Tokenizes a piece of text into its word pieces.
313 |
314 | This uses a greedy longest-match-first algorithm to perform tokenization
315 | using the given vocabulary.
316 |
317 | For example:
318 | input = "unaffable"
319 | output = ["un", "##aff", "##able"]
320 |
321 | Args:
322 | text: A single token or whitespace separated tokens. This should have
323 | already been passed through `BasicTokenizer`.
324 |
325 | Returns:
326 | A list of wordpiece tokens.
327 | """
328 |
329 | output_tokens = []
330 | for token in whitespace_tokenize(text):
331 | chars = list(token)
332 | if len(chars) > self.max_input_chars_per_word:
333 | output_tokens.append(self.unk_token)
334 | continue
335 |
336 | is_bad = False
337 | start = 0
338 | sub_tokens = []
339 | while start < len(chars):
340 | end = len(chars)
341 | cur_substr = None
342 | while start < end:
343 | substr = "".join(chars[start:end])
344 | if start > 0:
345 | substr = "##" + substr
346 | if substr in self.vocab:
347 | cur_substr = substr
348 | break
349 | end -= 1
350 | if cur_substr is None:
351 | is_bad = True
352 | break
353 | sub_tokens.append(cur_substr)
354 | start = end
355 |
356 | if is_bad:
357 | output_tokens.append(self.unk_token)
358 | else:
359 | output_tokens.extend(sub_tokens)
360 | return output_tokens
361 |
362 |
363 | def _is_whitespace(char):
364 | """Checks whether `chars` is a whitespace character."""
365 | # \t, \n, and \r are technically contorl characters but we treat them
366 | # as whitespace since they are generally considered as such.
367 | if char == " " or char == "\t" or char == "\n" or char == "\r":
368 | return True
369 | cat = unicodedata.category(char)
370 | if cat == "Zs":
371 | return True
372 | return False
373 |
374 |
375 | def _is_control(char):
376 | """Checks whether `chars` is a control character."""
377 | # These are technically control characters but we count them as whitespace
378 | # characters.
379 | if char == "\t" or char == "\n" or char == "\r":
380 | return False
381 | cat = unicodedata.category(char)
382 | if cat.startswith("C"):
383 | return True
384 | return False
385 |
386 |
387 | def _is_punctuation(char):
388 | """Checks whether `chars` is a punctuation character."""
389 | cp = ord(char)
390 | # We treat all non-letter/number ASCII as punctuation.
391 | # Characters such as "^", "$", and "`" are not in the Unicode
392 | # Punctuation class but we treat them as punctuation anyways, for
393 | # consistency.
394 | if ((cp >= 33 and cp <= 47) or (cp >= 58 and cp <= 64) or
395 | (cp >= 91 and cp <= 96) or (cp >= 123 and cp <= 126)):
396 | return True
397 | cat = unicodedata.category(char)
398 | if cat.startswith("P"):
399 | return True
400 | return False
401 |
--------------------------------------------------------------------------------
/pytorch_pretrained_bert/tokenization_transfo_xl.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2018 Google AI, Google Brain and Carnegie Mellon University Authors and the HuggingFace Inc. team.
3 | # Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved.
4 | #
5 | # Licensed under the Apache License, Version 2.0 (the "License");
6 | # you may not use this file except in compliance with the License.
7 | # You may obtain a copy of the License at
8 | #
9 | # http://www.apache.org/licenses/LICENSE-2.0
10 | #
11 | # Unless required by applicable law or agreed to in writing, software
12 | # distributed under the License is distributed on an "AS IS" BASIS,
13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 | # See the License for the specific language governing permissions and
15 | # limitations under the License.
16 | """ Tokenization classes for Transformer XL model.
17 | Adapted from https://github.com/kimiyoung/transformer-xl.
18 | """
19 | from __future__ import (absolute_import, division, print_function,
20 | unicode_literals)
21 |
22 | import glob
23 | import logging
24 | import os
25 | import sys
26 | from collections import Counter, OrderedDict
27 | from io import open
28 | import unicodedata
29 |
30 | import torch
31 | import numpy as np
32 |
33 | from .file_utils import cached_path
34 |
35 | if sys.version_info[0] == 2:
36 | import cPickle as pickle
37 | else:
38 | import pickle
39 |
40 |
41 | logger = logging.getLogger(__name__)
42 |
43 | PRETRAINED_VOCAB_ARCHIVE_MAP = {
44 | 'transfo-xl-wt103': "https://s3.amazonaws.com/models.huggingface.co/bert/transfo-xl-wt103-vocab.bin",
45 | }
46 | VOCAB_NAME = 'vocab.bin'
47 |
48 | PRETRAINED_CORPUS_ARCHIVE_MAP = {
49 | 'transfo-xl-wt103': "https://s3.amazonaws.com/models.huggingface.co/bert/transfo-xl-wt103-corpus.bin",
50 | }
51 | CORPUS_NAME = 'corpus.bin'
52 |
53 | class TransfoXLTokenizer(object):
54 | """
55 | Transformer-XL tokenizer adapted from Vocab class in https://github.com/kimiyoung/transformer-xl
56 | """
57 | @classmethod
58 | def from_pretrained(cls, pretrained_model_name_or_path, cache_dir=None, *inputs, **kwargs):
59 | """
60 | Instantiate a TransfoXLTokenizer.
61 | The TransfoXLTokenizer.
62 | """
63 | if pretrained_model_name_or_path in PRETRAINED_VOCAB_ARCHIVE_MAP:
64 | vocab_file = PRETRAINED_VOCAB_ARCHIVE_MAP[pretrained_model_name_or_path]
65 | else:
66 | vocab_file = os.path.join(pretrained_model_name_or_path, VOCAB_NAME)
67 | # redirect to the cache, if necessary
68 | try:
69 | resolved_vocab_file = cached_path(vocab_file, cache_dir=cache_dir)
70 | except EnvironmentError:
71 | logger.error(
72 | "Model name '{}' was not found in model name list ({}). "
73 | "We assumed '{}' was a path or url but couldn't find files {} "
74 | "at this path or url.".format(
75 | pretrained_model_name_or_path,
76 | ', '.join(PRETRAINED_VOCAB_ARCHIVE_MAP.keys()),
77 | pretrained_model_name_or_path,
78 | vocab_file))
79 | return None
80 | if resolved_vocab_file == vocab_file:
81 | logger.info("loading vocabulary file {}".format(vocab_file))
82 | else:
83 | logger.info("loading vocabulary file {} from cache at {}".format(
84 | vocab_file, resolved_vocab_file))
85 |
86 | # Instantiate tokenizer.
87 | tokenizer = cls(*inputs, **kwargs)
88 | vocab_dict = torch.load(resolved_vocab_file)
89 | for key, value in vocab_dict.items():
90 | tokenizer.__dict__[key] = value
91 | return tokenizer
92 |
93 | def __init__(self, special=[], min_freq=0, max_size=None, lower_case=False,
94 | delimiter=None, vocab_file=None, never_split=("", "", "")):
95 | self.counter = Counter()
96 | self.special = special
97 | self.min_freq = min_freq
98 | self.max_size = max_size
99 | self.lower_case = lower_case
100 | self.delimiter = delimiter
101 | self.vocab_file = vocab_file
102 | self.never_split = never_split
103 |
104 | def count_file(self, path, verbose=False, add_eos=False):
105 | if verbose: print('counting file {} ...'.format(path))
106 | assert os.path.exists(path)
107 |
108 | sents = []
109 | with open(path, 'r', encoding='utf-8') as f:
110 | for idx, line in enumerate(f):
111 | if verbose and idx > 0 and idx % 500000 == 0:
112 | print(' line {}'.format(idx))
113 | symbols = self.tokenize(line, add_eos=add_eos)
114 | self.counter.update(symbols)
115 | sents.append(symbols)
116 |
117 | return sents
118 |
119 | def count_sents(self, sents, verbose=False):
120 | """
121 | sents : a list of sentences, each a list of tokenized symbols
122 | """
123 | if verbose: print('counting {} sents ...'.format(len(sents)))
124 | for idx, symbols in enumerate(sents):
125 | if verbose and idx > 0 and idx % 500000 == 0:
126 | print(' line {}'.format(idx))
127 | self.counter.update(symbols)
128 |
129 | def _build_from_file(self, vocab_file):
130 | self.idx2sym = []
131 | self.sym2idx = OrderedDict()
132 |
133 | with open(vocab_file, 'r', encoding='utf-8') as f:
134 | for line in f:
135 | symb = line.strip().split()[0]
136 | self.add_symbol(symb)
137 | if '' in self.sym2idx:
138 | self.unk_idx = self.sym2idx['']
139 | elif '' in self.sym2idx:
140 | self.unk_idx = self.sym2idx['']
141 | else:
142 | raise ValueError('No token in vocabulary')
143 |
144 | def build_vocab(self):
145 | if self.vocab_file:
146 | print('building vocab from {}'.format(self.vocab_file))
147 | self._build_from_file(self.vocab_file)
148 | print('final vocab size {}'.format(len(self)))
149 | else:
150 | print('building vocab with min_freq={}, max_size={}'.format(
151 | self.min_freq, self.max_size))
152 | self.idx2sym = []
153 | self.sym2idx = OrderedDict()
154 |
155 | for sym in self.special:
156 | self.add_special(sym)
157 |
158 | for sym, cnt in self.counter.most_common(self.max_size):
159 | if cnt < self.min_freq: break
160 | self.add_symbol(sym)
161 |
162 | print('final vocab size {} from {} unique tokens'.format(
163 | len(self), len(self.counter)))
164 |
165 | def encode_file(self, path, ordered=False, verbose=False, add_eos=True,
166 | add_double_eos=False):
167 | if verbose: print('encoding file {} ...'.format(path))
168 | assert os.path.exists(path)
169 | encoded = []
170 | with open(path, 'r', encoding='utf-8') as f:
171 | for idx, line in enumerate(f):
172 | if verbose and idx > 0 and idx % 500000 == 0:
173 | print(' line {}'.format(idx))
174 | symbols = self.tokenize(line, add_eos=add_eos,
175 | add_double_eos=add_double_eos)
176 | encoded.append(self.convert_to_tensor(symbols))
177 |
178 | if ordered:
179 | encoded = torch.cat(encoded)
180 |
181 | return encoded
182 |
183 | def encode_sents(self, sents, ordered=False, verbose=False):
184 | if verbose: print('encoding {} sents ...'.format(len(sents)))
185 | encoded = []
186 | for idx, symbols in enumerate(sents):
187 | if verbose and idx > 0 and idx % 500000 == 0:
188 | print(' line {}'.format(idx))
189 | encoded.append(self.convert_to_tensor(symbols))
190 |
191 | if ordered:
192 | encoded = torch.cat(encoded)
193 |
194 | return encoded
195 |
196 | def add_special(self, sym):
197 | if sym not in self.sym2idx:
198 | self.idx2sym.append(sym)
199 | self.sym2idx[sym] = len(self.idx2sym) - 1
200 | setattr(self, '{}_idx'.format(sym.strip('<>')), self.sym2idx[sym])
201 |
202 | def add_symbol(self, sym):
203 | if sym not in self.sym2idx:
204 | self.idx2sym.append(sym)
205 | self.sym2idx[sym] = len(self.idx2sym) - 1
206 |
207 | def get_sym(self, idx):
208 | assert 0 <= idx < len(self), 'Index {} out of vocabulary range'.format(idx)
209 | return self.idx2sym[idx]
210 |
211 | def get_idx(self, sym):
212 | if sym in self.sym2idx:
213 | return self.sym2idx[sym]
214 | else:
215 | # print('encounter unk {}'.format(sym))
216 | # assert '' not in sym
217 | if hasattr(self, 'unk_idx'):
218 | return self.sym2idx.get(sym, self.unk_idx)
219 | # Backward compatibility with pre-trained models
220 | elif '' in self.sym2idx:
221 | return self.sym2idx['']
222 | elif '' in self.sym2idx:
223 | return self.sym2idx['']
224 | else:
225 | raise ValueError('Token not in vocabulary and no token in vocabulary for replacement')
226 |
227 | def convert_ids_to_tokens(self, indices):
228 | """Converts a sequence of indices in symbols using the vocab."""
229 | return [self.get_sym(idx) for idx in indices]
230 |
231 | def convert_tokens_to_ids(self, symbols):
232 | """Converts a sequence of symbols into ids using the vocab."""
233 | return [self.get_idx(sym) for sym in symbols]
234 |
235 | def convert_to_tensor(self, symbols):
236 | return torch.LongTensor(self.convert_tokens_to_ids(symbols))
237 |
238 | def decode(self, indices, exclude=None):
239 | """Converts a sequence of indices in a string."""
240 | if exclude is None:
241 | return ' '.join([self.get_sym(idx) for idx in indices])
242 | else:
243 | return ' '.join([self.get_sym(idx) for idx in indices if idx not in exclude])
244 |
245 | def __len__(self):
246 | return len(self.idx2sym)
247 |
248 | def _run_split_on_punc(self, text):
249 | """Splits punctuation on a piece of text."""
250 | if text in self.never_split:
251 | return [text]
252 | chars = list(text)
253 | i = 0
254 | start_new_word = True
255 | output = []
256 | while i < len(chars):
257 | char = chars[i]
258 | if _is_punctuation(char):
259 | output.append([char])
260 | start_new_word = True
261 | else:
262 | if start_new_word:
263 | output.append([])
264 | start_new_word = False
265 | output[-1].append(char)
266 | i += 1
267 |
268 | return ["".join(x) for x in output]
269 |
270 | def _run_strip_accents(self, text):
271 | """Strips accents from a piece of text."""
272 | text = unicodedata.normalize("NFD", text)
273 | output = []
274 | for char in text:
275 | cat = unicodedata.category(char)
276 | if cat == "Mn":
277 | continue
278 | output.append(char)
279 | return "".join(output)
280 |
281 | def _clean_text(self, text):
282 | """Performs invalid character removal and whitespace cleanup on text."""
283 | output = []
284 | for char in text:
285 | cp = ord(char)
286 | if cp == 0 or cp == 0xfffd or _is_control(char):
287 | continue
288 | if _is_whitespace(char):
289 | output.append(" ")
290 | else:
291 | output.append(char)
292 | return "".join(output)
293 |
294 | def whitespace_tokenize(self, text):
295 | """Runs basic whitespace cleaning and splitting on a piece of text."""
296 | text = text.strip()
297 | if not text:
298 | return []
299 | if self.delimiter == '':
300 | tokens = text
301 | else:
302 | tokens = text.split(self.delimiter)
303 | return tokens
304 |
305 | def tokenize(self, line, add_eos=False, add_double_eos=False):
306 | line = self._clean_text(line)
307 | line = line.strip()
308 |
309 | symbols = self.whitespace_tokenize(line)
310 |
311 | split_symbols = []
312 | for symbol in symbols:
313 | if self.lower_case and symbol not in self.never_split:
314 | symbol = symbol.lower()
315 | symbol = self._run_strip_accents(symbol)
316 | split_symbols.extend(self._run_split_on_punc(symbol))
317 |
318 | if add_double_eos: # lm1b
319 | return [''] + split_symbols + ['']
320 | elif add_eos:
321 | return split_symbols + ['']
322 | else:
323 | return split_symbols
324 |
325 |
326 | class LMOrderedIterator(object):
327 | def __init__(self, data, bsz, bptt, device='cpu', ext_len=None):
328 | """
329 | data -- LongTensor -- the LongTensor is strictly ordered
330 | """
331 | self.bsz = bsz
332 | self.bptt = bptt
333 | self.ext_len = ext_len if ext_len is not None else 0
334 |
335 | self.device = device
336 |
337 | # Work out how cleanly we can divide the dataset into bsz parts.
338 | self.n_step = data.size(0) // bsz
339 |
340 | # Trim off any extra elements that wouldn't cleanly fit (remainders).
341 | data = data.narrow(0, 0, self.n_step * bsz)
342 |
343 | # Evenly divide the data across the bsz batches.
344 | self.data = data.view(bsz, -1).t().contiguous().to(device)
345 |
346 | # Number of mini-batches
347 | self.n_batch = (self.n_step + self.bptt - 1) // self.bptt
348 |
349 | def get_batch(self, i, bptt=None):
350 | if bptt is None: bptt = self.bptt
351 | seq_len = min(bptt, self.data.size(0) - 1 - i)
352 |
353 | end_idx = i + seq_len
354 | beg_idx = max(0, i - self.ext_len)
355 |
356 | data = self.data[beg_idx:end_idx]
357 | target = self.data[i+1:i+1+seq_len]
358 |
359 | data_out = data.transpose(0, 1).contiguous().to(self.device)
360 | target_out = target.transpose(0, 1).contiguous().to(self.device)
361 |
362 | return data_out, target_out, seq_len
363 |
364 | def get_fixlen_iter(self, start=0):
365 | for i in range(start, self.data.size(0) - 1, self.bptt):
366 | yield self.get_batch(i)
367 |
368 | def get_varlen_iter(self, start=0, std=5, min_len=5, max_deviation=3):
369 | max_len = self.bptt + max_deviation * std
370 | i = start
371 | while True:
372 | bptt = self.bptt if np.random.random() < 0.95 else self.bptt / 2.
373 | bptt = min(max_len, max(min_len, int(np.random.normal(bptt, std))))
374 | data, target, seq_len = self.get_batch(i, bptt)
375 | i += seq_len
376 | yield data, target, seq_len
377 | if i >= self.data.size(0) - 2:
378 | break
379 |
380 | def __iter__(self):
381 | return self.get_fixlen_iter()
382 |
383 |
384 | class LMShuffledIterator(object):
385 | def __init__(self, data, bsz, bptt, device='cpu', ext_len=None, shuffle=False):
386 | """
387 | data -- list[LongTensor] -- there is no order among the LongTensors
388 | """
389 | self.data = data
390 |
391 | self.bsz = bsz
392 | self.bptt = bptt
393 | self.ext_len = ext_len if ext_len is not None else 0
394 |
395 | self.device = device
396 | self.shuffle = shuffle
397 |
398 | def get_sent_stream(self):
399 | # index iterator
400 | epoch_indices = np.random.permutation(len(self.data)) if self.shuffle \
401 | else np.array(range(len(self.data)))
402 |
403 | # sentence iterator
404 | for idx in epoch_indices:
405 | yield self.data[idx]
406 |
407 | def stream_iterator(self, sent_stream):
408 | # streams for each data in the batch
409 | streams = [None] * self.bsz
410 |
411 | data = torch.LongTensor(self.bptt, self.bsz)
412 | target = torch.LongTensor(self.bptt, self.bsz)
413 |
414 | n_retain = 0
415 |
416 | while True:
417 | # data : [n_retain+bptt x bsz]
418 | # target : [bptt x bsz]
419 | data[n_retain:].fill_(-1)
420 | target.fill_(-1)
421 |
422 | valid_batch = True
423 |
424 | for i in range(self.bsz):
425 | n_filled = 0
426 | try:
427 | while n_filled < self.bptt:
428 | if streams[i] is None or len(streams[i]) <= 1:
429 | streams[i] = next(sent_stream)
430 | # number of new tokens to fill in
431 | n_new = min(len(streams[i]) - 1, self.bptt - n_filled)
432 | # first n_retain tokens are retained from last batch
433 | data[n_retain+n_filled:n_retain+n_filled+n_new, i] = \
434 | streams[i][:n_new]
435 | target[n_filled:n_filled+n_new, i] = \
436 | streams[i][1:n_new+1]
437 | streams[i] = streams[i][n_new:]
438 | n_filled += n_new
439 | except StopIteration:
440 | valid_batch = False
441 | break
442 |
443 | if not valid_batch:
444 | return
445 |
446 | data_out = data.transpose(0, 1).contiguous().to(self.device)
447 | target_out = target.transpose(0, 1).contiguous().to(self.device)
448 |
449 | yield data_out, target_out, self.bptt
450 |
451 | n_retain = min(data.size(0), self.ext_len)
452 | if n_retain > 0:
453 | data[:n_retain] = data[-n_retain:]
454 | data.resize_(n_retain + self.bptt, data.size(1))
455 |
456 | def __iter__(self):
457 | # sent_stream is an iterator
458 | sent_stream = self.get_sent_stream()
459 |
460 | for batch in self.stream_iterator(sent_stream):
461 | yield batch
462 |
463 |
464 | class LMMultiFileIterator(LMShuffledIterator):
465 | def __init__(self, paths, vocab, bsz, bptt, device='cpu', ext_len=None,
466 | shuffle=False):
467 |
468 | self.paths = paths
469 | self.vocab = vocab
470 |
471 | self.bsz = bsz
472 | self.bptt = bptt
473 | self.ext_len = ext_len if ext_len is not None else 0
474 |
475 | self.device = device
476 | self.shuffle = shuffle
477 |
478 | def get_sent_stream(self, path):
479 | sents = self.vocab.encode_file(path, add_double_eos=True)
480 | if self.shuffle:
481 | np.random.shuffle(sents)
482 | sent_stream = iter(sents)
483 |
484 | return sent_stream
485 |
486 | def __iter__(self):
487 | if self.shuffle:
488 | np.random.shuffle(self.paths)
489 |
490 | for path in self.paths:
491 | # sent_stream is an iterator
492 | sent_stream = self.get_sent_stream(path)
493 | for batch in self.stream_iterator(sent_stream):
494 | yield batch
495 |
496 |
497 | class TransfoXLCorpus(object):
498 | @classmethod
499 | def from_pretrained(cls, pretrained_model_name_or_path, cache_dir=None, *inputs, **kwargs):
500 | """
501 | Instantiate a pre-processed corpus.
502 | """
503 | vocab = TransfoXLTokenizer.from_pretrained(pretrained_model_name_or_path, *inputs, **kwargs)
504 | if pretrained_model_name_or_path in PRETRAINED_CORPUS_ARCHIVE_MAP:
505 | corpus_file = PRETRAINED_CORPUS_ARCHIVE_MAP[pretrained_model_name_or_path]
506 | else:
507 | corpus_file = os.path.join(pretrained_model_name_or_path, CORPUS_NAME)
508 | # redirect to the cache, if necessary
509 | try:
510 | resolved_corpus_file = cached_path(corpus_file, cache_dir=cache_dir)
511 | except EnvironmentError:
512 | logger.error(
513 | "Corpus '{}' was not found in corpus list ({}). "
514 | "We assumed '{}' was a path or url but couldn't find files {} "
515 | "at this path or url.".format(
516 | pretrained_model_name_or_path,
517 | ', '.join(PRETRAINED_VOCAB_ARCHIVE_MAP.keys()),
518 | pretrained_model_name_or_path,
519 | corpus_file))
520 | return None
521 | if resolved_corpus_file == corpus_file:
522 | logger.info("loading corpus file {}".format(corpus_file))
523 | else:
524 | logger.info("loading corpus file {} from cache at {}".format(
525 | corpus_file, resolved_corpus_file))
526 |
527 | # Instantiate tokenizer.
528 | corpus = cls(*inputs, **kwargs)
529 | corpus_dict = torch.load(resolved_corpus_file)
530 | for key, value in corpus_dict.items():
531 | corpus.__dict__[key] = value
532 | corpus.vocab = vocab
533 | if corpus.train is not None:
534 | corpus.train = torch.tensor(corpus.train, dtype=torch.long)
535 | if corpus.valid is not None:
536 | corpus.valid = torch.tensor(corpus.valid, dtype=torch.long)
537 | if corpus.test is not None:
538 | corpus.test = torch.tensor(corpus.test, dtype=torch.long)
539 | return corpus
540 |
541 | def __init__(self, *args, **kwargs):
542 | self.vocab = TransfoXLTokenizer(*args, **kwargs)
543 | self.dataset = None
544 | self.train = None
545 | self.valid = None
546 | self.test = None
547 |
548 | def build_corpus(self, path, dataset):
549 | self.dataset = dataset
550 |
551 | if self.dataset in ['ptb', 'wt2', 'enwik8', 'text8']:
552 | self.vocab.count_file(os.path.join(path, 'train.txt'))
553 | self.vocab.count_file(os.path.join(path, 'valid.txt'))
554 | self.vocab.count_file(os.path.join(path, 'test.txt'))
555 | elif self.dataset == 'wt103':
556 | self.vocab.count_file(os.path.join(path, 'train.txt'))
557 | elif self.dataset == 'lm1b':
558 | train_path_pattern = os.path.join(
559 | path, '1-billion-word-language-modeling-benchmark-r13output',
560 | 'training-monolingual.tokenized.shuffled', 'news.en-*')
561 | train_paths = glob.glob(train_path_pattern)
562 | # the vocab will load from file when build_vocab() is called
563 |
564 | self.vocab.build_vocab()
565 |
566 | if self.dataset in ['ptb', 'wt2', 'wt103']:
567 | self.train = self.vocab.encode_file(
568 | os.path.join(path, 'train.txt'), ordered=True)
569 | self.valid = self.vocab.encode_file(
570 | os.path.join(path, 'valid.txt'), ordered=True)
571 | self.test = self.vocab.encode_file(
572 | os.path.join(path, 'test.txt'), ordered=True)
573 | elif self.dataset in ['enwik8', 'text8']:
574 | self.train = self.vocab.encode_file(
575 | os.path.join(path, 'train.txt'), ordered=True, add_eos=False)
576 | self.valid = self.vocab.encode_file(
577 | os.path.join(path, 'valid.txt'), ordered=True, add_eos=False)
578 | self.test = self.vocab.encode_file(
579 | os.path.join(path, 'test.txt'), ordered=True, add_eos=False)
580 | elif self.dataset == 'lm1b':
581 | self.train = train_paths
582 | self.valid = self.vocab.encode_file(
583 | os.path.join(path, 'valid.txt'), ordered=False, add_double_eos=True)
584 | self.test = self.vocab.encode_file(
585 | os.path.join(path, 'test.txt'), ordered=False, add_double_eos=True)
586 |
587 | def get_iterator(self, split, *args, **kwargs):
588 | if split == 'train':
589 | if self.dataset in ['ptb', 'wt2', 'wt103', 'enwik8', 'text8']:
590 | data_iter = LMOrderedIterator(self.train, *args, **kwargs)
591 | elif self.dataset == 'lm1b':
592 | kwargs['shuffle'] = True
593 | data_iter = LMMultiFileIterator(self.train, self.vocab, *args, **kwargs)
594 | elif split in ['valid', 'test']:
595 | data = self.valid if split == 'valid' else self.test
596 | if self.dataset in ['ptb', 'wt2', 'wt103', 'enwik8', 'text8']:
597 | data_iter = LMOrderedIterator(data, *args, **kwargs)
598 | elif self.dataset == 'lm1b':
599 | data_iter = LMShuffledIterator(data, *args, **kwargs)
600 |
601 | return data_iter
602 |
603 |
604 | def get_lm_corpus(datadir, dataset):
605 | fn = os.path.join(datadir, 'cache.pt')
606 | fn_pickle = os.path.join(datadir, 'cache.pkl')
607 | if os.path.exists(fn):
608 | print('Loading cached dataset...')
609 | corpus = torch.load(fn_pickle)
610 | elif os.path.exists(fn):
611 | print('Loading cached dataset from pickle...')
612 | with open(fn, "rb") as fp:
613 | corpus = pickle.load(fp)
614 | else:
615 | print('Producing dataset {}...'.format(dataset))
616 | kwargs = {}
617 | if dataset in ['wt103', 'wt2']:
618 | kwargs['special'] = ['']
619 | kwargs['lower_case'] = False
620 | elif dataset == 'ptb':
621 | kwargs['special'] = ['']
622 | kwargs['lower_case'] = True
623 | elif dataset == 'lm1b':
624 | kwargs['special'] = []
625 | kwargs['lower_case'] = False
626 | kwargs['vocab_file'] = os.path.join(datadir, '1b_word_vocab.txt')
627 | elif dataset in ['enwik8', 'text8']:
628 | pass
629 |
630 | corpus = TransfoXLCorpus(datadir, dataset, **kwargs)
631 | torch.save(corpus, fn)
632 |
633 | return corpus
634 |
635 | def _is_whitespace(char):
636 | """Checks whether `chars` is a whitespace character."""
637 | # \t, \n, and \r are technically contorl characters but we treat them
638 | # as whitespace since they are generally considered as such.
639 | if char == " " or char == "\t" or char == "\n" or char == "\r":
640 | return True
641 | cat = unicodedata.category(char)
642 | if cat == "Zs":
643 | return True
644 | return False
645 |
646 |
647 | def _is_control(char):
648 | """Checks whether `chars` is a control character."""
649 | # These are technically control characters but we count them as whitespace
650 | # characters.
651 | if char == "\t" or char == "\n" or char == "\r":
652 | return False
653 | cat = unicodedata.category(char)
654 | if cat.startswith("C"):
655 | return True
656 | return False
657 |
658 |
659 | def _is_punctuation(char):
660 | """Checks whether `chars` is a punctuation character."""
661 | cp = ord(char)
662 | # We treat all non-letter/number ASCII as punctuation.
663 | # Characters such as "^", "$", and "`" are not in the Unicode
664 | # Punctuation class but we treat them as punctuation anyways, for
665 | # consistency.
666 | if ((cp >= 33 and cp <= 47) or (cp >= 58 and cp <= 64) or
667 | (cp >= 91 and cp <= 96) or (cp >= 123 and cp <= 126)):
668 | return True
669 | cat = unicodedata.category(char)
670 | if cat.startswith("P"):
671 | return True
672 | return False
673 |
--------------------------------------------------------------------------------
/src/model.py:
--------------------------------------------------------------------------------
1 | import sys
2 | from pytorch_pretrained_bert.module import BertModel
3 | import torch
4 | import torch.nn as nn
5 | import pytorch_pretrained_bert as Bert
6 | import numpy as np
7 | import copy
8 | import torch.nn.functional as F
9 | import math
10 | from src.vae import *
11 |
12 |
13 |
14 | def gelu(x):
15 | return 0.5 * x * (1 + torch.tanh(math.sqrt(math.pi / 2) * (x + 0.044715 * x ** 3)))
16 | class BertConfig(Bert.modeling.BertConfig):
17 | def __init__(self, config):
18 | super(BertConfig, self).__init__(
19 | vocab_size_or_config_json_file=config.get('vocab_size'),
20 | hidden_size=config['hidden_size'],
21 | num_hidden_layers=config.get('num_hidden_layers'),
22 | num_attention_heads=config.get('num_attention_heads'),
23 | intermediate_size=config.get('intermediate_size'),
24 | hidden_act=config.get('hidden_act'),
25 | hidden_dropout_prob=config.get('hidden_dropout_prob'),
26 | attention_probs_dropout_prob=config.get('attention_probs_dropout_prob'),
27 | max_position_embeddings = config.get('max_position_embedding'),
28 | initializer_range=config.get('initializer_range'),
29 | )
30 | self.seg_vocab_size = config.get('seg_vocab_size')
31 | self.age_vocab_size = config.get('age_vocab_size')
32 | self.num_treatment = config.get('num_treatment')
33 | self.device = config.get('device')
34 | self.year_vocab_size = config.get('year_vocab_size')
35 |
36 | if config.get('poolingSize') is not None:
37 | self.poolingSize = config.get('poolingSize')
38 | if config.get('MEM') is not None:
39 | self.MEM = config.get('MEM')
40 | else:
41 | self.MEM=False
42 | if config.get('unsupSize') is not None:
43 | self.unsupSize = config.get('unsupSize')
44 | if config.get('unsupVAE') is not None:
45 | self.unsupVAE = True
46 | else:
47 | self.unsupVAE = False
48 | if config.get("vaeinchannels") is not None:
49 | self.vaeinchannels = config.get('vaeinchannels')
50 | if config.get("vaelatentdim") is not None:
51 | self.vaelatentdim = config.get('vaelatentdim')
52 | if config.get("vaehidden") is not None:
53 | self.vaehidden = config.get('vaehidden')
54 | if config.get("klpar") is not None:
55 | self.klpar = config.get('klpar')
56 | else:
57 | self.klpar = 1
58 | if config.get('BetaD') is not None:
59 | self.BetaD = config.get('BetaD')
60 | else:
61 | self.BetaD = False
62 |
63 | class BertEmbeddingsUnsup(nn.Module):
64 | """Construct the embeddings from word, segment, age
65 | """
66 |
67 | def __init__(self, config):
68 | super(BertEmbeddingsUnsup, self).__init__()
69 | self.word_embeddings = nn.Embedding(config.vocab_size, config.hidden_size)
70 | # self.segment_embeddings = nn.Embedding(config.seg_vocab_size, config.hidden_size)
71 | self.age_embeddings = nn.Embedding(config.age_vocab_size, config.hidden_size)
72 | self.posi_embeddings = nn.Embedding(config.max_position_embeddings, config.hidden_size).\
73 | from_pretrained(embeddings=self._init_posi_embedding(config.max_position_embeddings, config.hidden_size))
74 | self.year_embeddings = nn.Embedding(config.year_vocab_size, config.hidden_size)
75 | self.unsuplist = config.unsupSize
76 | sumInputTabular = sum([el[1] for el in self.unsuplist])
77 | self.unsupEmbeddings = nn.ModuleList([nn.Embedding(el[0], el[1]) for el in self.unsuplist])
78 | self.unsupLinear = nn.Linear(sumInputTabular, config.hidden_size)
79 |
80 | self.LayerNorm = Bert.modeling.BertLayerNorm(config.hidden_size, eps=1e-12)
81 |
82 | self.dropout = nn.Dropout(config.hidden_dropout_prob)
83 |
84 | def forward(self, word_ids, age_ids=None, seg_ids=None, posi_ids=None, year_ids=None):
85 | if seg_ids is None:
86 | seg_ids = torch.zeros_like(word_ids)
87 | if age_ids is None:
88 | age_ids = torch.zeros_like(word_ids)
89 | if posi_ids is None:
90 | posi_ids = torch.zeros_like(word_ids)
91 |
92 | tabularVar = word_ids[:,:len(self.unsuplist)]
93 | word_ids = word_ids[:,len(self.unsuplist):]
94 |
95 |
96 | word_embed = self.word_embeddings(word_ids)
97 | age_embed = self.age_embeddings(age_ids)
98 | posi_embeddings = self.posi_embeddings(posi_ids)
99 | year_embed = self.year_embeddings(year_ids)
100 | tabularVar = tabularVar.transpose(1,0)
101 | tabularVarembed = torch.cat([self.unsupEmbeddings[eliter](el) for eliter, el in enumerate(tabularVar)], dim=1)
102 | tabularVarembed = self.unsupLinear(tabularVarembed).unsqueeze(1)
103 | embeddings = word_embed + age_embed + year_embed + posi_embeddings
104 | embeddings = torch.cat((tabularVarembed, embeddings), dim=1)
105 |
106 | embeddings = self.LayerNorm(embeddings)
107 |
108 | embeddings = self.dropout(embeddings)
109 | return embeddings
110 |
111 | def _init_posi_embedding(self, max_position_embedding, hidden_size):
112 | def even_code(pos, idx):
113 | return np.sin(pos/(10000**(2*idx/hidden_size)))
114 |
115 | def odd_code(pos, idx):
116 | return np.cos(pos/(10000**(2*idx/hidden_size)))
117 |
118 | # initialize position embedding table
119 | lookup_table = np.zeros((max_position_embedding, hidden_size), dtype=np.float32)
120 |
121 | # reset table parameters with hard encoding
122 | # set even dimension
123 | for pos in range(max_position_embedding):
124 | for idx in np.arange(0, hidden_size, step=2):
125 | lookup_table[pos, idx] = even_code(pos, idx)
126 | # set odd dimension
127 | for pos in range(max_position_embedding):
128 | for idx in np.arange(1, hidden_size, step=2):
129 | lookup_table[pos, idx] = odd_code(pos, idx)
130 |
131 | return torch.tensor(lookup_table)
132 |
133 | class BertLastPooler(nn.Module):
134 | def __init__(self, config):
135 | super(BertLastPooler, self).__init__()
136 | self.dense = nn.Linear(config.hidden_size, config.poolingSize)
137 | self.activation = nn.Tanh()
138 |
139 | def forward(self, hidden_states):
140 | # We "pool" the model by simply taking the hidden state corresponding
141 | # to the last token. this is unlike first token pooling, just switched the ordering of the data
142 | first_token_tensor = hidden_states[:, -1]
143 | pooled_output = self.dense(first_token_tensor)
144 | pooled_output = self.activation(pooled_output)
145 | return pooled_output
146 |
147 | class BertVAEPooler(nn.Module):
148 | def __init__(self, config):
149 | super(BertVAEPooler, self).__init__()
150 | self.dense = nn.Linear(config.hidden_size, config.poolingSize)
151 | self.activation = nn.Tanh()
152 |
153 | def forward(self, hidden_states):
154 | # We "pool" the model by simply taking the hidden state corresponding
155 | # to the first token.
156 | first_token_tensor = hidden_states[:, 0]
157 | pooled_output = self.dense(first_token_tensor)
158 | pooled_output = self.activation(pooled_output)
159 | return pooled_output
160 |
161 |
162 |
163 | class TBEHRT(Bert.modeling.BertPreTrainedModel):
164 | def __init__(self, config, num_labels):
165 | super(TBEHRT, self).__init__(config)
166 | if 'cuda' in config.device:
167 |
168 | self.device = int(config.device[-1])
169 |
170 | self.otherDevice = self.device
171 | else:
172 | self.device = 'cpu'
173 | self.otherDevice = self.device
174 | self.bert = SimpleBEHRT (config)
175 |
176 | # self.bert = BertModel(config)
177 | self.treatmentC = nn.Linear(config.poolingSize, config.num_treatment)
178 | self.OutcomeT1_1 = nn.Linear(config.poolingSize, config.poolingSize)
179 | self.OutcomeT2_1 = nn.Linear(config.poolingSize, config.poolingSize)
180 | self.OutcomeT3_1 = nn.Linear(config.poolingSize, 2)
181 |
182 | self.OutcomeT1_2 = nn.Linear(config.poolingSize, config.poolingSize)
183 | self.OutcomeT2_2 = nn.Linear(config.poolingSize, config.poolingSize)
184 | self.OutcomeT3_2 = nn.Linear(config.poolingSize, 2)
185 | self.config = config
186 | self.dropoutVAE = nn.Dropout(config.hidden_dropout_prob)
187 |
188 | self.dropout = nn.Dropout(config.hidden_dropout_prob)
189 | self.treatmentC.to(self.otherDevice)
190 |
191 | self.OutcomeT1_1.to(self.otherDevice)
192 | self.OutcomeT2_1.to(self.otherDevice)
193 |
194 | self.OutcomeT3_1.to(self.otherDevice)
195 |
196 |
197 | self.OutcomeT1_2.to(self.otherDevice)
198 | self.OutcomeT2_2.to(self.otherDevice)
199 |
200 | self.OutcomeT3_2.to(self.otherDevice)
201 | self.gelu = nn.ELU()
202 | self.num_labels = num_labels
203 | self.num_treatment = config.num_treatment
204 | self.logS = nn.LogSoftmax()
205 |
206 | self.VAEpooler = BertVAEPooler(config)
207 | self.VAEpooler.to(self.otherDevice)
208 | self.VAE = VAE(config)
209 | self.VAE.to(self.otherDevice)
210 | self.treatmentW = 1.0
211 | self.MEM = False
212 | self.config = config
213 | if config.MEM is True:
214 | self.MEM = True
215 | print('turning on the MEM....')
216 | self.cls = Bert.modeling.BertOnlyMLMHead(config, self.bert.bert.embeddings.word_embeddings.weight)
217 | self.cls.to(self.otherDevice)
218 | self.apply(self.init_bert_weights)
219 | print("full init completed...")
220 |
221 | def forward(self, input_ids, age_ids=None, seg_ids=None, posi_ids=None, year_ids = None, attention_mask=None, masked_lm_labels=None,
222 | outcomeT=None, treatmentCLabel=None, fullEval=False, vaelabel = None):
223 | batchs = input_ids.shape[0]
224 | embed4MLM, pooled_out = self.bert(input_ids, age_ids, seg_ids, posi_ids , year_ids, attention_mask,
225 | output_all_encoded_layers=False, fullmask = None)
226 |
227 | treatmentCLabel = treatmentCLabel.to(self.otherDevice)
228 | pooled_outVAE = self.VAEpooler(embed4MLM)
229 | pooled_outVAE = self.dropoutVAE(pooled_outVAE)
230 |
231 | outcomeT = outcomeT.to(self.otherDevice)
232 | outputVAE = self.VAE(pooled_outVAE, vaelabel)
233 |
234 | if self.MEM==True:
235 | prediction_scores = self.cls(embed4MLM[:,1:])
236 |
237 | if masked_lm_labels is not None:
238 | loss_fct = nn.CrossEntropyLoss(ignore_index=-1)
239 | masked_lm_loss = loss_fct(prediction_scores.view(-1, self.config.vocab_size), masked_lm_labels.view(-1))
240 |
241 | else:
242 | prediction_scores = torch.ones([batchs, self.config.vocab_size]).to(self.otherDevice)
243 | masked_lm_loss = torch.tensor([0.0]).to(self.otherDevice)
244 |
245 | treatmentOut = self.treatmentC(pooled_out)
246 | tloss_fuct = nn.NLLLoss(reduction='mean')
247 | lossT = tloss_fuct(self.logS(treatmentOut).view(-1, self.config.num_treatment), treatmentCLabel.view(-1))
248 | pureSoft = nn.Softmax(dim=1)
249 |
250 | treatmentOut = pureSoft(treatmentOut)
251 |
252 | out1 = self.gelu(self.OutcomeT1_1(pooled_out))
253 | out2 = self.gelu(self.OutcomeT2_1(out1))
254 | logits0 = (self.OutcomeT3_1(out2))
255 |
256 | out12 = self.gelu(self.OutcomeT1_2(pooled_out))
257 | out22 = self.gelu(self.OutcomeT2_2(out12))
258 | logits1 = (self.OutcomeT3_2(out22))
259 |
260 |
261 |
262 | outcome1loss = nn.CrossEntropyLoss(reduction='none')
263 | outcome0loss = nn.CrossEntropyLoss(reduction='none')
264 | lossRaw0 = outcome0loss(logits0,outcomeT.squeeze(-1))
265 | lossRaw1 = outcome1loss(logits1, outcomeT.squeeze(-1))
266 |
267 | trueLoss1 = torch.mean(lossRaw1*(treatmentCLabel.type(torch.FloatTensor).squeeze(-1)).to(self.otherDevice))
268 | trueLoss0 = torch.mean(lossRaw0*(1-treatmentCLabel.type(torch.FloatTensor).squeeze(-1)).to(self.otherDevice))
269 |
270 | tloss = trueLoss0 +trueLoss1 + self.treatmentW*lossT
271 |
272 |
273 | outlog1 = pureSoft(logits1)[:,1]
274 | outlog0 = pureSoft(logits0)[:,1]
275 | outlogits = torch.cat((outlog0.view(-1, 1).unsqueeze(0), outlog1.view(-1, 1).unsqueeze(0)), dim=0 )
276 | fulTout = treatmentOut
277 | outputTreatIndex = treatmentCLabel
278 | outlabelsfull = outcomeT
279 |
280 |
281 | # vae loss
282 | vaeloss = self.VAE.loss_function(outputVAE)
283 | vaeloss_total = vaeloss['loss']
284 | masked_lm_loss= masked_lm_loss+ vaeloss_total
285 | return masked_lm_loss, tloss, prediction_scores.view(-1, self.config.vocab_size), masked_lm_labels.contiguous().view( -1), fulTout, outputTreatIndex, outlogits, outlabelsfull, outputTreatIndex, 0, vaeloss
286 |
287 |
288 | class TARNET_MEM(Bert.modeling.BertPreTrainedModel):
289 | def __init__(self, config, num_labels):
290 | super(TARNET_MEM, self).__init__(config)
291 | if 'cuda' in config.device:
292 |
293 | self.device = int(config.device[-1])
294 |
295 | self.otherDevice = self.device
296 | else:
297 | self.device = 'cpu'
298 | self.otherDevice = self.device
299 | self.bert = SimpleBEHRT (config)
300 |
301 | # self.bert = BertModel(config)
302 | self.treatmentC = nn.Linear(config.poolingSize, config.num_treatment)
303 | self.OutcomeT1_1 = nn.Linear(config.poolingSize, config.poolingSize)
304 | self.OutcomeT2_1 = nn.Linear(config.poolingSize, config.poolingSize)
305 | self.OutcomeT3_1 = nn.Linear(config.poolingSize, 2)
306 |
307 |
308 | self.OutcomeT1_2 = nn.Linear(config.poolingSize, config.poolingSize)
309 | self.OutcomeT2_2 = nn.Linear(config.poolingSize, config.poolingSize)
310 | self.OutcomeT3_2 = nn.Linear(config.poolingSize, 2)
311 | self.config = config
312 | self.dropoutVAE = nn.Dropout(config.hidden_dropout_prob)
313 |
314 | self.dropout = nn.Dropout(config.hidden_dropout_prob)
315 | self.treatmentC.to(self.otherDevice)
316 |
317 | self.OutcomeT1_1.to(self.otherDevice)
318 | self.OutcomeT2_1.to(self.otherDevice)
319 |
320 | self.OutcomeT3_1.to(self.otherDevice)
321 |
322 |
323 | self.OutcomeT1_2.to(self.otherDevice)
324 | self.OutcomeT2_2.to(self.otherDevice)
325 |
326 | self.OutcomeT3_2.to(self.otherDevice)
327 | self.gelu = nn.ELU()
328 | self.num_labels = num_labels
329 | self.num_treatment = config.num_treatment
330 | self.logS = nn.LogSoftmax()
331 |
332 |
333 |
334 |
335 |
336 | self.VAEpooler = BertVAEPooler(config)
337 | self.VAEpooler.to(self.otherDevice)
338 | self.VAE = VAE(config)
339 | self.VAE.to(self.otherDevice)
340 | self.treatmentW = 1.0
341 | self.MEM = False
342 | self.config = config
343 | if config.MEM is True:
344 | self.MEM = True
345 | print('turning on the MEM....')
346 | self.cls = Bert.modeling.BertOnlyMLMHead(config, self.bert.bert.embeddings.word_embeddings.weight)
347 | self.cls.to(self.otherDevice)
348 | self.apply(self.init_bert_weights)
349 | print("full init completed...")
350 |
351 | def forward(self, input_ids, age_ids=None, seg_ids=None, posi_ids=None, year_ids = None, attention_mask=None, masked_lm_labels=None,
352 | outcomeT=None, treatmentCLabel=None, fullEval=False, vaelabel = None):
353 | batchs = input_ids.shape[0]
354 | embed4MLM, pooled_out = self.bert(input_ids, age_ids, seg_ids, posi_ids , year_ids, attention_mask,
355 | output_all_encoded_layers=False, fullmask = None)
356 |
357 | treatmentCLabel = treatmentCLabel.to(self.otherDevice)
358 | #
359 | pooled_outVAE = self.VAEpooler(embed4MLM)
360 | pooled_outVAE = self.dropoutVAE(pooled_outVAE)
361 |
362 | outcomeT = outcomeT.to(self.otherDevice)
363 | outputVAE = self.VAE(pooled_outVAE, vaelabel)
364 |
365 | masked_lm_loss = torch.tensor([0.0]).to(self.otherDevice)
366 |
367 | if self.MEM==True:
368 | prediction_scores = self.cls(embed4MLM[:,1:])
369 |
370 | if masked_lm_labels is not None:
371 | loss_fct = nn.CrossEntropyLoss(ignore_index=-1)
372 | masked_lm_loss = loss_fct(prediction_scores.view(-1, self.config.vocab_size), masked_lm_labels.view(-1))
373 |
374 | else:
375 | prediction_scores = torch.ones([batchs, self.config.vocab_size]).to(self.otherDevice)
376 | masked_lm_loss = torch.tensor([0.0]).to(self.otherDevice)
377 |
378 |
379 | treatmentOut = self.treatmentC(pooled_out)
380 | tloss_fuct = nn.NLLLoss(reduction='mean')
381 | pureSoft = nn.Softmax(dim=1)
382 |
383 | treatmentOut = pureSoft(treatmentOut)
384 |
385 | out1 = self.gelu(self.OutcomeT1_1(pooled_out))
386 | out2 = self.gelu(self.OutcomeT2_1(out1))
387 | logits0 = (self.OutcomeT3_1(out2))
388 |
389 | out12 = self.gelu(self.OutcomeT1_2(pooled_out))
390 | out22 = self.gelu(self.OutcomeT2_2(out12))
391 | logits1 = (self.OutcomeT3_2(out22))
392 |
393 |
394 | outcome1loss = nn.CrossEntropyLoss(reduction='none')
395 | outcome0loss = nn.CrossEntropyLoss(reduction='none')
396 | lossRaw0 = outcome0loss(logits0,outcomeT.squeeze(-1))
397 | lossRaw1 = outcome1loss(logits1, outcomeT.squeeze(-1))
398 |
399 | trueLoss1 = torch.mean(lossRaw1*(treatmentCLabel.type(torch.FloatTensor).squeeze(-1)).to(self.otherDevice))
400 | trueLoss0 = torch.mean(lossRaw0*(1-treatmentCLabel.type(torch.FloatTensor).squeeze(-1)).to(self.otherDevice))
401 |
402 | tloss = trueLoss0 +trueLoss1
403 | outlog1 = pureSoft(logits1)[:,1]
404 |
405 | outlog0 = pureSoft(logits0)[:,1]
406 |
407 |
408 | outlogits = torch.cat((outlog0.view(-1, 1).unsqueeze(0), outlog1.view(-1, 1).unsqueeze(0)), dim=0 )
409 | fulTout = treatmentOut
410 |
411 | outputTreatIndex = treatmentCLabel
412 | outlabelsfull = outcomeT
413 |
414 |
415 |
416 |
417 |
418 |
419 |
420 | # vae loss
421 | vaeloss = self.VAE.loss_function(outputVAE)
422 | vaeloss_total = vaeloss['loss']
423 | masked_lm_loss= masked_lm_loss+ vaeloss_total
424 | return masked_lm_loss, tloss, prediction_scores.view(-1, self.config.vocab_size), masked_lm_labels.contiguous().view( -1), fulTout, outputTreatIndex, outlogits, outlabelsfull, outputTreatIndex, 0, vaeloss
425 |
426 |
427 |
428 |
429 | class SimpleBEHRT(Bert.modeling.BertPreTrainedModel):
430 | def __init__(self, config, num_labels=1):
431 | super(SimpleBEHRT, self).__init__(config)
432 |
433 | self.bert = BEHRTBASE(config)
434 | if 'cuda' in config.device:
435 |
436 | self.device = int(config.device[-1])
437 |
438 | self.otherDevice = self.device
439 | else:
440 | self.device = 'cpu'
441 | self.otherDevice = self.device
442 |
443 |
444 | self.bert.to(self.device)
445 |
446 | self.num_labels = num_labels
447 |
448 | self.dropout = nn.Dropout(config.hidden_dropout_prob)
449 |
450 | self.pooler = BertLastPooler(config)
451 | self.pooler.to(self.otherDevice)
452 |
453 |
454 | self.config = config
455 | self.apply(self.init_bert_weights)
456 |
457 | def forward(self, input_ids, age_ids=None, seg_ids = None, posi_ids=None, year_ids=None, attention_mask=None,output_all_encoded_layers = False, fullmask = None):
458 | batchS = input_ids.shape[0]
459 | sequence_output, embedding_outputLSTM, embedding_outputLSTM2, attention_maskLSTM = self.bert(input_ids, age_ids ,seg_ids, posi_ids, year_ids, attention_mask, fullmask = fullmask)
460 |
461 | pooled_out = self.pooler(embedding_outputLSTM)
462 | pooled_out = self.dropout(pooled_out)
463 | return embedding_outputLSTM, pooled_out
464 |
465 | class BEHRTBASE(Bert.modeling.BertPreTrainedModel):
466 | def __init__(self, config):
467 | super(BEHRTBASE, self).__init__(config)
468 |
469 | self.embeddings = BertEmbeddingsUnsup(config=config)
470 |
471 | self.encoder = Bert.modeling.BertEncoder(config=config)
472 | self.config = config
473 |
474 | self.apply(self.init_bert_weights)
475 |
476 | def forward(self, input_ids, age_ids=None, seg_ids = None, posi_ids=None, year_ids=None, attention_mask=None, fullmask = None):
477 | if attention_mask is None:
478 | attention_mask = torch.ones_like(input_ids)
479 | if age_ids is None:
480 | age_ids = torch.zeros_like(input_ids)
481 |
482 | if posi_ids is None:
483 | posi_ids = torch.zeros_like(input_ids)
484 |
485 | # We create a 3D attention mask from a 2D tensor mask.
486 | # Sizes are [batch_size, 1, 1, to_seq_length]
487 | # So we can broadcast to [batch_size, num_heads, from_seq_length, to_seq_length]
488 | # this attention mask is more simple than the triangular masking of causal attention
489 | # used in OpenAI GPT, we just need to prepare the broadcast dimension here.
490 | encodermask = attention_mask
491 | encodermask = encodermask.unsqueeze(1).unsqueeze(2)
492 | encodermask = encodermask.to(dtype=next(self.parameters()).dtype) # fp16 compatibility
493 | encodermask = (1.0 - encodermask) * -10000.0
494 |
495 |
496 |
497 | # Since attention_mask is 1.0 for positions we want to attend and 0.0 for
498 | # masked positions, this operation will create a tensor which is 0.0 for
499 | # positions we want to attend and -10000.0 for masked positions.
500 | # Since we are adding it to the raw scores before the softmax, this is
501 | # effectively the same as removing these entirely.
502 |
503 | embedding_output = self.embeddings(input_ids, age_ids, seg_ids, posi_ids, year_ids)
504 |
505 | encoded_layers = self.encoder(embedding_output,
506 | encodermask,
507 | output_all_encoded_layers=False)
508 | sequenceOut = encoded_layers[-1]
509 | return [0], sequenceOut, [0], [0]
510 |
511 |
512 | class DRAGONNET(Bert.modeling.BertPreTrainedModel):
513 | def __init__(self, config, num_labels):
514 | super(DRAGONNET, self).__init__(config)
515 | if 'cuda' in config.device:
516 |
517 | self.device = int(config.device[-1])
518 |
519 | self.otherDevice = self.device
520 | else:
521 | self.device = 'cpu'
522 | self.otherDevice = self.device
523 | self.bert = SimpleBEHRT (config)
524 |
525 | self.treatmentC = nn.Linear(config.poolingSize, config.num_treatment)
526 | self.OutcomeT1_1 = nn.Linear(config.poolingSize, config.poolingSize)
527 | self.OutcomeT2_1 = nn.Linear(config.poolingSize, config.poolingSize)
528 | self.OutcomeT3_1 = nn.Linear(config.poolingSize, 2)
529 |
530 |
531 | self.OutcomeT1_2 = nn.Linear(config.poolingSize, config.poolingSize)
532 | self.OutcomeT2_2 = nn.Linear(config.poolingSize, config.poolingSize)
533 | self.OutcomeT3_2 = nn.Linear(config.poolingSize, 2)
534 | self.config = config
535 | self.dropoutVAE = nn.Dropout(config.hidden_dropout_prob)
536 |
537 | self.dropout = nn.Dropout(config.hidden_dropout_prob)
538 | self.treatmentC.to(self.otherDevice)
539 |
540 | self.OutcomeT1_1.to(self.otherDevice)
541 | self.OutcomeT2_1.to(self.otherDevice)
542 |
543 | self.OutcomeT3_1.to(self.otherDevice)
544 |
545 |
546 | self.OutcomeT1_2.to(self.otherDevice)
547 | self.OutcomeT2_2.to(self.otherDevice)
548 |
549 | self.OutcomeT3_2.to(self.otherDevice)
550 | self.gelu = nn.ELU()
551 | self.num_labels = num_labels
552 | self.num_treatment = config.num_treatment
553 | self.logS = nn.LogSoftmax()
554 |
555 |
556 |
557 | self.VAEpooler = BertVAEPooler(config)
558 | self.VAEpooler.to(self.otherDevice)
559 | self.VAE = VAE(config)
560 | self.VAE.to(self.otherDevice)
561 | self.treatmentW = 1.0
562 | self.MEM = False
563 | self.config = config
564 | if config.MEM is True:
565 | self.MEM = True
566 | print('turning on the MEM....')
567 | self.cls = Bert.modeling.BertOnlyMLMHead(config, self.bert.bert.embeddings.word_embeddings.weight)
568 | self.cls.to(self.otherDevice)
569 | self.apply(self.init_bert_weights)
570 | print("full init completed...")
571 |
572 | def forward(self, input_ids, age_ids=None, seg_ids=None, posi_ids=None, year_ids = None, attention_mask=None, masked_lm_labels=None,
573 | outcomeT=None, treatmentCLabel=None, fullEval=False, vaelabel = None):
574 | batchs = input_ids.shape[0]
575 | embed4MLM, pooled_out = self.bert(input_ids, age_ids, seg_ids, posi_ids , year_ids, attention_mask,
576 | output_all_encoded_layers=False, fullmask = None)
577 |
578 | treatmentCLabel = treatmentCLabel.to(self.otherDevice)
579 | #
580 | pooled_outVAE = self.VAEpooler(embed4MLM)
581 | pooled_outVAE = self.dropoutVAE(pooled_outVAE)
582 |
583 | outcomeT = outcomeT.to(self.otherDevice)
584 | outputVAE = self.VAE(pooled_outVAE, vaelabel)
585 |
586 | if self.MEM==True:
587 | prediction_scores = self.cls(embed4MEM[:,1:])
588 |
589 | if masked_lm_labels is not None:
590 | loss_fct = nn.CrossEntropyLoss(ignore_index=-1)
591 | masked_lm_loss = loss_fct(prediction_scores.view(-1, self.config.vocab_size), masked_lm_labels.view(-1))
592 |
593 | else:
594 | prediction_scores = torch.ones([batchs, self.config.vocab_size]).to(self.otherDevice)
595 | masked_lm_loss = torch.tensor([0.0]).to(self.otherDevice)
596 |
597 | treatmentOut = self.treatmentC(pooled_out)
598 | tloss_fuct = nn.NLLLoss(reduction='mean')
599 | lossT = tloss_fuct(self.logS(treatmentOut).view(-1, self.config.num_treatment), treatmentCLabel.view(-1))
600 | pureSoft = nn.Softmax(dim=1)
601 |
602 | treatmentOut = pureSoft(treatmentOut)
603 |
604 | out1 = self.gelu(self.OutcomeT1_1(pooled_out))
605 | out2 = self.gelu(self.OutcomeT2_1(out1))
606 | logits0 = (self.OutcomeT3_1(out2))
607 |
608 | out12 = self.gelu(self.OutcomeT1_2(pooled_out))
609 | out22 = self.gelu(self.OutcomeT2_2(out12))
610 | logits1 = (self.OutcomeT3_2(out22))
611 | outcome1loss = nn.CrossEntropyLoss(reduction='none')
612 | outcome0loss = nn.CrossEntropyLoss(reduction='none')
613 | lossRaw0 = outcome0loss(logits0,outcomeT.squeeze(-1))
614 | lossRaw1 = outcome1loss(logits1, outcomeT.squeeze(-1))
615 | trueLoss1 = torch.mean(lossRaw1*(treatmentCLabel.type(torch.FloatTensor).squeeze(-1)).to(self.otherDevice))
616 | trueLoss0 = torch.mean(lossRaw0*(1-treatmentCLabel.type(torch.FloatTensor).squeeze(-1)).to(self.otherDevice))
617 |
618 | tloss = trueLoss0 +trueLoss1 + self.treatmentW*lossT
619 | outlog1 = pureSoft(logits1)[:,1]
620 |
621 | outlog0 = pureSoft(logits0)[:,1]
622 |
623 |
624 | outlogits = torch.cat((outlog0.view(-1, 1).unsqueeze(0), outlog1.view(-1, 1).unsqueeze(0)), dim=0 )
625 | fulTout = treatmentOut
626 |
627 | outputTreatIndex = treatmentCLabel
628 | outlabelsfull = outcomeT
629 |
630 |
631 |
632 | # vae loss
633 | vaeloss = self.VAE.loss_function(outputVAE)
634 | vaeloss_total = vaeloss['loss']
635 | masked_lm_loss= masked_lm_loss+ vaeloss_total
636 | return masked_lm_loss, tloss, prediction_scores.view(-1, self.config.vocab_size), masked_lm_labels.contiguous().view( -1), fulTout, outputTreatIndex, outlogits, outlabelsfull, outputTreatIndex, 0, vaeloss
637 |
--------------------------------------------------------------------------------