├── examples ├── test ├── __init__.py ├── save_models │ ├── tt │ ├── TBEHRT_Test__CUT0.bin │ ├── TBEHRT_Test__CUT1.bin │ ├── TBEHRT_Test__CUT2.bin │ ├── TBEHRT_Test__CUT3.bin │ └── TBEHRT_Test__CUT4.bin ├── test.parquet ├── TBEHRT_Test__CUT0.npz ├── TBEHRT_Test__CUT1.npz ├── TBEHRT_Test__CUT2.npz ├── TBEHRT_Test__CUT3.npz ├── TBEHRT_Test__CUT4.npz └── CVTMLE_example.ipynb ├── screenshot.png ├── .idea ├── misc.xml ├── vcs.xml ├── .gitignore ├── inspectionProfiles │ └── profiles_settings.xml ├── modules.xml └── Targeted_BEHRT.iml ├── LICENSE ├── README.md ├── pytorch_pretrained_bert ├── __init__.py ├── convert_tf_checkpoint_to_pytorch.py ├── convert_gpt2_checkpoint_to_pytorch.py ├── convert_openai_checkpoint_to_pytorch.py ├── optimizer.py ├── module.py ├── __main__.py ├── convert_transfo_xl_checkpoint_to_pytorch.py ├── optimization_openai.py ├── optimization.py ├── tokenization_gpt2.py ├── file_utils.py ├── tokenization_openai.py ├── modeling_transfo_xl_utilities.py ├── tokenization.py └── tokenization_transfo_xl.py └── src ├── data.py ├── vae.py ├── CV_TMLE.py ├── utils.py └── model.py /examples/test: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /examples/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /examples/save_models/tt: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /screenshot.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/deepmedicine/Targeted-BEHRT/HEAD/screenshot.png -------------------------------------------------------------------------------- /examples/test.parquet: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/deepmedicine/Targeted-BEHRT/HEAD/examples/test.parquet -------------------------------------------------------------------------------- /examples/TBEHRT_Test__CUT0.npz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/deepmedicine/Targeted-BEHRT/HEAD/examples/TBEHRT_Test__CUT0.npz -------------------------------------------------------------------------------- /examples/TBEHRT_Test__CUT1.npz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/deepmedicine/Targeted-BEHRT/HEAD/examples/TBEHRT_Test__CUT1.npz -------------------------------------------------------------------------------- /examples/TBEHRT_Test__CUT2.npz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/deepmedicine/Targeted-BEHRT/HEAD/examples/TBEHRT_Test__CUT2.npz -------------------------------------------------------------------------------- /examples/TBEHRT_Test__CUT3.npz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/deepmedicine/Targeted-BEHRT/HEAD/examples/TBEHRT_Test__CUT3.npz -------------------------------------------------------------------------------- /examples/TBEHRT_Test__CUT4.npz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/deepmedicine/Targeted-BEHRT/HEAD/examples/TBEHRT_Test__CUT4.npz -------------------------------------------------------------------------------- /examples/save_models/TBEHRT_Test__CUT0.bin: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/deepmedicine/Targeted-BEHRT/HEAD/examples/save_models/TBEHRT_Test__CUT0.bin -------------------------------------------------------------------------------- /examples/save_models/TBEHRT_Test__CUT1.bin: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/deepmedicine/Targeted-BEHRT/HEAD/examples/save_models/TBEHRT_Test__CUT1.bin -------------------------------------------------------------------------------- /examples/save_models/TBEHRT_Test__CUT2.bin: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/deepmedicine/Targeted-BEHRT/HEAD/examples/save_models/TBEHRT_Test__CUT2.bin -------------------------------------------------------------------------------- /examples/save_models/TBEHRT_Test__CUT3.bin: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/deepmedicine/Targeted-BEHRT/HEAD/examples/save_models/TBEHRT_Test__CUT3.bin -------------------------------------------------------------------------------- /examples/save_models/TBEHRT_Test__CUT4.bin: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/deepmedicine/Targeted-BEHRT/HEAD/examples/save_models/TBEHRT_Test__CUT4.bin -------------------------------------------------------------------------------- /.idea/misc.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | -------------------------------------------------------------------------------- /.idea/vcs.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | -------------------------------------------------------------------------------- /.idea/.gitignore: -------------------------------------------------------------------------------- 1 | # Default ignored files 2 | /shelf/ 3 | /workspace.xml 4 | # Datasource local storage ignored files 5 | /dataSources/ 6 | /dataSources.local.xml 7 | # Editor-based HTTP Client requests 8 | /httpRequests/ 9 | -------------------------------------------------------------------------------- /.idea/inspectionProfiles/profiles_settings.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 6 | -------------------------------------------------------------------------------- /.idea/modules.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | -------------------------------------------------------------------------------- /.idea/Targeted_BEHRT.iml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | 10 | 11 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2022 Shishir Rao 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Targeted BEHRT 2 | Repository for publication: Targeted-BEHRT: Deep Learning for Observational Causal Inference on Longitudinal Electronic Health Records
3 | IEEE Transactions on Neural Networks and Learning Systems; Special Issue on Causality
4 | https://ieeexplore.ieee.org/document/9804397/
5 | DOI: 10.1109/TNNLS.2022.3183864.
6 | 7 | ![Screenshot](screenshot.png) 8 | 9 | How to use:
10 | In "examples" folder, run the "run_TBEHRT.ipynb" file. A test.csv file is provided to test/play and demonstrate how the vocabulary/year/age/etc function (please read full paper linked above for further methodological details).
11 | Furthermoree, in the examples folder to run the CV-TMLE estimator, run the "CVTMLE_example.ipynb" file. A host of fake fold data is provided to test/play and demonstrate how the CV-TMLE algorithm works (please read methods publication of CV-TMLE for further details).
12 | 13 | The files in the "src" folder contain model and data handling packages in addition to other necessary VAE relevant files and helper functions. 14 | 15 | Requirements:
16 | torch >1.6.0
17 | numpy 1.19.2
18 | sklearn 0.23.2
19 | pandas 1.1.3
20 |
21 | -------------------------------------------------------------------------------- /pytorch_pretrained_bert/__init__.py: -------------------------------------------------------------------------------- 1 | __version__ = "0.6.1" 2 | # from .tokenization import BertTokenizer, BasicTokenizer, WordpieceTokenizer 3 | # from .tokenization_openai import OpenAIGPTTokenizer 4 | # from .tokenization_transfo_xl import (TransfoXLTokenizer, TransfoXLCorpus) 5 | # from .tokenization_gpt2 import GPT2Tokenizer 6 | 7 | from .modeling import (BertConfig, BertModel, BertForPreTraining, 8 | BertForMaskedLM, BertForNextSentencePrediction, 9 | BertForSequenceClassification, BertForMultipleChoice, 10 | BertForTokenClassification, BertForQuestionAnswering, 11 | load_tf_weights_in_bert) 12 | 13 | 14 | from .modeling_openai import (OpenAIGPTConfig, OpenAIGPTModel, 15 | OpenAIGPTLMHeadModel, OpenAIGPTDoubleHeadsModel, 16 | load_tf_weights_in_openai_gpt) 17 | from .modeling_transfo_xl import (TransfoXLConfig, TransfoXLModel, TransfoXLLMHeadModel, 18 | load_tf_weights_in_transfo_xl) 19 | from .modeling_gpt2 import (GPT2Config, GPT2Model, 20 | GPT2LMHeadModel, GPT2DoubleHeadsModel, 21 | load_tf_weights_in_gpt2) 22 | 23 | from .optimization import BertAdam 24 | from .optimization_openai import OpenAIAdam 25 | 26 | from .file_utils import PYTORCH_PRETRAINED_BERT_CACHE, cached_path 27 | -------------------------------------------------------------------------------- /pytorch_pretrained_bert/convert_tf_checkpoint_to_pytorch.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2018 The HuggingFace Inc. team. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | """Convert BERT checkpoint.""" 16 | 17 | from __future__ import absolute_import 18 | from __future__ import division 19 | from __future__ import print_function 20 | 21 | import os 22 | import re 23 | import argparse 24 | import tensorflow as tf 25 | import torch 26 | import numpy as np 27 | 28 | from pytorch_pretrained_bert.modeling import BertConfig, BertForPreTraining, load_tf_weights_in_bert 29 | 30 | def convert_tf_checkpoint_to_pytorch(tf_checkpoint_path, bert_config_file, pytorch_dump_path): 31 | # Initialise PyTorch model 32 | config = BertConfig.from_json_file(bert_config_file) 33 | print("Building PyTorch model from configuration: {}".format(str(config))) 34 | model = BertForPreTraining(config) 35 | 36 | # Load weights from tf checkpoint 37 | load_tf_weights_in_bert(model, tf_checkpoint_path) 38 | 39 | # Save pytorch-model 40 | print("Save PyTorch model to {}".format(pytorch_dump_path)) 41 | torch.save(model.state_dict(), pytorch_dump_path) 42 | 43 | 44 | if __name__ == "__main__": 45 | parser = argparse.ArgumentParser() 46 | ## Required parameters 47 | parser.add_argument("--tf_checkpoint_path", 48 | default = None, 49 | type = str, 50 | required = True, 51 | help = "Path the TensorFlow checkpoint path.") 52 | parser.add_argument("--bert_config_file", 53 | default = None, 54 | type = str, 55 | required = True, 56 | help = "The config json file corresponding to the pre-trained BERT model. \n" 57 | "This specifies the model architecture.") 58 | parser.add_argument("--pytorch_dump_path", 59 | default = None, 60 | type = str, 61 | required = True, 62 | help = "Path to the output PyTorch model.") 63 | args = parser.parse_args() 64 | convert_tf_checkpoint_to_pytorch(args.tf_checkpoint_path, 65 | args.bert_config_file, 66 | args.pytorch_dump_path) 67 | -------------------------------------------------------------------------------- /examples/CVTMLE_example.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": 3, 6 | "metadata": {}, 7 | "outputs": [], 8 | "source": [ 9 | "import sys\n", 10 | "import os\n", 11 | "sys.path.insert(0, '/home/rnshishir/deepmed/TBEHRT_pl/')\n", 12 | "\n", 13 | "import scipy\n", 14 | "import pandas as pd\n", 15 | "import numpy as np\n", 16 | "from src.CV_TMLE import *" 17 | ] 18 | }, 19 | { 20 | "cell_type": "markdown", 21 | "metadata": {}, 22 | "source": [ 23 | "# CV TMLE tutorial" 24 | ] 25 | }, 26 | { 27 | "cell_type": "code", 28 | "execution_count": 4, 29 | "metadata": {}, 30 | "outputs": [ 31 | { 32 | "name": "stdout", 33 | "output_type": "stream", 34 | "text": [ 35 | "running CV-TMLE for binary outcomes...\n" 36 | ] 37 | } 38 | ], 39 | "source": [ 40 | "# folds in the npz format\n", 41 | "foldNPZ = ['TBEHRT_Test__CUT0.npz', 'TBEHRT_Test__CUT1.npz', 'TBEHRT_Test__CUT2.npz', 'TBEHRT_Test__CUT3.npz', 'TBEHRT_Test__CUT4.npz' ]\n", 42 | "\n", 43 | "# cvtmle runner \n", 44 | "TMLErun = CVTMLE(fromFolds=foldNPZ,truncate_level=0.03 )\n", 45 | "\n", 46 | "# estiamte the risk ratio for binary outcome\n", 47 | "est = TMLErun.run_tmle_binary()" 48 | ] 49 | }, 50 | { 51 | "cell_type": "code", 52 | "execution_count": 5, 53 | "metadata": {}, 54 | "outputs": [ 55 | { 56 | "data": { 57 | "text/plain": [ 58 | "[0.10878099048487283, 5.2854239704810925e-08, 223885.61366048577]" 59 | ] 60 | }, 61 | "execution_count": 5, 62 | "metadata": {}, 63 | "output_type": "execute_result" 64 | } 65 | ], 66 | "source": [ 67 | "est\n", 68 | "# prints estimate and lower and upper conf interval bounds" 69 | ] 70 | }, 71 | { 72 | "cell_type": "code", 73 | "execution_count": 7, 74 | "metadata": {}, 75 | "outputs": [], 76 | "source": [ 77 | "data = pd.read_parquet('test.parquet')" 78 | ] 79 | }, 80 | { 81 | "cell_type": "code", 82 | "execution_count": 11, 83 | "metadata": {}, 84 | "outputs": [], 85 | "source": [ 86 | "# raw\n", 87 | "# data[data.explabel ==1].label.mean()/data[data.explabel ==0].label.mean()" 88 | ] 89 | }, 90 | { 91 | "cell_type": "code", 92 | "execution_count": null, 93 | "metadata": {}, 94 | "outputs": [], 95 | "source": [] 96 | } 97 | ], 98 | "metadata": { 99 | "kernelspec": { 100 | "display_name": "real3", 101 | "language": "python", 102 | "name": "py3" 103 | }, 104 | "language_info": { 105 | "codemirror_mode": { 106 | "name": "ipython", 107 | "version": 3 108 | }, 109 | "file_extension": ".py", 110 | "mimetype": "text/x-python", 111 | "name": "python", 112 | "nbconvert_exporter": "python", 113 | "pygments_lexer": "ipython3", 114 | "version": "3.6.8" 115 | } 116 | }, 117 | "nbformat": 4, 118 | "nbformat_minor": 4 119 | } 120 | -------------------------------------------------------------------------------- /pytorch_pretrained_bert/convert_gpt2_checkpoint_to_pytorch.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2018 The HuggingFace Inc. team. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | """Convert OpenAI GPT checkpoint.""" 16 | 17 | from __future__ import absolute_import, division, print_function 18 | 19 | import argparse 20 | from io import open 21 | 22 | import torch 23 | 24 | from pytorch_pretrained_bert.modeling_gpt2 import (CONFIG_NAME, WEIGHTS_NAME, 25 | GPT2Config, 26 | GPT2Model, 27 | load_tf_weights_in_gpt2) 28 | 29 | 30 | def convert_gpt2_checkpoint_to_pytorch(gpt2_checkpoint_path, gpt2_config_file, pytorch_dump_folder_path): 31 | # Construct model 32 | if gpt2_config_file == "": 33 | config = GPT2Config() 34 | else: 35 | config = GPT2Config(gpt2_config_file) 36 | model = GPT2Model(config) 37 | 38 | # Load weights from numpy 39 | load_tf_weights_in_gpt2(model, gpt2_checkpoint_path) 40 | 41 | # Save pytorch-model 42 | pytorch_weights_dump_path = pytorch_dump_folder_path + '/' + WEIGHTS_NAME 43 | pytorch_config_dump_path = pytorch_dump_folder_path + '/' + CONFIG_NAME 44 | print("Save PyTorch model to {}".format(pytorch_weights_dump_path)) 45 | torch.save(model.state_dict(), pytorch_weights_dump_path) 46 | print("Save configuration file to {}".format(pytorch_config_dump_path)) 47 | with open(pytorch_config_dump_path, "w", encoding="utf-8") as f: 48 | f.write(config.to_json_string()) 49 | 50 | 51 | if __name__ == "__main__": 52 | parser = argparse.ArgumentParser() 53 | ## Required parameters 54 | parser.add_argument("--gpt2_checkpoint_path", 55 | default = None, 56 | type = str, 57 | required = True, 58 | help = "Path the TensorFlow checkpoint path.") 59 | parser.add_argument("--pytorch_dump_folder_path", 60 | default = None, 61 | type = str, 62 | required = True, 63 | help = "Path to the output PyTorch model.") 64 | parser.add_argument("--gpt2_config_file", 65 | default = "", 66 | type = str, 67 | help = "An optional config json file corresponding to the pre-trained OpenAI model. \n" 68 | "This specifies the model architecture.") 69 | args = parser.parse_args() 70 | convert_gpt2_checkpoint_to_pytorch(args.gpt2_checkpoint_path, 71 | args.gpt2_config_file, 72 | args.pytorch_dump_folder_path) 73 | -------------------------------------------------------------------------------- /pytorch_pretrained_bert/convert_openai_checkpoint_to_pytorch.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2018 The HuggingFace Inc. team. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | """Convert OpenAI GPT checkpoint.""" 16 | 17 | from __future__ import absolute_import, division, print_function 18 | 19 | import argparse 20 | from io import open 21 | 22 | import torch 23 | 24 | from pytorch_pretrained_bert.modeling_openai import (CONFIG_NAME, WEIGHTS_NAME, 25 | OpenAIGPTConfig, 26 | OpenAIGPTModel, 27 | load_tf_weights_in_openai_gpt) 28 | 29 | 30 | def convert_openai_checkpoint_to_pytorch(openai_checkpoint_folder_path, openai_config_file, pytorch_dump_folder_path): 31 | # Construct model 32 | if openai_config_file == "": 33 | config = OpenAIGPTConfig() 34 | else: 35 | config = OpenAIGPTConfig(openai_config_file) 36 | model = OpenAIGPTModel(config) 37 | 38 | # Load weights from numpy 39 | load_tf_weights_in_openai_gpt(model, openai_checkpoint_folder_path) 40 | 41 | # Save pytorch-model 42 | pytorch_weights_dump_path = pytorch_dump_folder_path + '/' + WEIGHTS_NAME 43 | pytorch_config_dump_path = pytorch_dump_folder_path + '/' + CONFIG_NAME 44 | print("Save PyTorch model to {}".format(pytorch_weights_dump_path)) 45 | torch.save(model.state_dict(), pytorch_weights_dump_path) 46 | print("Save configuration file to {}".format(pytorch_config_dump_path)) 47 | with open(pytorch_config_dump_path, "w", encoding="utf-8") as f: 48 | f.write(config.to_json_string()) 49 | 50 | 51 | if __name__ == "__main__": 52 | parser = argparse.ArgumentParser() 53 | ## Required parameters 54 | parser.add_argument("--openai_checkpoint_folder_path", 55 | default = None, 56 | type = str, 57 | required = True, 58 | help = "Path the TensorFlow checkpoint path.") 59 | parser.add_argument("--pytorch_dump_folder_path", 60 | default = None, 61 | type = str, 62 | required = True, 63 | help = "Path to the output PyTorch model.") 64 | parser.add_argument("--openai_config_file", 65 | default = "", 66 | type = str, 67 | help = "An optional config json file corresponding to the pre-trained OpenAI model. \n" 68 | "This specifies the model architecture.") 69 | args = parser.parse_args() 70 | convert_openai_checkpoint_to_pytorch(args.openai_checkpoint_folder_path, 71 | args.openai_config_file, 72 | args.pytorch_dump_folder_path) 73 | -------------------------------------------------------------------------------- /pytorch_pretrained_bert/optimizer.py: -------------------------------------------------------------------------------- 1 | import pytorch_pretrained_bert as Bert 2 | 3 | def VAEadam(params, config=None): 4 | if config is None: 5 | config = { 6 | 'lr': 3e-5, 7 | 'warmup_proportion': 0.1 8 | } 9 | no_decay = ['bias', 'LayerNorm.bias', 'LayerNorm.weight', 'Eps'] 10 | vae= ['VAE'] 11 | no_decayfull = no_decay+vae 12 | # print( {'params': [n for n, p in params if not any(nd in n for nd in no_decayfull)], 'weight_decay': 0.01, 'lr': config['lr']}, 13 | # {'params': [n for n, p in params if any(nd in n for nd in no_decay)], 'weight_decay': 0.0, 'lr': config['lr']}, 14 | # {'params': [n for n, p in params if any(nd in n for nd in vae)], 'weight_decay': 0.0, 'lr':1e-3 } 15 | # ) 16 | optimizer_grouped_parameters = [ 17 | {'params': [p for n, p in params if not any(nd in n for nd in no_decayfull)], 'weight_decay': 0.01, 'lr': config['lr']}, 18 | {'params': [p for n, p in params if (any(nd in n for nd in no_decay) and 'VAE' not in n)], 'weight_decay': 0.0, 'lr': config['lr']}, 19 | {'params': [p for n, p in params if any(nd in n for nd in vae)], 'weight_decay': 0.0, 'lr':1e-3 } 20 | 21 | ] 22 | 23 | optim = Bert.optimization.BertAdam(optimizer_grouped_parameters, 24 | warmup=config['warmup_proportion']) 25 | return optim 26 | 27 | 28 | def adam(params, config=None): 29 | if config is None: 30 | config = { 31 | 'lr': 3e-5, 32 | 'warmup_proportion': 0.1 33 | } 34 | no_decay = ['bias', 'LayerNorm.bias', 'LayerNorm.weight', 'Eps','VAE'] 35 | optimizer_grouped_parameters = [ 36 | {'params': [p for n, p in params if not any(nd in n for nd in no_decay)], 'weight_decay': 0.01}, 37 | {'params': [p for n, p in params if any(nd in n for nd in no_decay)], 'weight_decay': 0.0} 38 | ] 39 | 40 | optim = Bert.optimization.BertAdam(optimizer_grouped_parameters, 41 | lr=config['lr'], 42 | warmup=config['warmup_proportion']) 43 | return optim 44 | 45 | def GPadam(params, gpLR, config=None): 46 | if config is None: 47 | config = { 48 | 'lr': 3e-5, 49 | 'warmup_proportion': 0.1 50 | } 51 | no_decay = ['bias', 'LayerNorm.bias', 'LayerNorm.weight','Eps'] 52 | gp = ['GP'] 53 | 54 | 55 | 56 | optimizer_grouped_parameters = [ 57 | {'params': [p for n, p in params if not any(nd in n for nd in no_decay) and not any(nd in n for nd in gp)], 'weight_decay': 0.01 , 'lr': config['lr'], 'warmup_proportion': 0.1}, 58 | {'params': [p for n, p in params if any(nd in n for nd in no_decay) and not any(nd in n for nd in gp)], 'weight_decay': 0.0, 'lr': config['lr'], 'warmup_proportion': 0.1}, 59 | {'params': [p for n, p in params if any(nd in n for nd in gp)], 'lr': gpLR} 60 | ] 61 | 62 | print([ 63 | {'params': [n for n, p in params if not any(nd in n for nd in no_decay) and not any(nd in n for nd in gp)], 'weight_decay': 0.01 , 'lr': config['lr'], 'warmup_proportion': 0.1}, 64 | {'params': [n for n, p in params if any(nd in n for nd in no_decay) and not any(nd in n for nd in gp)], 'weight_decay': 0.0, 'lr': config['lr'], 'warmup_proportion': 0.1}, 65 | {'params': [n for n, p in params if any(nd in n for nd in gp)], 'lr': gpLR} 66 | ]) 67 | optim = Bert.optimization.BertAdam(optimizer_grouped_parameters) 68 | return optim 69 | -------------------------------------------------------------------------------- /pytorch_pretrained_bert/module.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import pytorch_pretrained_bert as Bert 4 | 5 | 6 | def sequence_mask(sequence_length, max_len=None, device=None): 7 | sequence_length = torch.tensor(sequence_length) 8 | max_len = torch.tensor(max_len) 9 | if max_len is None: 10 | max_len = sequence_length.data.max() 11 | batch_size = sequence_length.size(0) 12 | seq_range = torch.arange(0, max_len).long() 13 | seq_range_expand = seq_range.unsqueeze(0).expand(batch_size, max_len) 14 | 15 | if sequence_length.is_cuda: 16 | seq_range_expand = seq_range_expand.to(device) 17 | seq_length_expand = (sequence_length.unsqueeze(1).expand_as(seq_range_expand)) 18 | mask= seq_range_expand < seq_length_expand 19 | return mask.detach().long() 20 | 21 | 22 | class BertEmbeddings(nn.Module): 23 | """Construct the embeddings from word, segment, age 24 | """ 25 | 26 | def __init__(self, config): 27 | super(BertEmbeddings, self).__init__() 28 | self.word_embeddings = nn.Embedding(config.vocab_size, config.hidden_size) 29 | self.segment_embeddings = nn.Embedding(config.seg_vocab_size, config.hidden_size) 30 | self.age_embeddings = nn.Embedding(config.age_vocab_size, config.hidden_size) 31 | 32 | self.LayerNorm = Bert.modeling.BertLayerNorm(config.hidden_size, eps=1e-12) 33 | self.dropout = nn.Dropout(config.hidden_dropout_prob) 34 | 35 | def forward(self, word_ids, age_ids=None, seg_ids=None): 36 | if seg_ids is None: 37 | seg_ids = torch.zeros_like(word_ids) 38 | if age_ids is None: 39 | age_ids = torch.zeros_like(word_ids) 40 | 41 | word_embed = self.word_embeddings(word_ids) 42 | segment_embed = self.segment_embeddings(seg_ids) 43 | age_embed = self.age_embeddings(age_ids) 44 | 45 | embeddings = word_embed + segment_embed + age_embed 46 | embeddings = self.LayerNorm(embeddings) 47 | embeddings = self.dropout(embeddings) 48 | return embeddings 49 | 50 | 51 | class BertModel(Bert.modeling.BertPreTrainedModel): 52 | def __init__(self, config): 53 | super(BertModel, self).__init__(config) 54 | self.embeddings = BertEmbeddings(config=config) 55 | self.encoder = Bert.modeling.BertEncoder(config=config) 56 | self.pooler = Bert.modeling.BertPooler(config) 57 | self.apply(self.init_bert_weights) 58 | 59 | def forward(self, input_ids, age_ids=None, seg_ids=None, attention_mask=None, output_all_encoded_layers=True): 60 | if attention_mask is None: 61 | attention_mask = torch.ones_like(input_ids) 62 | if age_ids is None: 63 | age_ids = torch.zeros_like(input_ids) 64 | if seg_ids is None: 65 | seg_ids = torch.zeros_like(input_ids) 66 | 67 | # We create a 3D attention mask from a 2D tensor mask. 68 | # Sizes are [batch_size, 1, 1, to_seq_length] 69 | # So we can broadcast to [batch_size, num_heads, from_seq_length, to_seq_length] 70 | # this attention mask is more simple than the triangular masking of causal attention 71 | # used in OpenAI GPT, we just need to prepare the broadcast dimension here. 72 | extended_attention_mask = attention_mask.unsqueeze(1).unsqueeze(2) 73 | 74 | # Since attention_mask is 1.0 for positions we want to attend and 0.0 for 75 | # masked positions, this operation will create a tensor which is 0.0 for 76 | # positions we want to attend and -10000.0 for masked positions. 77 | # Since we are adding it to the raw scores before the softmax, this is 78 | # effectively the same as removing these entirely. 79 | extended_attention_mask = extended_attention_mask.to(dtype=next(self.parameters()).dtype) # fp16 compatibility 80 | extended_attention_mask = (1.0 - extended_attention_mask) * -10000.0 81 | 82 | embedding_output = self.embeddings(input_ids, age_ids, seg_ids) 83 | encoded_layers = self.encoder(embedding_output, 84 | extended_attention_mask, 85 | output_all_encoded_layers=output_all_encoded_layers) 86 | sequence_output = encoded_layers[-1] 87 | pooled_output = self.pooler(sequence_output) 88 | if not output_all_encoded_layers: 89 | encoded_layers = encoded_layers[-1] 90 | return encoded_layers, pooled_output -------------------------------------------------------------------------------- /pytorch_pretrained_bert/__main__.py: -------------------------------------------------------------------------------- 1 | # coding: utf8 2 | def main(): 3 | import sys 4 | if (len(sys.argv) != 4 and len(sys.argv) != 5) or sys.argv[1] not in [ 5 | "convert_tf_checkpoint_to_pytorch", 6 | "convert_openai_checkpoint", 7 | "convert_transfo_xl_checkpoint", 8 | "convert_gpt2_checkpoint", 9 | ]: 10 | print( 11 | "Should be used as one of: \n" 12 | ">> `pytorch_pretrained_bert convert_tf_checkpoint_to_pytorch TF_CHECKPOINT TF_CONFIG PYTORCH_DUMP_OUTPUT`, \n" 13 | ">> `pytorch_pretrained_bert convert_openai_checkpoint OPENAI_GPT_CHECKPOINT_FOLDER_PATH PYTORCH_DUMP_OUTPUT [OPENAI_GPT_CONFIG]`, \n" 14 | ">> `pytorch_pretrained_bert convert_transfo_xl_checkpoint TF_CHECKPOINT_OR_DATASET PYTORCH_DUMP_OUTPUT [TF_CONFIG]` or \n" 15 | ">> `pytorch_pretrained_bert convert_gpt2_checkpoint TF_CHECKPOINT PYTORCH_DUMP_OUTPUT [GPT2_CONFIG]`") 16 | else: 17 | if sys.argv[1] == "convert_tf_checkpoint_to_pytorch": 18 | try: 19 | from .convert_tf_checkpoint_to_pytorch import convert_tf_checkpoint_to_pytorch 20 | except ImportError: 21 | print("pytorch_pretrained_bert can only be used from the commandline to convert TensorFlow models in PyTorch, " 22 | "In that case, it requires TensorFlow to be installed. Please see " 23 | "https://www.tensorflow.org/install/ for installation instructions.") 24 | raise 25 | 26 | if len(sys.argv) != 5: 27 | # pylint: disable=line-too-long 28 | print("Should be used as `pytorch_pretrained_bert convert_tf_checkpoint_to_pytorch TF_CHECKPOINT TF_CONFIG PYTORCH_DUMP_OUTPUT`") 29 | else: 30 | PYTORCH_DUMP_OUTPUT = sys.argv.pop() 31 | TF_CONFIG = sys.argv.pop() 32 | TF_CHECKPOINT = sys.argv.pop() 33 | convert_tf_checkpoint_to_pytorch(TF_CHECKPOINT, TF_CONFIG, PYTORCH_DUMP_OUTPUT) 34 | elif sys.argv[1] == "convert_openai_checkpoint": 35 | from .convert_openai_checkpoint_to_pytorch import convert_openai_checkpoint_to_pytorch 36 | OPENAI_GPT_CHECKPOINT_FOLDER_PATH = sys.argv[2] 37 | PYTORCH_DUMP_OUTPUT = sys.argv[3] 38 | if len(sys.argv) == 5: 39 | OPENAI_GPT_CONFIG = sys.argv[4] 40 | else: 41 | OPENAI_GPT_CONFIG = "" 42 | convert_openai_checkpoint_to_pytorch(OPENAI_GPT_CHECKPOINT_FOLDER_PATH, 43 | OPENAI_GPT_CONFIG, 44 | PYTORCH_DUMP_OUTPUT) 45 | elif sys.argv[1] == "convert_transfo_xl_checkpoint": 46 | try: 47 | from .convert_transfo_xl_checkpoint_to_pytorch import convert_transfo_xl_checkpoint_to_pytorch 48 | except ImportError: 49 | print("pytorch_pretrained_bert can only be used from the commandline to convert TensorFlow models in PyTorch, " 50 | "In that case, it requires TensorFlow to be installed. Please see " 51 | "https://www.tensorflow.org/install/ for installation instructions.") 52 | raise 53 | 54 | if 'ckpt' in sys.argv[2].lower(): 55 | TF_CHECKPOINT = sys.argv[2] 56 | TF_DATASET_FILE = "" 57 | else: 58 | TF_DATASET_FILE = sys.argv[2] 59 | TF_CHECKPOINT = "" 60 | PYTORCH_DUMP_OUTPUT = sys.argv[3] 61 | if len(sys.argv) == 5: 62 | TF_CONFIG = sys.argv[4] 63 | else: 64 | TF_CONFIG = "" 65 | convert_transfo_xl_checkpoint_to_pytorch(TF_CHECKPOINT, TF_CONFIG, PYTORCH_DUMP_OUTPUT, TF_DATASET_FILE) 66 | else: 67 | try: 68 | from .convert_gpt2_checkpoint_to_pytorch import convert_gpt2_checkpoint_to_pytorch 69 | except ImportError: 70 | print("pytorch_pretrained_bert can only be used from the commandline to convert TensorFlow models in PyTorch, " 71 | "In that case, it requires TensorFlow to be installed. Please see " 72 | "https://www.tensorflow.org/install/ for installation instructions.") 73 | raise 74 | 75 | TF_CHECKPOINT = sys.argv[2] 76 | PYTORCH_DUMP_OUTPUT = sys.argv[3] 77 | if len(sys.argv) == 5: 78 | TF_CONFIG = sys.argv[4] 79 | else: 80 | TF_CONFIG = "" 81 | convert_gpt2_checkpoint_to_pytorch(TF_CHECKPOINT, TF_CONFIG, PYTORCH_DUMP_OUTPUT) 82 | if __name__ == '__main__': 83 | main() 84 | -------------------------------------------------------------------------------- /src/data.py: -------------------------------------------------------------------------------- 1 | from torch.utils.data.dataset import Dataset 2 | import torch 3 | from src.utils import * 4 | # data for var autoencoder deep unsup learning with tbehrt 5 | 6 | 7 | class TBEHRT_data_formation(Dataset): 8 | def __init__(self, token2idx, dataframe, code= 'code', age = 'age', year = 'year' , static= 'static' , max_len=1000,expColumn='explabel', outcomeColumn='label', max_age=110, yvocab=None, list2avoid=None, MEM=True): 9 | """ 10 | The dataset class for the pytorch coded model, Targeted BEHRT 11 | 12 | token2idx - the dict that maps tokens in EHR to numbers /index 13 | dataframe - the pandas dataframe that has the code,age,year, and any static columns 14 | code - name of code column 15 | age - name of age column 16 | year - name of year column 17 | static - name of static column 18 | max_len - length of sequence 19 | yvocab - the year vocab for the year based sequence of variables 20 | expColumn - the exposure column for dataframe 21 | outcomeColumn - the outcome column 22 | MEM - the masked EHR modelling flag for unsupervised learning 23 | list2avoid - list of tokens /diseases to not include in the MEM masking procedure 24 | 25 | """ 26 | 27 | if list2avoid is None: 28 | self.acceptableVoc = token2idx 29 | else: 30 | self.acceptableVoc = {x: y for x, y in token2idx.items() if x not in list2avoid} 31 | print("old Vocab size: ", len(token2idx), ", and new Vocab size: ", len(self.acceptableVoc)) 32 | self.vocab = token2idx 33 | self.max_len = max_len 34 | self.code = dataframe[code] 35 | self.age = dataframe[age] 36 | self.year = dataframe[year] 37 | if outcomeColumn is None: 38 | self.label = dataframe.deathLabel 39 | else: 40 | self.label = dataframe[outcomeColumn] 41 | self.age2idx, _ = age_vocab(110, year, symbol=None) 42 | 43 | if expColumn is None: 44 | self.treatmentLabel = dataframe.diseaseLabel 45 | else: 46 | self.treatmentLabel = dataframe[expColumn] 47 | self.year2idx = yvocab 48 | self.codeS = dataframe[static] 49 | self.MEM = MEM 50 | def __getitem__(self, index): 51 | """ 52 | return: age, code, position, segmentation, mask, label 53 | """ 54 | 55 | # extract data 56 | 57 | age = self.age[index] 58 | 59 | code = self.code[index] 60 | year = self.year[index] 61 | 62 | age = age[(-self.max_len + 1):] 63 | code = code[(-self.max_len + 1):] 64 | year = year[(-self.max_len + 1):] 65 | 66 | 67 | treatmentOutcome = torch.LongTensor([self.treatmentLabel[index]]) 68 | 69 | # avoid data cut with first element to be 'SEP' 70 | labelOutcome = self.label[index] 71 | 72 | 73 | # moved CLS to end as opposed to beginning. 74 | code[-1] = 'CLS' 75 | 76 | mask = np.ones(self.max_len) 77 | mask[:-len(code)] = 0 78 | mask = np.append(np.array([1]), mask) 79 | 80 | 81 | tokensReal, code2 = code2index(code, self.vocab) 82 | # pad age sequence and code sequence 83 | year = seq_padding_reverse(year, self.max_len, token2idx=self.year2idx) 84 | 85 | age = seq_padding_reverse(age, self.max_len, token2idx=self.age2idx) 86 | 87 | if self.MEM == False: 88 | tokens, codeMLM, labelMLM = nonMASK(code, self.vocab) 89 | else: 90 | tokens, codeMLM, labelMLM = randommaskreal(code, self.acceptableVoc) 91 | 92 | # get position code and segment code 93 | tokens = seq_padding_reverse(tokens, self.max_len) 94 | position = position_idx(tokens) 95 | segment = index_seg(tokens) 96 | 97 | code2 = seq_padding_reverse(code2, self.max_len, symbol=self.vocab['PAD']) 98 | 99 | codeMLM = seq_padding_reverse(codeMLM, self.max_len, symbol=self.vocab['PAD']) 100 | labelMLM = seq_padding_reverse(labelMLM, self.max_len, symbol=-1) 101 | 102 | outCodeS = [int(xx) for xx in self.codeS[index]] 103 | fixedcovar = np.array(outCodeS ) 104 | labelcovar = np.array(([-1] * len(outCodeS)) + [-1, -1]) 105 | if self.MEM == True: 106 | fixedcovar, labelcovar = covarUnsupMaker(fixedcovar) 107 | code2 = np.append(fixedcovar, code2) 108 | codeMLM = np.append(fixedcovar, codeMLM) 109 | 110 | 111 | 112 | # code2 is the fixed static covariates while the codeMLM are the longutidunal one 113 | return torch.LongTensor(age), torch.LongTensor(code2), torch.LongTensor(codeMLM), torch.LongTensor( 114 | position), torch.LongTensor(segment), torch.LongTensor(year), \ 115 | torch.LongTensor(mask), torch.LongTensor(labelMLM), torch.LongTensor( 116 | [labelOutcome]), treatmentOutcome, torch.LongTensor(labelcovar) 117 | 118 | 119 | def __len__(self): 120 | return len(self.code) 121 | 122 | -------------------------------------------------------------------------------- /pytorch_pretrained_bert/convert_transfo_xl_checkpoint_to_pytorch.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2018 The HuggingFace Inc. team. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | """Convert Transformer XL checkpoint and datasets.""" 16 | 17 | from __future__ import absolute_import, division, print_function 18 | 19 | import argparse 20 | import os 21 | import sys 22 | from io import open 23 | 24 | import torch 25 | 26 | import pytorch_pretrained_bert.tokenization_transfo_xl as data_utils 27 | from pytorch_pretrained_bert.modeling_transfo_xl import (CONFIG_NAME, 28 | WEIGHTS_NAME, 29 | TransfoXLConfig, 30 | TransfoXLLMHeadModel, 31 | load_tf_weights_in_transfo_xl) 32 | from pytorch_pretrained_bert.tokenization_transfo_xl import (CORPUS_NAME, 33 | VOCAB_NAME) 34 | 35 | if sys.version_info[0] == 2: 36 | import cPickle as pickle 37 | else: 38 | import pickle 39 | 40 | # We do this to be able to load python 2 datasets pickles 41 | # See e.g. https://stackoverflow.com/questions/2121874/python-pickling-after-changing-a-modules-directory/2121918#2121918 42 | data_utils.Vocab = data_utils.TransfoXLTokenizer 43 | data_utils.Corpus = data_utils.TransfoXLCorpus 44 | sys.modules['data_utils'] = data_utils 45 | sys.modules['vocabulary'] = data_utils 46 | 47 | def convert_transfo_xl_checkpoint_to_pytorch(tf_checkpoint_path, 48 | transfo_xl_config_file, 49 | pytorch_dump_folder_path, 50 | transfo_xl_dataset_file): 51 | if transfo_xl_dataset_file: 52 | # Convert a pre-processed corpus (see original TensorFlow repo) 53 | with open(transfo_xl_dataset_file, "rb") as fp: 54 | corpus = pickle.load(fp, encoding="latin1") 55 | # Save vocabulary and dataset cache as Dictionaries (should be better than pickles for the long-term) 56 | pytorch_vocab_dump_path = pytorch_dump_folder_path + '/' + VOCAB_NAME 57 | print("Save vocabulary to {}".format(pytorch_vocab_dump_path)) 58 | corpus_vocab_dict = corpus.vocab.__dict__ 59 | torch.save(corpus_vocab_dict, pytorch_vocab_dump_path) 60 | 61 | corpus_dict_no_vocab = corpus.__dict__ 62 | corpus_dict_no_vocab.pop('vocab', None) 63 | pytorch_dataset_dump_path = pytorch_dump_folder_path + '/' + CORPUS_NAME 64 | print("Save dataset to {}".format(pytorch_dataset_dump_path)) 65 | torch.save(corpus_dict_no_vocab, pytorch_dataset_dump_path) 66 | 67 | if tf_checkpoint_path: 68 | # Convert a pre-trained TensorFlow model 69 | config_path = os.path.abspath(transfo_xl_config_file) 70 | tf_path = os.path.abspath(tf_checkpoint_path) 71 | 72 | print("Converting Transformer XL checkpoint from {} with config at {}".format(tf_path, config_path)) 73 | # Initialise PyTorch model 74 | if transfo_xl_config_file == "": 75 | config = TransfoXLConfig() 76 | else: 77 | config = TransfoXLConfig(transfo_xl_config_file) 78 | print("Building PyTorch model from configuration: {}".format(str(config))) 79 | model = TransfoXLLMHeadModel(config) 80 | 81 | model = load_tf_weights_in_transfo_xl(model, config, tf_path) 82 | # Save pytorch-model 83 | pytorch_weights_dump_path = os.path.join(pytorch_dump_folder_path, WEIGHTS_NAME) 84 | pytorch_config_dump_path = os.path.join(pytorch_dump_folder_path, CONFIG_NAME) 85 | print("Save PyTorch model to {}".format(os.path.abspath(pytorch_weights_dump_path))) 86 | torch.save(model.state_dict(), pytorch_weights_dump_path) 87 | print("Save configuration file to {}".format(os.path.abspath(pytorch_config_dump_path))) 88 | with open(pytorch_config_dump_path, "w", encoding="utf-8") as f: 89 | f.write(config.to_json_string()) 90 | 91 | 92 | if __name__ == "__main__": 93 | parser = argparse.ArgumentParser() 94 | parser.add_argument("--pytorch_dump_folder_path", 95 | default = None, 96 | type = str, 97 | required = True, 98 | help = "Path to the folder to store the PyTorch model or dataset/vocab.") 99 | parser.add_argument("--tf_checkpoint_path", 100 | default = "", 101 | type = str, 102 | help = "An optional path to a TensorFlow checkpoint path to be converted.") 103 | parser.add_argument("--transfo_xl_config_file", 104 | default = "", 105 | type = str, 106 | help = "An optional config json file corresponding to the pre-trained BERT model. \n" 107 | "This specifies the model architecture.") 108 | parser.add_argument("--transfo_xl_dataset_file", 109 | default = "", 110 | type = str, 111 | help = "An optional dataset file to be converted in a vocabulary.") 112 | args = parser.parse_args() 113 | convert_transfo_xl_checkpoint_to_pytorch(args.tf_checkpoint_path, 114 | args.transfo_xl_config_file, 115 | args.pytorch_dump_folder_path, 116 | args.transfo_xl_dataset_file) 117 | -------------------------------------------------------------------------------- /src/vae.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import pytorch_pretrained_bert as Bert 4 | 5 | # borrowed boilerplate vae code from https://github.com/AntixK/PyTorch-VAE 6 | 7 | 8 | class VAE(Bert.modeling.BertPreTrainedModel): 9 | def __init__(self, config): 10 | super(VAE, self).__init__(config) 11 | 12 | self.unsuplist = config.unsupSize 13 | 14 | 15 | 16 | 17 | vaelatentdim = config.vaelatentdim 18 | vaeinchannels = config.vaeinchannels 19 | 20 | modules = [] 21 | vaehidden = [config.poolingSize] 22 | self.linearFC = nn.Linear(config.hidden_size, config.poolingSize) 23 | self.activ = nn.ReLU() 24 | 25 | # Build Encoder 26 | self.fc_mu = nn.Linear(vaehidden[-1], vaelatentdim) 27 | self.fc_var = nn.Linear(vaehidden[-1], vaelatentdim) 28 | 29 | 30 | # Build Decoder 31 | modules = [] 32 | 33 | self.decoder1 = nn.Linear(vaelatentdim, vaehidden[-1]) 34 | self.decoder2 = nn.Linear(vaehidden[-1],int( vaehidden[-1])) 35 | 36 | self.logSoftmax = nn.LogSoftmax(dim=1) 37 | 38 | self.linearOut = nn.ModuleList([nn.Linear (int( vaehidden[-1]), el[0]) for el in self.unsuplist]) 39 | self.BetaD = config.BetaD 40 | 41 | self.apply(self.init_bert_weights) 42 | 43 | def encode(self, input: torch.Tensor) : 44 | """ 45 | Encodes the input by passing through the encoder network 46 | and returns the latent codes. 47 | :param input: (Tensor) Input tensor to encoder [N x C x H x W] 48 | :return: (Tensor) List of latent codes 49 | """ 50 | # result = self.activ (self.linearFC(input)) 51 | 52 | mu = self.fc_mu(input) 53 | log_var = self.fc_var(input) 54 | 55 | return [mu, log_var] 56 | 57 | def decode(self, z: torch.Tensor) -> torch.Tensor: 58 | """ 59 | Maps the given latent codes 60 | onto the image space. 61 | :param z: (Tensor) [B x D] 62 | :return: (Tensor) [B x C x H x W] 63 | """ 64 | result = self.activ(self.decoder1(z)) 65 | result = self.activ(self.decoder2(result)) 66 | outs = [] 67 | 68 | 69 | for outputiter , linoutnetwork in enumerate(self.linearOut): 70 | resout = self.logSoftmax(linoutnetwork(result)) 71 | outs.append(resout) 72 | 73 | outs = torch.cat((outs), dim=1) 74 | return outs 75 | 76 | def reparameterize(self, mu: torch.Tensor, logvar: torch.Tensor) -> torch.Tensor: 77 | """ 78 | Reparameterization trick to sample from N(mu, var) from 79 | N(0,1). 80 | :param mu: (Tensor) Mean of the latent Gaussian [B x D] 81 | :param logvar: (Tensor) Standard deviation of the latent Gaussian [B x D] 82 | :return: (Tensor) [B x D] 83 | """ 84 | std = torch.exp(0.5 * logvar) 85 | eps = torch.randn_like(std) 86 | return eps * std + mu 87 | 88 | def forward(self, input: torch.Tensor, label: torch.Tensor): 89 | 90 | if self.BetaD==False: 91 | mu, log_var = self.encode(input) 92 | z = self.reparameterize(mu, log_var) 93 | return [self.decode(z), label, mu, log_var] 94 | else: 95 | mu, log_var = self.encode(input) 96 | z = self.reparameterize(mu, log_var) 97 | return [self.decode(z), label, mu, log_var] 98 | def loss_function(self,dictout) -> dict: 99 | """ 100 | Computes the VAE loss function. 101 | KL(N(\mu, \sigma), N(0, 1)) = \log \frac{1}{\sigma} + \frac{\sigma^2 + \mu^2}{2} - \frac{1}{2} 102 | :param args: 103 | :param kwargs: 104 | :return: 105 | """ 106 | recons = dictout[0].transpose(1,0) 107 | input = dictout[1].transpose(1,0) 108 | 109 | mu = dictout[2] 110 | log_var = dictout[3] 111 | if self.BetaD==False: 112 | 113 | kld_weight = self.config.klpar # Account for the minibatch samples from the dataset 114 | reconsloss = 0 115 | startindx = 0 116 | 117 | outs = [] 118 | labs = [] 119 | for outputiter , output in enumerate(self.unsuplist): 120 | elementssize = output[0] 121 | chunkrecons = recons[startindx:startindx+elementssize].transpose(1,0) 122 | labels= input[outputiter] 123 | lossF = nn.NLLLoss(reduction='none', ignore_index=-1) 124 | temploss = lossF(chunkrecons,labels).sum() 125 | reconsloss =reconsloss+ temploss 126 | 127 | outs.append(chunkrecons) 128 | labs.append(labels) 129 | startindx = startindx+elementssize 130 | 131 | 132 | kld_loss = torch.sum(-0.5 * torch.sum(1 + log_var - mu ** 2 - log_var.exp(), dim = 1), dim = 0) 133 | 134 | loss = (reconsloss + kld_weight * kld_loss)/len(dictout[0]) 135 | 136 | if self.config.klpar<1: 137 | self.config.klpar = self.config.klpar + 1e-5 138 | 139 | return {'loss': loss, 'Reconstruction_Loss':reconsloss, 'KLD':-kld_loss, 'outs':outs, 'labs':labs} 140 | else: 141 | 142 | 143 | return 0 144 | 145 | def sample(self, 146 | num_samples:int, 147 | current_device: int, **kwargs) -> torch.Tensor: 148 | """ 149 | Samples from the latent space and return the corresponding 150 | image space map. 151 | :param num_samples: (Int) Number of samples 152 | :param current_device: (Int) Device to run the model 153 | :return: (Tensor) 154 | """ 155 | z = torch.randn(num_samples, 156 | self.vaelatentdim) 157 | 158 | z = z.to(current_device) 159 | 160 | samples = self.decode(z) 161 | return samples 162 | 163 | def generate(self, x: torch.Tensor, **kwargs) -> torch.Tensor: 164 | """ 165 | Given an input image x, returns the reconstructed image 166 | :param x: (Tensor) [B x C x H x W] 167 | :return: (Tensor) [B x C x H x W] 168 | """ 169 | 170 | return self.forward(x)[0] 171 | 172 | 173 | 174 | -------------------------------------------------------------------------------- /pytorch_pretrained_bert/optimization_openai.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2018 The Open AI Team Authors and The HuggingFace Inc. team. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | """PyTorch optimization for OpenAI GPT model.""" 16 | 17 | import math 18 | import torch 19 | from torch.optim import Optimizer 20 | from torch.optim.optimizer import required 21 | from torch.nn.utils import clip_grad_norm_ 22 | import logging 23 | 24 | logger = logging.getLogger(__name__) 25 | 26 | def warmup_cosine(x, warmup=0.002): 27 | if x < warmup: 28 | return x/warmup 29 | x_ = (x - warmup) / (1 - warmup) # progress after warmup 30 | return 0.5 * (1. + math.cos(math.pi * x_)) 31 | 32 | def warmup_constant(x, warmup=0.002): 33 | """ Linearly increases learning rate over `warmup`*`t_total` (as provided to OpenAIAdam) training steps. 34 | Learning rate is 1. afterwards. """ 35 | if x < warmup: 36 | return x/warmup 37 | return 1.0 38 | 39 | def warmup_linear(x, warmup=0.002): 40 | """ Specifies a triangular learning rate schedule where peak is reached at `warmup`*`t_total`-th (as provided to OpenAIAdam) training step. 41 | After `t_total`-th training step, learning rate is zero. """ 42 | if x < warmup: 43 | return x/warmup 44 | return max((x-1.)/(warmup-1.), 0) 45 | 46 | SCHEDULES = { 47 | 'warmup_cosine':warmup_cosine, 48 | 'warmup_constant':warmup_constant, 49 | 'warmup_linear':warmup_linear, 50 | } 51 | 52 | 53 | class OpenAIAdam(Optimizer): 54 | """Implements Open AI version of Adam algorithm with weight decay fix. 55 | """ 56 | def __init__(self, params, lr=required, schedule='warmup_linear', warmup=-1, t_total=-1, 57 | b1=0.9, b2=0.999, e=1e-8, weight_decay=0, 58 | vector_l2=False, max_grad_norm=-1, **kwargs): 59 | if lr is not required and lr < 0.0: 60 | raise ValueError("Invalid learning rate: {} - should be >= 0.0".format(lr)) 61 | if schedule not in SCHEDULES: 62 | raise ValueError("Invalid schedule parameter: {}".format(schedule)) 63 | if not 0.0 <= warmup < 1.0 and not warmup == -1: 64 | raise ValueError("Invalid warmup: {} - should be in [0.0, 1.0[ or -1".format(warmup)) 65 | if not 0.0 <= b1 < 1.0: 66 | raise ValueError("Invalid b1 parameter: {}".format(b1)) 67 | if not 0.0 <= b2 < 1.0: 68 | raise ValueError("Invalid b2 parameter: {}".format(b2)) 69 | if not e >= 0.0: 70 | raise ValueError("Invalid epsilon value: {}".format(e)) 71 | defaults = dict(lr=lr, schedule=schedule, warmup=warmup, t_total=t_total, 72 | b1=b1, b2=b2, e=e, weight_decay=weight_decay, vector_l2=vector_l2, 73 | max_grad_norm=max_grad_norm) 74 | super(OpenAIAdam, self).__init__(params, defaults) 75 | 76 | def get_lr(self): 77 | lr = [] 78 | for group in self.param_groups: 79 | for p in group['params']: 80 | state = self.state[p] 81 | if len(state) == 0: 82 | return [0] 83 | if group['t_total'] != -1: 84 | schedule_fct = SCHEDULES[group['schedule']] 85 | lr_scheduled = group['lr'] * schedule_fct(state['step']/group['t_total'], group['warmup']) 86 | else: 87 | lr_scheduled = group['lr'] 88 | lr.append(lr_scheduled) 89 | return lr 90 | 91 | def step(self, closure=None): 92 | """Performs a single optimization step. 93 | 94 | Arguments: 95 | closure (callable, optional): A closure that reevaluates the model 96 | and returns the loss. 97 | """ 98 | loss = None 99 | if closure is not None: 100 | loss = closure() 101 | 102 | warned_for_t_total = False 103 | 104 | for group in self.param_groups: 105 | for p in group['params']: 106 | if p.grad is None: 107 | continue 108 | grad = p.grad.data 109 | if grad.is_sparse: 110 | raise RuntimeError('Adam does not support sparse gradients, please consider SparseAdam instead') 111 | 112 | state = self.state[p] 113 | 114 | # State initialization 115 | if len(state) == 0: 116 | state['step'] = 0 117 | # Exponential moving average of gradient values 118 | state['exp_avg'] = torch.zeros_like(p.data) 119 | # Exponential moving average of squared gradient values 120 | state['exp_avg_sq'] = torch.zeros_like(p.data) 121 | 122 | exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq'] 123 | beta1, beta2 = group['b1'], group['b2'] 124 | 125 | state['step'] += 1 126 | 127 | # Add grad clipping 128 | if group['max_grad_norm'] > 0: 129 | clip_grad_norm_(p, group['max_grad_norm']) 130 | 131 | # Decay the first and second moment running average coefficient 132 | exp_avg.mul_(beta1).add_(1 - beta1, grad) 133 | exp_avg_sq.mul_(beta2).addcmul_(1 - beta2, grad, grad) 134 | denom = exp_avg_sq.sqrt().add_(group['e']) 135 | 136 | bias_correction1 = 1 - beta1 ** state['step'] 137 | bias_correction2 = 1 - beta2 ** state['step'] 138 | 139 | if group['t_total'] != -1: 140 | schedule_fct = SCHEDULES[group['schedule']] 141 | progress = state['step']/group['t_total'] 142 | lr_scheduled = group['lr'] * schedule_fct(progress, group['warmup']) 143 | # warning for exceeding t_total (only active with warmup_linear 144 | if group['schedule'] == "warmup_linear" and progress > 1. and not warned_for_t_total: 145 | logger.warning( 146 | "Training beyond specified 't_total' steps with schedule '{}'. Learning rate set to {}. " 147 | "Please set 't_total' of {} correctly.".format(group['schedule'], lr_scheduled, self.__class__.__name__)) 148 | warned_for_t_total = True 149 | # end warning 150 | else: 151 | lr_scheduled = group['lr'] 152 | 153 | step_size = lr_scheduled * math.sqrt(bias_correction2) / bias_correction1 154 | 155 | p.data.addcdiv_(-step_size, exp_avg, denom) 156 | 157 | # Add weight decay at the end (fixed version) 158 | if (len(p.size()) > 1 or group['vector_l2']) and group['weight_decay'] > 0: 159 | p.data.add_(-lr_scheduled * group['weight_decay'], p.data) 160 | 161 | return loss 162 | -------------------------------------------------------------------------------- /src/CV_TMLE.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from scipy.special import logit, expit 3 | from scipy.optimize import minimize 4 | 5 | import numpy as np 6 | from scipy.special import logit 7 | import itertools 8 | import sklearn.linear_model as lm 9 | 10 | import numpy as np 11 | from sklearn.feature_extraction.text import CountVectorizer 12 | 13 | np.random.seed(0) 14 | 15 | 16 | class CVTMLE: 17 | def __init__(self, q_t0=None, q_t1=None, g=None, t=None, y=None, fromFolds=None, est_keys=None, 18 | truncate_level=0.05): 19 | """ 20 | CVTMLE as conceived by Levi, 2018: 21 | Levy, Jonathan. "An easy implementation of CV-TMLE." arXiv preprint arXiv:1811.04573 (2018). 22 | 23 | :param q_t0: initial estimate with control exposure 24 | :param q_t1: initial estimate with treatment/non-control exposure 25 | :param g: prediction of propensity score 26 | :param t: treatment label 27 | :param y: factual outcome 28 | :param fromFolds: if files for estimates per fold are provided (type list) which are npz files, then no need to provide first five parameterse 29 | :param est_keys: once npz files are read, the keys are needed to extract estimates (i.e., first five parameters in this init) 30 | :param truncate_level: truncation for propensity scores (0.05 default means that only patients with estimates between 0.05 and 0.95 will be considered) 31 | """ 32 | 33 | 34 | self.q_t0 = q_t0 35 | self.q_t1 = q_t1 36 | self.g = g 37 | self.t = t 38 | self.y = y 39 | self.est_keys = est_keys 40 | self.truncate_level = truncate_level 41 | if fromFolds is not None: 42 | self.q_t0, self.q_t1, self.y, self.g, self.t = self.collateFromFolds(fromFolds) 43 | 44 | def _perturbed_model_bin_outcome(self, q_t0, q_t1, g, t, eps): 45 | """ 46 | Helper for psi_tmle_bin_outcome 47 | 48 | Returns q_\eps (t,x) and the h term 49 | (i.e., value of perturbed predictor at t, eps, x; where q_t0, q_t1, g are all evaluated at x 50 | """ 51 | h = t * (1. / g) - (1. - t) / (1. - g) 52 | full_lq = (1. - t) * logit(q_t0) + t * logit(q_t1) # logit predictions from unperturbed model 53 | logit_perturb = full_lq + eps * h 54 | return expit(logit_perturb), h 55 | 56 | def run_tmle_binary(self): 57 | """ 58 | This is for CV-TMLE on binary outcomes yielding risk ratio with 95% CI. Read Levi et al for methodological details. 59 | Influence curves coded from Gruber S, van der Laan, MJ. (2011). 60 | 61 | """ 62 | 63 | print('running CV-TMLE for binary outcomes...') 64 | q_t0, q_t1, g, t, y, truncatel = np.copy(self.q_t0), np.copy(self.q_t1), np.copy(self.g), np.copy( 65 | self.t), np.copy(self.y), np.copy(self.truncate_level) 66 | q_t0, q_t1, g, t, y = self.truncate_all_by_g(q_t0, q_t1, g, t, y, truncatel) 67 | 68 | eps_hat = minimize( 69 | lambda eps: self.cross_entropy(y, self._perturbed_model_bin_outcome(q_t0, q_t1, g, t, eps)[0]), 0., 70 | method='Nelder-Mead') 71 | eps_hat = eps_hat.x[0] 72 | 73 | def q1(t_cf): 74 | return self._perturbed_model_bin_outcome(q_t0, q_t1, g, t_cf, eps_hat) 75 | 76 | qall = ((1. - t) * (q_t0)) + (t * (q_t1)) # full predictions from unperturbed model 77 | 78 | qq1, h1 = q1(np.ones_like(t)) 79 | qq0, h0 = q1(np.zeros_like(t)) 80 | rr = np.mean(qq1) / np.mean(qq0) 81 | 82 | ic = (1 / np.mean(qq1) * (h1 * (y - qall) + qq1 - np.mean(qq1)) - 83 | (1 / np.mean(qq0)) * (-1 * h0 * (y - qall) + qq0 - np.mean(qq0))) 84 | psi_tmle_std = 1.96 * np.sqrt(np.var(ic) / (t.shape[0])) 85 | 86 | return [rr, np.exp(np.log(rr) - psi_tmle_std), np.exp(np.log(rr) + psi_tmle_std)] 87 | 88 | def run_tmle_continuous(self): 89 | """ 90 | This is for CV-TMLE on continuous outcomes yielding ATE/MD with 95% CI. Read Levi et al for methodological details. 91 | Influence curves coded from Gruber S, van der Laan, MJ. (2011). 92 | 93 | """ 94 | print('running CV-TMLE for continuous outcomes...') 95 | 96 | q_t0, q_t1, g, t, y, truncatel = np.copy(self.q_t0), np.copy(self.q_t1), np.copy(self.g), np.copy( 97 | self.t), np.copy(self.y), np.copy(self.truncate_level) 98 | q_t0, q_t1, g, t, y = self.truncate_all_by_g(q_t0, q_t1, g, t, y, truncatel) 99 | 100 | h = t * (1.0 / g) - (1.0 - t) / (1.0 - g) 101 | full_q = (1.0 - t) * q_t0 + t * q_t1 102 | eps_hat = np.sum(h * (y - full_q)) / np.sum(np.square(h)) 103 | 104 | def q1(t_cf): 105 | h_cf = t_cf * (1.0 / g) - (1.0 - t_cf) / (1.0 - g) 106 | full_q = ((1.0 - t_cf) * q_t0) + (t_cf * q_t1) 107 | return full_q + eps_hat * h_cf, h_cf 108 | 109 | qq1, h_cf1 = q1(np.ones_like(t)) 110 | qq0, h_cf0 = q1(np.zeros_like(t)) 111 | haw = h_cf0 + h_cf1 112 | 113 | rd = np.mean(qq1 - qq0) 114 | ic = (haw * (y - full_q)) + (qq1 - qq0) - rd 115 | psi_tmle_std = 1.96 * np.sqrt(np.var(ic) / (t.shape[0])) 116 | 117 | return [rd, rd - psi_tmle_std, rd + psi_tmle_std] 118 | 119 | def truncate_by_g(self, attribute, g, level=0.1): 120 | keep_these = np.logical_and(g >= level, g <= 1. - level) 121 | return attribute[keep_these] 122 | 123 | def truncate_all_by_g(self, q_t0, q_t1, g, t, y, truncate_level=0.05): 124 | """ 125 | Helper function to clean up nuisance parameter estimates. 126 | """ 127 | orig_g = np.copy(g) 128 | q_t0 = self.truncate_by_g(np.copy(q_t0), orig_g, truncate_level) 129 | q_t1 = self.truncate_by_g(np.copy(q_t1), orig_g, truncate_level) 130 | g = self.truncate_by_g(np.copy(g), orig_g, truncate_level) 131 | t = self.truncate_by_g(np.copy(t), orig_g, truncate_level) 132 | y = self.truncate_by_g(np.copy(y), orig_g, truncate_level) 133 | return q_t0, q_t1, g, t, y 134 | 135 | def cross_entropy(self, y, p): 136 | return -np.mean((y * np.log(p) + (1. - y) * np.log(1. - p))) 137 | 138 | def collateFromFolds(self, foldNPZ): 139 | """ 140 | FYI: keys can be provided but default is below 141 | est_keys = { 142 | 'treatment_label_key' : 'treatment_label', 143 | 'outcome_key' : 'outcome', 144 | 'treatment_pred_key' : 'treatment', 145 | 'outcome_label_key' : 'outcome_label'} 146 | 147 | """ 148 | if self.est_keys is None: 149 | self.est_keys = { 150 | 'treatment_label_key': 'treatment_label', 151 | 'outcome_key': 'outcome', 152 | 'treatment_pred_key': 'treatment', 153 | 'outcome_label_key': 'outcome_label'} 154 | t_all = [] 155 | q1_all = [] 156 | q0_all = [] 157 | g_all = [] 158 | y_all = [] 159 | for fold in foldNPZ: 160 | ld = np.load(fold) 161 | t_all.append(ld[self.est_keys['treatment_label_key']]) 162 | q0_all.append(ld[self.est_keys['outcome_key']][:, 0]) 163 | q1_all.append(ld[self.est_keys['outcome_key']][:, 1]) 164 | y_all.append(ld[self.est_keys['outcome_label_key']]) 165 | g_all.append(ld[self.est_keys['treatment_pred_key']][:, 1]) 166 | t_all = np.array(list(itertools.chain(*t_all))).flatten() 167 | g_all = np.array(list(itertools.chain(*g_all))).flatten() 168 | q0_all = np.array(list(itertools.chain(*q0_all))).flatten() 169 | q1_all = np.array(list(itertools.chain(*q1_all))).flatten() 170 | y_all = np.array(list(itertools.chain(*y_all))).flatten() 171 | return q0_all, q1_all, y_all, g_all, t_all 172 | -------------------------------------------------------------------------------- /pytorch_pretrained_bert/optimization.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2018 The Google AI Language Team Authors and The HuggingFace Inc. team. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | """PyTorch optimization for BERT model.""" 16 | 17 | import math 18 | import torch 19 | from torch.optim import Optimizer 20 | from torch.optim.optimizer import required 21 | from torch.nn.utils import clip_grad_norm_ 22 | import logging 23 | 24 | logger = logging.getLogger(__name__) 25 | 26 | def warmup_cosine(x, warmup=0.002): 27 | if x < warmup: 28 | return x/warmup 29 | x_ = (x - warmup) / (1 - warmup) # progress after warmup - 30 | return 0.5 * (1. + math.cos(math.pi * x_)) 31 | 32 | def warmup_constant(x, warmup=0.002): 33 | """ Linearly increases learning rate over `warmup`*`t_total` (as provided to BertAdam) training steps. 34 | Learning rate is 1. afterwards. """ 35 | if x < warmup: 36 | return x/warmup 37 | return 1.0 38 | 39 | def warmup_linear(x, warmup=0.002): 40 | """ Specifies a triangular learning rate schedule where peak is reached at `warmup`*`t_total`-th (as provided to BertAdam) training step. 41 | After `t_total`-th training step, learning rate is zero. """ 42 | if x < warmup: 43 | return x/warmup 44 | return max((x-1.)/(warmup-1.), 0) 45 | 46 | SCHEDULES = { 47 | 'warmup_cosine': warmup_cosine, 48 | 'warmup_constant': warmup_constant, 49 | 'warmup_linear': warmup_linear, 50 | } 51 | 52 | 53 | class BertAdam(Optimizer): 54 | """Implements BERT version of Adam algorithm with weight decay fix. 55 | Params: 56 | lr: learning rate 57 | warmup: portion of t_total for the warmup, -1 means no warmup. Default: -1 58 | t_total: total number of training steps for the learning 59 | rate schedule, -1 means constant learning rate. Default: -1 60 | schedule: schedule to use for the warmup (see above). Default: 'warmup_linear' 61 | b1: Adams b1. Default: 0.9 62 | b2: Adams b2. Default: 0.999 63 | e: Adams epsilon. Default: 1e-6 64 | weight_decay: Weight decay. Default: 0.01 65 | max_grad_norm: Maximum norm for the gradients (-1 means no clipping). Default: 1.0 66 | """ 67 | def __init__(self, params, lr=required, warmup=-1, t_total=-1, schedule='warmup_linear', 68 | b1=0.9, b2=0.999, e=1e-6, weight_decay=0.01, 69 | max_grad_norm=1.0): 70 | if lr is not required and lr < 0.0: 71 | raise ValueError("Invalid learning rate: {} - should be >= 0.0".format(lr)) 72 | if schedule not in SCHEDULES: 73 | raise ValueError("Invalid schedule parameter: {}".format(schedule)) 74 | if not 0.0 <= warmup < 1.0 and not warmup == -1: 75 | raise ValueError("Invalid warmup: {} - should be in [0.0, 1.0[ or -1".format(warmup)) 76 | if not 0.0 <= b1 < 1.0: 77 | raise ValueError("Invalid b1 parameter: {} - should be in [0.0, 1.0[".format(b1)) 78 | if not 0.0 <= b2 < 1.0: 79 | raise ValueError("Invalid b2 parameter: {} - should be in [0.0, 1.0[".format(b2)) 80 | if not e >= 0.0: 81 | raise ValueError("Invalid epsilon value: {} - should be >= 0.0".format(e)) 82 | defaults = dict(lr=lr, schedule=schedule, warmup=warmup, t_total=t_total, 83 | b1=b1, b2=b2, e=e, weight_decay=weight_decay, 84 | max_grad_norm=max_grad_norm) 85 | super(BertAdam, self).__init__(params, defaults) 86 | 87 | def get_lr(self): 88 | lr = [] 89 | for group in self.param_groups: 90 | for p in group['params']: 91 | state = self.state[p] 92 | if len(state) == 0: 93 | return [0] 94 | if group['t_total'] != -1: 95 | schedule_fct = SCHEDULES[group['schedule']] 96 | lr_scheduled = group['lr'] * schedule_fct(state['step']/group['t_total'], group['warmup']) 97 | else: 98 | lr_scheduled = group['lr'] 99 | lr.append(lr_scheduled) 100 | return lr 101 | 102 | def step(self, closure=None): 103 | """Performs a single optimization step. 104 | 105 | Arguments: 106 | closure (callable, optional): A closure that reevaluates the model 107 | and returns the loss. 108 | """ 109 | loss = None 110 | if closure is not None: 111 | loss = closure() 112 | 113 | warned_for_t_total = False 114 | 115 | for group in self.param_groups: 116 | for p in group['params']: 117 | if p.grad is None: 118 | continue 119 | grad = p.grad.data 120 | if grad.is_sparse: 121 | raise RuntimeError('Adam does not support sparse gradients, please consider SparseAdam instead') 122 | 123 | state = self.state[p] 124 | 125 | # State initialization 126 | if len(state) == 0: 127 | state['step'] = 0 128 | # Exponential moving average of gradient values 129 | state['next_m'] = torch.zeros_like(p.data) 130 | # Exponential moving average of squared gradient values 131 | state['next_v'] = torch.zeros_like(p.data) 132 | 133 | next_m, next_v = state['next_m'], state['next_v'] 134 | beta1, beta2 = group['b1'], group['b2'] 135 | 136 | # Add grad clipping 137 | if group['max_grad_norm'] > 0: 138 | clip_grad_norm_(p, group['max_grad_norm']) 139 | 140 | # Decay the first and second moment running average coefficient 141 | # In-place operations to update the averages at the same time 142 | next_m.mul_(beta1).add_(1 - beta1, grad) 143 | next_v.mul_(beta2).addcmul_(1 - beta2, grad, grad) 144 | update = next_m / (next_v.sqrt() + group['e']) 145 | 146 | # Just adding the square of the weights to the loss function is *not* 147 | # the correct way of using L2 regularization/weight decay with Adam, 148 | # since that will interact with the m and v parameters in strange ways. 149 | # 150 | # Instead we want to decay the weights in a manner that doesn't interact 151 | # with the m/v parameters. This is equivalent to adding the square 152 | # of the weights to the loss with plain (non-momentum) SGD. 153 | if group['weight_decay'] > 0.0: 154 | update += group['weight_decay'] * p.data 155 | 156 | if group['t_total'] != -1: 157 | schedule_fct = SCHEDULES[group['schedule']] 158 | progress = state['step']/group['t_total'] 159 | lr_scheduled = group['lr'] * schedule_fct(progress, group['warmup']) 160 | # warning for exceeding t_total (only active with warmup_linear 161 | if group['schedule'] == "warmup_linear" and progress > 1. and not warned_for_t_total: 162 | logger.warning( 163 | "Training beyond specified 't_total' steps with schedule '{}'. Learning rate set to {}. " 164 | "Please set 't_total' of {} correctly.".format(group['schedule'], lr_scheduled, self.__class__.__name__)) 165 | warned_for_t_total = True 166 | # end warning 167 | else: 168 | lr_scheduled = group['lr'] 169 | 170 | update_with_lr = lr_scheduled * update 171 | p.data.add_(-update_with_lr) 172 | 173 | state['step'] += 1 174 | 175 | # step_size = lr_scheduled * math.sqrt(bias_correction2) / bias_correction1 176 | # No bias correction 177 | # bias_correction1 = 1 - beta1 ** state['step'] 178 | # bias_correction2 = 1 - beta2 ** state['step'] 179 | 180 | return loss 181 | -------------------------------------------------------------------------------- /pytorch_pretrained_bert/tokenization_gpt2.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2018 The Open AI Team Authors and The HuggingFace Inc. team. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | """Tokenization classes for OpenAI GPT.""" 16 | from __future__ import (absolute_import, division, print_function, 17 | unicode_literals) 18 | 19 | import json 20 | import logging 21 | import os 22 | import regex as re 23 | from io import open 24 | 25 | try: 26 | from functools import lru_cache 27 | except ImportError: 28 | # Just a dummy decorator to get the checks to run on python2 29 | # because honestly I don't want to support a byte-level unicode BPE tokenizer on python 2 right now. 30 | def lru_cache(): 31 | return lambda func: func 32 | 33 | from .file_utils import cached_path 34 | 35 | logger = logging.getLogger(__name__) 36 | 37 | PRETRAINED_VOCAB_ARCHIVE_MAP = { 38 | 'gpt2': "https://s3.amazonaws.com/models.huggingface.co/bert/gpt2-vocab.json", 39 | } 40 | PRETRAINED_MERGES_ARCHIVE_MAP = { 41 | 'gpt2': "https://s3.amazonaws.com/models.huggingface.co/bert/gpt2-merges.txt", 42 | } 43 | PRETRAINED_VOCAB_POSITIONAL_EMBEDDINGS_SIZE_MAP = { 44 | 'gpt2': 1024, 45 | } 46 | VOCAB_NAME = 'vocab.json' 47 | MERGES_NAME = 'merges.txt' 48 | 49 | @lru_cache() 50 | def bytes_to_unicode(): 51 | """ 52 | Returns list of utf-8 byte and a corresponding list of unicode strings. 53 | The reversible bpe codes work on unicode strings. 54 | This means you need a large # of unicode characters in your vocab if you want to avoid UNKs. 55 | When you're at something like a 10B token dataset you end up needing around 5K for decent coverage. 56 | This is a signficant percentage of your normal, say, 32K bpe vocab. 57 | To avoid that, we want lookup tables between utf-8 bytes and unicode strings. 58 | And avoids mapping to whitespace/control characters the bpe code barfs on. 59 | """ 60 | bs = list(range(ord("!"), ord("~")+1))+list(range(ord("¡"), ord("¬")+1))+list(range(ord("®"), ord("ÿ")+1)) 61 | cs = bs[:] 62 | n = 0 63 | for b in range(2**8): 64 | if b not in bs: 65 | bs.append(b) 66 | cs.append(2**8+n) 67 | n += 1 68 | cs = [chr(n) for n in cs] 69 | return dict(zip(bs, cs)) 70 | 71 | def get_pairs(word): 72 | """Return set of symbol pairs in a word. 73 | 74 | Word is represented as tuple of symbols (symbols being variable-length strings). 75 | """ 76 | pairs = set() 77 | prev_char = word[0] 78 | for char in word[1:]: 79 | pairs.add((prev_char, char)) 80 | prev_char = char 81 | return pairs 82 | 83 | class GPT2Tokenizer(object): 84 | """ 85 | GPT-2 BPE tokenizer. Peculiarities: 86 | - Byte-level BPE 87 | """ 88 | @classmethod 89 | def from_pretrained(cls, pretrained_model_name_or_path, cache_dir=None, *inputs, **kwargs): 90 | """ 91 | Instantiate a PreTrainedBertModel from a pre-trained model file. 92 | Download and cache the pre-trained model file if needed. 93 | """ 94 | if pretrained_model_name_or_path in PRETRAINED_VOCAB_ARCHIVE_MAP: 95 | vocab_file = PRETRAINED_VOCAB_ARCHIVE_MAP[pretrained_model_name_or_path] 96 | merges_file = PRETRAINED_MERGES_ARCHIVE_MAP[pretrained_model_name_or_path] 97 | else: 98 | vocab_file = os.path.join(pretrained_model_name_or_path, VOCAB_NAME) 99 | merges_file = os.path.join(pretrained_model_name_or_path, MERGES_NAME) 100 | # redirect to the cache, if necessary 101 | try: 102 | resolved_vocab_file = cached_path(vocab_file, cache_dir=cache_dir) 103 | resolved_merges_file = cached_path(merges_file, cache_dir=cache_dir) 104 | except EnvironmentError: 105 | logger.error( 106 | "Model name '{}' was not found in model name list ({}). " 107 | "We assumed '{}' was a path or url but couldn't find files {} and {} " 108 | "at this path or url.".format( 109 | pretrained_model_name_or_path, 110 | ', '.join(PRETRAINED_VOCAB_ARCHIVE_MAP.keys()), 111 | pretrained_model_name_or_path, 112 | vocab_file, merges_file)) 113 | return None 114 | if resolved_vocab_file == vocab_file and resolved_merges_file == merges_file: 115 | logger.info("loading vocabulary file {}".format(vocab_file)) 116 | logger.info("loading merges file {}".format(merges_file)) 117 | else: 118 | logger.info("loading vocabulary file {} from cache at {}".format( 119 | vocab_file, resolved_vocab_file)) 120 | logger.info("loading merges file {} from cache at {}".format( 121 | merges_file, resolved_merges_file)) 122 | if pretrained_model_name_or_path in PRETRAINED_VOCAB_POSITIONAL_EMBEDDINGS_SIZE_MAP: 123 | # if we're using a pretrained model, ensure the tokenizer wont index sequences longer 124 | # than the number of positional embeddings 125 | max_len = PRETRAINED_VOCAB_POSITIONAL_EMBEDDINGS_SIZE_MAP[pretrained_model_name_or_path] 126 | kwargs['max_len'] = min(kwargs.get('max_len', int(1e12)), max_len) 127 | # Instantiate tokenizer. 128 | tokenizer = cls(resolved_vocab_file, resolved_merges_file, *inputs, **kwargs) 129 | return tokenizer 130 | 131 | def __init__(self, vocab_file, merges_file, errors='replace', max_len=None): 132 | self.max_len = max_len if max_len is not None else int(1e12) 133 | self.encoder = json.load(open(vocab_file)) 134 | self.decoder = {v:k for k,v in self.encoder.items()} 135 | self.errors = errors # how to handle errors in decoding 136 | self.byte_encoder = bytes_to_unicode() 137 | self.byte_decoder = {v:k for k, v in self.byte_encoder.items()} 138 | bpe_data = open(merges_file, encoding='utf-8').read().split('\n')[1:-1] 139 | bpe_merges = [tuple(merge.split()) for merge in bpe_data] 140 | self.bpe_ranks = dict(zip(bpe_merges, range(len(bpe_merges)))) 141 | self.cache = {} 142 | 143 | # Should haved added re.IGNORECASE so BPE merges can happen for capitalized versions of contractions 144 | self.pat = re.compile(r"""'s|'t|'re|'ve|'m|'ll|'d| ?\p{L}+| ?\p{N}+| ?[^\s\p{L}\p{N}]+|\s+(?!\S)|\s+""") 145 | 146 | def __len__(self): 147 | return len(self.encoder) 148 | 149 | def bpe(self, token): 150 | if token in self.cache: 151 | return self.cache[token] 152 | word = tuple(token) 153 | pairs = get_pairs(word) 154 | 155 | if not pairs: 156 | return token 157 | 158 | while True: 159 | bigram = min(pairs, key = lambda pair: self.bpe_ranks.get(pair, float('inf'))) 160 | if bigram not in self.bpe_ranks: 161 | break 162 | first, second = bigram 163 | new_word = [] 164 | i = 0 165 | while i < len(word): 166 | try: 167 | j = word.index(first, i) 168 | new_word.extend(word[i:j]) 169 | i = j 170 | except: 171 | new_word.extend(word[i:]) 172 | break 173 | 174 | if word[i] == first and i < len(word)-1 and word[i+1] == second: 175 | new_word.append(first+second) 176 | i += 2 177 | else: 178 | new_word.append(word[i]) 179 | i += 1 180 | new_word = tuple(new_word) 181 | word = new_word 182 | if len(word) == 1: 183 | break 184 | else: 185 | pairs = get_pairs(word) 186 | word = ' '.join(word) 187 | self.cache[token] = word 188 | return word 189 | 190 | def encode(self, text): 191 | bpe_tokens = [] 192 | for token in re.findall(self.pat, text): 193 | token = ''.join(self.byte_encoder[b] for b in token.encode('utf-8')) 194 | bpe_tokens.extend(self.encoder[bpe_token] for bpe_token in self.bpe(token).split(' ')) 195 | if len(bpe_tokens) > self.max_len: 196 | logger.warning( 197 | "Token indices sequence length is longer than the specified maximum " 198 | " sequence length for this OpenAI GPT-2 model ({} > {}). Running this" 199 | " sequence through the model will result in indexing errors".format(len(bpe_tokens), self.max_len) 200 | ) 201 | return bpe_tokens 202 | 203 | def decode(self, tokens): 204 | text = ''.join([self.decoder[token] for token in tokens]) 205 | text = bytearray([self.byte_decoder[c] for c in text]).decode('utf-8', errors=self.errors) 206 | return text 207 | -------------------------------------------------------------------------------- /pytorch_pretrained_bert/file_utils.py: -------------------------------------------------------------------------------- 1 | """ 2 | Utilities for working with the local dataset cache. 3 | This file is adapted from the AllenNLP library at https://github.com/allenai/allennlp 4 | Copyright by the AllenNLP authors. 5 | """ 6 | from __future__ import (absolute_import, division, print_function, unicode_literals) 7 | 8 | import json 9 | import logging 10 | import os 11 | import shutil 12 | import tempfile 13 | from functools import wraps 14 | from hashlib import sha256 15 | import sys 16 | from io import open 17 | 18 | # import boto3 19 | # import requests 20 | # from botocore.exceptions import ClientError 21 | from tqdm import tqdm 22 | 23 | try: 24 | from urllib.parse import urlparse 25 | except ImportError: 26 | from urlparse import urlparse 27 | 28 | try: 29 | from pathlib import Path 30 | PYTORCH_PRETRAINED_BERT_CACHE = Path(os.getenv('PYTORCH_PRETRAINED_BERT_CACHE', 31 | Path.home() / '.pytorch_pretrained_bert')) 32 | except (AttributeError, ImportError): 33 | PYTORCH_PRETRAINED_BERT_CACHE = os.getenv('PYTORCH_PRETRAINED_BERT_CACHE', 34 | os.path.join(os.path.expanduser("~"), '.pytorch_pretrained_bert')) 35 | 36 | logger = logging.getLogger(__name__) # pylint: disable=invalid-name 37 | 38 | 39 | def url_to_filename(url, etag=None): 40 | """ 41 | Convert `url` into a hashed filename in a repeatable way. 42 | If `etag` is specified, append its hash to the url's, delimited 43 | by a period. 44 | """ 45 | url_bytes = url.encode('utf-8') 46 | url_hash = sha256(url_bytes) 47 | filename = url_hash.hexdigest() 48 | 49 | if etag: 50 | etag_bytes = etag.encode('utf-8') 51 | etag_hash = sha256(etag_bytes) 52 | filename += '.' + etag_hash.hexdigest() 53 | 54 | return filename 55 | 56 | 57 | def filename_to_url(filename, cache_dir=None): 58 | """ 59 | Return the url and etag (which may be ``None``) stored for `filename`. 60 | Raise ``EnvironmentError`` if `filename` or its stored metadata do not exist. 61 | """ 62 | if cache_dir is None: 63 | cache_dir = PYTORCH_PRETRAINED_BERT_CACHE 64 | if sys.version_info[0] == 3 and isinstance(cache_dir, Path): 65 | cache_dir = str(cache_dir) 66 | 67 | cache_path = os.path.join(cache_dir, filename) 68 | if not os.path.exists(cache_path): 69 | raise EnvironmentError("file {} not found".format(cache_path)) 70 | 71 | meta_path = cache_path + '.json' 72 | if not os.path.exists(meta_path): 73 | raise EnvironmentError("file {} not found".format(meta_path)) 74 | 75 | with open(meta_path, encoding="utf-8") as meta_file: 76 | metadata = json.load(meta_file) 77 | url = metadata['url'] 78 | etag = metadata['etag'] 79 | 80 | return url, etag 81 | 82 | 83 | def cached_path(url_or_filename, cache_dir=None): 84 | """ 85 | Given something that might be a URL (or might be a local path), 86 | determine which. If it's a URL, download the file and cache it, and 87 | return the path to the cached file. If it's already a local path, 88 | make sure the file exists and then return the path. 89 | """ 90 | if cache_dir is None: 91 | cache_dir = PYTORCH_PRETRAINED_BERT_CACHE 92 | if sys.version_info[0] == 3 and isinstance(url_or_filename, Path): 93 | url_or_filename = str(url_or_filename) 94 | if sys.version_info[0] == 3 and isinstance(cache_dir, Path): 95 | cache_dir = str(cache_dir) 96 | 97 | parsed = urlparse(url_or_filename) 98 | 99 | if parsed.scheme in ('http', 'https', 's3'): 100 | # URL, so get it from the cache (downloading if necessary) 101 | return get_from_cache(url_or_filename, cache_dir) 102 | elif os.path.exists(url_or_filename): 103 | # File, and it exists. 104 | return url_or_filename 105 | elif parsed.scheme == '': 106 | # File, but it doesn't exist. 107 | raise EnvironmentError("file {} not found".format(url_or_filename)) 108 | else: 109 | # Something unknown 110 | raise ValueError("unable to parse {} as a URL or as a local path".format(url_or_filename)) 111 | 112 | 113 | def split_s3_path(url): 114 | """Split a full s3 path into the bucket name and path.""" 115 | parsed = urlparse(url) 116 | if not parsed.netloc or not parsed.path: 117 | raise ValueError("bad s3 path {}".format(url)) 118 | bucket_name = parsed.netloc 119 | s3_path = parsed.path 120 | # Remove '/' at beginning of path. 121 | if s3_path.startswith("/"): 122 | s3_path = s3_path[1:] 123 | return bucket_name, s3_path 124 | 125 | 126 | def s3_request(func): 127 | """ 128 | Wrapper function for s3 requests in order to create more helpful error 129 | messages. 130 | """ 131 | 132 | @wraps(func) 133 | def wrapper(url, *args, **kwargs): 134 | try: 135 | return func(url, *args, **kwargs) 136 | except ClientError as exc: 137 | if int(exc.response["Error"]["Code"]) == 404: 138 | raise EnvironmentError("file {} not found".format(url)) 139 | else: 140 | raise 141 | 142 | return wrapper 143 | 144 | 145 | @s3_request 146 | def s3_etag(url): 147 | """Check ETag on S3 object.""" 148 | s3_resource = boto3.resource("s3") 149 | bucket_name, s3_path = split_s3_path(url) 150 | s3_object = s3_resource.Object(bucket_name, s3_path) 151 | return s3_object.e_tag 152 | 153 | 154 | @s3_request 155 | def s3_get(url, temp_file): 156 | """Pull a file directly from S3.""" 157 | s3_resource = boto3.resource("s3") 158 | bucket_name, s3_path = split_s3_path(url) 159 | s3_resource.Bucket(bucket_name).download_fileobj(s3_path, temp_file) 160 | 161 | 162 | def http_get(url, temp_file): 163 | req = requests.get(url, stream=True) 164 | content_length = req.headers.get('Content-Length') 165 | total = int(content_length) if content_length is not None else None 166 | progress = tqdm(unit="B", total=total) 167 | for chunk in req.iter_content(chunk_size=1024): 168 | if chunk: # filter out keep-alive new chunks 169 | progress.update(len(chunk)) 170 | temp_file.write(chunk) 171 | progress.close() 172 | 173 | 174 | def get_from_cache(url, cache_dir=None): 175 | """ 176 | Given a URL, look for the corresponding dataset in the local cache. 177 | If it's not there, download it. Then return the path to the cached file. 178 | """ 179 | if cache_dir is None: 180 | cache_dir = PYTORCH_PRETRAINED_BERT_CACHE 181 | if sys.version_info[0] == 3 and isinstance(cache_dir, Path): 182 | cache_dir = str(cache_dir) 183 | 184 | if not os.path.exists(cache_dir): 185 | os.makedirs(cache_dir) 186 | 187 | # Get eTag to add to filename, if it exists. 188 | if url.startswith("s3://"): 189 | etag = s3_etag(url) 190 | else: 191 | response = requests.head(url, allow_redirects=True) 192 | if response.status_code != 200: 193 | raise IOError("HEAD request failed for url {} with status code {}" 194 | .format(url, response.status_code)) 195 | etag = response.headers.get("ETag") 196 | 197 | filename = url_to_filename(url, etag) 198 | 199 | # get cache path to put the file 200 | cache_path = os.path.join(cache_dir, filename) 201 | 202 | if not os.path.exists(cache_path): 203 | # Download to temporary file, then copy to cache dir once finished. 204 | # Otherwise you get corrupt cache entries if the download gets interrupted. 205 | with tempfile.NamedTemporaryFile() as temp_file: 206 | logger.info("%s not found in cache, downloading to %s", url, temp_file.name) 207 | 208 | # GET file object 209 | if url.startswith("s3://"): 210 | s3_get(url, temp_file) 211 | else: 212 | http_get(url, temp_file) 213 | 214 | # we are copying the file before closing it, so flush to avoid truncation 215 | temp_file.flush() 216 | # shutil.copyfileobj() starts at the current position, so go to the start 217 | temp_file.seek(0) 218 | 219 | logger.info("copying %s to cache at %s", temp_file.name, cache_path) 220 | with open(cache_path, 'wb') as cache_file: 221 | shutil.copyfileobj(temp_file, cache_file) 222 | 223 | logger.info("creating metadata file for %s", cache_path) 224 | meta = {'url': url, 'etag': etag} 225 | meta_path = cache_path + '.json' 226 | with open(meta_path, 'w', encoding="utf-8") as meta_file: 227 | json.dump(meta, meta_file) 228 | 229 | logger.info("removing temp file %s", temp_file.name) 230 | 231 | return cache_path 232 | 233 | 234 | def read_set_from_file(filename): 235 | ''' 236 | Extract a de-duped collection (set) of text from a file. 237 | Expected file format is one item per line. 238 | ''' 239 | collection = set() 240 | with open(filename, 'r', encoding='utf-8') as file_: 241 | for line in file_: 242 | collection.add(line.rstrip()) 243 | return collection 244 | 245 | 246 | def get_file_extension(path, dot=True, lower=True): 247 | ext = os.path.splitext(path)[1] 248 | ext = ext if dot else ext[1:] 249 | return ext.lower() if lower else ext 250 | -------------------------------------------------------------------------------- /src/utils.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import _pickle as pickle 3 | import random 4 | import torch.nn as nn 5 | import torch 6 | import os 7 | import sklearn 8 | import sklearn.metrics as skm 9 | import warnings 10 | 11 | 12 | def nonMASK(tokens, token2idx): 13 | output_label = [] 14 | output_token = [] 15 | for i, token in enumerate(tokens): 16 | prob = random.random() 17 | # mask token with 15% probability 18 | if prob < 0: 19 | prob /= 0.15 20 | 21 | # 80% randomly change token to mask token 22 | if prob < 0.8: 23 | output_token.append(token2idx["MASK"]) 24 | 25 | # 10% randomly change token to random token 26 | elif prob < 0.9: 27 | output_token.append(random.choice(list(token2idx.values()))) 28 | 29 | # -> rest 10% randomly keep current token 30 | 31 | # append current token to output (we will predict these later 32 | output_label.append(token2idx.get(token, token2idx['UNK'])) 33 | else: 34 | # no masking token (will be ignored by loss function later) 35 | output_label.append(-1) 36 | output_token.append(token2idx.get(token, token2idx['UNK'])) 37 | 38 | return tokens, output_token, output_label 39 | 40 | 41 | # static var masking 42 | def covarUnsupMaker(covar, covarprobb=0.4): 43 | inputcovar = [] 44 | labelcovar = [] 45 | for i,x in enumerate(covar): 46 | prob = random.random() 47 | if x != 0: 48 | if prob rest 10% randomly keep current token 86 | else: 87 | output_label.append(-1) 88 | 89 | # append current token to output (we will predict these later 90 | output_token.append(token2idx.get(token, token2idx['UNK'])) 91 | 92 | 93 | 94 | else: 95 | # no masking token (will be ignored by loss function later) 96 | output_label.append(-1) 97 | output_token.append(token2idx.get(token, token2idx['UNK'])) 98 | 99 | return tokens, output_token, output_label 100 | 101 | 102 | 103 | def save_obj(obj, name): 104 | with open(name + '.pkl', 'wb') as f: 105 | pickle.dump(obj, f) 106 | 107 | 108 | def load_obj(name): 109 | with open(name + '.pkl', 'rb') as f: 110 | return pickle.load(f) 111 | 112 | 113 | def code2index(tokens, token2idx): 114 | output_tokens = [] 115 | for i, token in enumerate(tokens): 116 | output_tokens.append(token2idx.get(token, token2idx['UNK'])) 117 | return tokens, output_tokens 118 | 119 | 120 | 121 | 122 | def index_seg(tokens, symbol='SEP'): 123 | flag = 0 124 | seg = [] 125 | 126 | for token in tokens: 127 | if token == symbol: 128 | seg.append(flag) 129 | if flag == 0: 130 | flag = 1 131 | else: 132 | flag = 0 133 | else: 134 | seg.append(flag) 135 | return seg 136 | 137 | 138 | def position_idx(tokens, symbol='SEP'): 139 | pos = [] 140 | flag = 0 141 | 142 | for token in tokens: 143 | if token == symbol: 144 | pos.append(flag) 145 | flag += 1 146 | else: 147 | pos.append(flag) 148 | return pos 149 | 150 | 151 | def age_vocab(max_age, year=False, symbol=None): 152 | age2idx = {} 153 | idx2age = {} 154 | if symbol is None: 155 | symbol = ['PAD', 'UNK'] 156 | 157 | for i in range(len(symbol)): 158 | age2idx[str(symbol[i])] = i 159 | idx2age[i] = str(symbol[i]) 160 | 161 | if year: 162 | for i in range(max_age): 163 | age2idx[str(i)] = len(symbol) + i 164 | idx2age[len(symbol) + i] = str(i) 165 | else: 166 | for i in range(max_age * 12): 167 | age2idx[str(i)] = len(symbol) + i 168 | idx2age[len(symbol) + i] = str(i) 169 | 170 | return age2idx, idx2age 171 | 172 | 173 | def seq_padding(tokens, max_len, token2idx=None, symbol=None): 174 | if symbol is None: 175 | symbol = 'PAD' 176 | 177 | seq = [] 178 | token_len = len(tokens) 179 | for i in range(max_len): 180 | if token2idx is None: 181 | if i < token_len: 182 | seq.append(tokens[i]) 183 | else: 184 | seq.append(symbol) 185 | else: 186 | if i < token_len: 187 | # 1 indicate UNK 188 | seq.append(token2idx.get(tokens[i], token2idx['UNK'])) 189 | else: 190 | seq.append(token2idx.get(symbol)) 191 | return seq 192 | 193 | 194 | def seq_padding_reverse(tokens, max_len, token2idx=None, symbol=None): 195 | if symbol is None: 196 | symbol = 'PAD' 197 | 198 | seq = [] 199 | token_len = len(tokens) 200 | tokens = tokens[::-1] 201 | for i in range(max_len): 202 | if token2idx is None: 203 | if i < token_len: 204 | seq.append(tokens[i]) 205 | else: 206 | seq.append(symbol) 207 | else: 208 | if i < token_len: 209 | # 1 indicate UNK 210 | seq.append(token2idx.get(tokens[i], token2idx['UNK'])) 211 | else: 212 | seq.append(token2idx.get(symbol)) 213 | return seq[::-1] 214 | 215 | 216 | def age_seq_padding(tokens, max_len, token2idx=None, symbol=None): 217 | if symbol is None: 218 | symbol = 'PAD' 219 | 220 | seq = [] 221 | token_len = len(tokens) 222 | for i in range(max_len): 223 | if token2idx is None: 224 | if i < token_len: 225 | seq.append(tokens[i]) 226 | else: 227 | seq.append(symbol) 228 | else: 229 | if i < token_len: 230 | # 1 indicate UNK 231 | seq.append(token2idx[tokens[i]]) 232 | else: 233 | seq.append(token2idx[symbol]) 234 | return seq 235 | 236 | 237 | 238 | def cal_acc(label, pred, logS=True): 239 | logs = nn.LogSoftmax() 240 | label = label.cpu().numpy() 241 | ind = np.where(label != -1)[0] 242 | truepred = pred.detach().cpu().numpy() 243 | truepred = truepred[ind] 244 | truelabel = label[ind] 245 | if logS == True: 246 | truepred = logs(torch.tensor(truepred)) 247 | else: 248 | truepred = torch.tensor(truepred) 249 | outs = [np.argmax(pred_x) for pred_x in truepred.numpy()] 250 | precision = skm.precision_score(truelabel, outs, average='micro') 251 | 252 | return precision 253 | 254 | def cal_acc(label, pred, logS=True): 255 | logs = nn.LogSoftmax() 256 | label = label.cpu().numpy() 257 | ind = np.where(label != -1)[0] 258 | truepred = pred.detach().cpu().numpy() 259 | truepred = truepred[ind] 260 | truelabel = label[ind] 261 | if logS ==True: 262 | truepred = logs(torch.tensor(truepred)) 263 | else: 264 | truepred = torch.tensor(truepred) 265 | outs = [np.argmax(pred_x) for pred_x in truepred.numpy()] 266 | precision = skm.precision_score(truelabel, outs, average='micro') 267 | 268 | return precision 269 | 270 | def partition(values, indices): 271 | idx = 0 272 | for index in indices: 273 | sublist = [] 274 | idxfill = [] 275 | while idx < len(values) and values[idx] <= index: 276 | # sublist.append(values[idx]) 277 | idxfill.append(idx) 278 | 279 | idx += 1 280 | if idxfill: 281 | yield idxfill 282 | 283 | 284 | def toLoad(model, filepath, custom=None): 285 | pre_bert = filepath 286 | 287 | pretrained_dict = torch.load(pre_bert, map_location='cpu') 288 | modeld = model.state_dict() 289 | # 1. filter out unnecessary keys 290 | if custom == None: 291 | pretrained_dict = {k: v for k, v in pretrained_dict.items() if k in modeld} 292 | else: 293 | pretrained_dict = {k: v for k, v in pretrained_dict.items() if k in modeld and k not in custom} 294 | 295 | modeld.update(pretrained_dict) 296 | # 3. load the new state dict 297 | model.load_state_dict(modeld) 298 | return model 299 | 300 | 301 | 302 | 303 | def OutcomePrecision(logits, label, sig=True): 304 | sig = nn.Sigmoid() 305 | if sig == True: 306 | output = sig(logits) 307 | else: 308 | output = logits 309 | label, output = label.cpu(), output.detach().cpu() 310 | tempprc = sklearn.metrics.average_precision_score(label.numpy(), output.numpy()) 311 | return tempprc, output, label 312 | 313 | 314 | def set_requires_grad(model, requires_grad=True): 315 | for param in model.parameters(): 316 | param.requires_grad = requires_grad 317 | 318 | 319 | def precision_test(logits, label, sig=True): 320 | sigm = nn.Sigmoid() 321 | if sig == True: 322 | output = sigm(logits) 323 | else: 324 | output = logits 325 | label, output = label.cpu(), output.detach().cpu() 326 | 327 | tempprc = sklearn.metrics.average_precision_score(label.numpy(), output.numpy()) 328 | return tempprc, output, label 329 | 330 | 331 | def roc_auc(logits, label, sig=True): 332 | sigm = nn.Sigmoid() 333 | if sig == True: 334 | output = sigm(logits) 335 | else: 336 | output = logits 337 | label, output = label.cpu(), output.detach().cpu() 338 | 339 | tempprc = sklearn.metrics.roc_auc_score(label.numpy(), output.numpy()) 340 | return tempprc, output, label 341 | 342 | 343 | # golobal function 344 | def create_folder(path): 345 | if not os.path.exists(path): 346 | os.mkdir(path) 347 | 348 | -------------------------------------------------------------------------------- /pytorch_pretrained_bert/tokenization_openai.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2018 The Open AI Team Authors and The HuggingFace Inc. team. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | """Tokenization classes for OpenAI GPT.""" 16 | from __future__ import (absolute_import, division, print_function, 17 | unicode_literals) 18 | 19 | import json 20 | import logging 21 | import os 22 | import re 23 | import sys 24 | from io import open 25 | 26 | from tqdm import tqdm 27 | 28 | from .file_utils import cached_path 29 | from .tokenization import BasicTokenizer 30 | 31 | logger = logging.getLogger(__name__) 32 | 33 | PRETRAINED_VOCAB_ARCHIVE_MAP = { 34 | 'openai-gpt': "https://s3.amazonaws.com/models.huggingface.co/bert/openai-gpt-vocab.json", 35 | } 36 | PRETRAINED_MERGES_ARCHIVE_MAP = { 37 | 'openai-gpt': "https://s3.amazonaws.com/models.huggingface.co/bert/openai-gpt-merges.txt", 38 | } 39 | PRETRAINED_VOCAB_POSITIONAL_EMBEDDINGS_SIZE_MAP = { 40 | 'openai-gpt': 512, 41 | } 42 | VOCAB_NAME = 'vocab.json' 43 | MERGES_NAME = 'merges.txt' 44 | 45 | def get_pairs(word): 46 | """ 47 | Return set of symbol pairs in a word. 48 | word is represented as tuple of symbols (symbols being variable-length strings) 49 | """ 50 | pairs = set() 51 | prev_char = word[0] 52 | for char in word[1:]: 53 | pairs.add((prev_char, char)) 54 | prev_char = char 55 | return pairs 56 | 57 | def text_standardize(text): 58 | """ 59 | fixes some issues the spacy tokenizer had on books corpus 60 | also does some whitespace standardization 61 | """ 62 | text = text.replace('—', '-') 63 | text = text.replace('–', '-') 64 | text = text.replace('―', '-') 65 | text = text.replace('…', '...') 66 | text = text.replace('´', "'") 67 | text = re.sub(r'''(-+|~+|!+|"+|;+|\?+|\++|,+|\)+|\(+|\\+|\/+|\*+|\[+|\]+|}+|{+|\|+|_+)''', r' \1 ', text) 68 | text = re.sub(r'\s*\n\s*', ' \n ', text) 69 | text = re.sub(r'[^\S\n]+', ' ', text) 70 | return text.strip() 71 | 72 | class OpenAIGPTTokenizer(object): 73 | """ 74 | BPE tokenizer. Peculiarities: 75 | - lower case all inputs 76 | - uses SpaCy tokenizer and ftfy for pre-BPE tokenization if they are installed, fallback to BERT's BasicTokenizer if not. 77 | - argument special_tokens and function set_special_tokens: 78 | can be used to add additional symbols (ex: "__classify__") to a vocabulary. 79 | """ 80 | @classmethod 81 | def from_pretrained(cls, pretrained_model_name_or_path, cache_dir=None, *inputs, **kwargs): 82 | """ 83 | Instantiate a PreTrainedBertModel from a pre-trained model file. 84 | Download and cache the pre-trained model file if needed. 85 | """ 86 | if pretrained_model_name_or_path in PRETRAINED_VOCAB_ARCHIVE_MAP: 87 | vocab_file = PRETRAINED_VOCAB_ARCHIVE_MAP[pretrained_model_name_or_path] 88 | merges_file = PRETRAINED_MERGES_ARCHIVE_MAP[pretrained_model_name_or_path] 89 | else: 90 | vocab_file = os.path.join(pretrained_model_name_or_path, VOCAB_NAME) 91 | merges_file = os.path.join(pretrained_model_name_or_path, MERGES_NAME) 92 | # redirect to the cache, if necessary 93 | try: 94 | resolved_vocab_file = cached_path(vocab_file, cache_dir=cache_dir) 95 | resolved_merges_file = cached_path(merges_file, cache_dir=cache_dir) 96 | except EnvironmentError: 97 | logger.error( 98 | "Model name '{}' was not found in model name list ({}). " 99 | "We assumed '{}' was a path or url but couldn't find files {} and {} " 100 | "at this path or url.".format( 101 | pretrained_model_name_or_path, 102 | ', '.join(PRETRAINED_VOCAB_ARCHIVE_MAP.keys()), 103 | pretrained_model_name_or_path, 104 | vocab_file, merges_file)) 105 | return None 106 | if resolved_vocab_file == vocab_file and resolved_merges_file == merges_file: 107 | logger.info("loading vocabulary file {}".format(vocab_file)) 108 | logger.info("loading merges file {}".format(merges_file)) 109 | else: 110 | logger.info("loading vocabulary file {} from cache at {}".format( 111 | vocab_file, resolved_vocab_file)) 112 | logger.info("loading merges file {} from cache at {}".format( 113 | merges_file, resolved_merges_file)) 114 | if pretrained_model_name_or_path in PRETRAINED_VOCAB_POSITIONAL_EMBEDDINGS_SIZE_MAP: 115 | # if we're using a pretrained model, ensure the tokenizer wont index sequences longer 116 | # than the number of positional embeddings 117 | max_len = PRETRAINED_VOCAB_POSITIONAL_EMBEDDINGS_SIZE_MAP[pretrained_model_name_or_path] 118 | kwargs['max_len'] = min(kwargs.get('max_len', int(1e12)), max_len) 119 | # Instantiate tokenizer. 120 | tokenizer = cls(resolved_vocab_file, resolved_merges_file, *inputs, **kwargs) 121 | return tokenizer 122 | 123 | def __init__(self, vocab_file, merges_file, special_tokens=None, max_len=None): 124 | try: 125 | import ftfy 126 | import spacy 127 | self.nlp = spacy.load('en', disable=['parser', 'tagger', 'ner', 'textcat']) 128 | self.fix_text = ftfy.fix_text 129 | except ImportError: 130 | logger.warning("ftfy or spacy is not installed using BERT BasicTokenizer instead of SpaCy & ftfy.") 131 | self.nlp = BasicTokenizer(do_lower_case=True, 132 | never_split=special_tokens if special_tokens is not None else []) 133 | self.fix_text = None 134 | 135 | self.max_len = max_len if max_len is not None else int(1e12) 136 | self.encoder = json.load(open(vocab_file, encoding="utf-8")) 137 | self.decoder = {v:k for k,v in self.encoder.items()} 138 | merges = open(merges_file, encoding='utf-8').read().split('\n')[1:-1] 139 | merges = [tuple(merge.split()) for merge in merges] 140 | self.bpe_ranks = dict(zip(merges, range(len(merges)))) 141 | self.cache = {} 142 | self.set_special_tokens(special_tokens) 143 | 144 | def __len__(self): 145 | return len(self.encoder) + len(self.special_tokens) 146 | 147 | def set_special_tokens(self, special_tokens): 148 | """ Add a list of additional tokens to the encoder. 149 | The additional tokens are indexed starting from the last index of the 150 | current vocabulary in the order of the `special_tokens` list. 151 | """ 152 | if not special_tokens: 153 | self.special_tokens = {} 154 | self.special_tokens_decoder = {} 155 | return 156 | self.special_tokens = dict((tok, len(self.encoder) + i) for i, tok in enumerate(special_tokens)) 157 | self.special_tokens_decoder = {v:k for k, v in self.special_tokens.items()} 158 | if self.fix_text is None: 159 | # Using BERT's BasicTokenizer: we can update the tokenizer 160 | self.nlp.never_split = special_tokens 161 | logger.info("Special tokens {}".format(self.special_tokens)) 162 | 163 | def bpe(self, token): 164 | word = tuple(token[:-1]) + (token[-1] + '',) 165 | if token in self.cache: 166 | return self.cache[token] 167 | pairs = get_pairs(word) 168 | 169 | if not pairs: 170 | return token+'' 171 | 172 | while True: 173 | bigram = min(pairs, key=lambda pair: self.bpe_ranks.get(pair, float('inf'))) 174 | if bigram not in self.bpe_ranks: 175 | break 176 | first, second = bigram 177 | new_word = [] 178 | i = 0 179 | while i < len(word): 180 | try: 181 | j = word.index(first, i) 182 | new_word.extend(word[i:j]) 183 | i = j 184 | except: 185 | new_word.extend(word[i:]) 186 | break 187 | 188 | if word[i] == first and i < len(word)-1 and word[i+1] == second: 189 | new_word.append(first+second) 190 | i += 2 191 | else: 192 | new_word.append(word[i]) 193 | i += 1 194 | new_word = tuple(new_word) 195 | word = new_word 196 | if len(word) == 1: 197 | break 198 | else: 199 | pairs = get_pairs(word) 200 | word = ' '.join(word) 201 | if word == '\n ': 202 | word = '\n' 203 | self.cache[token] = word 204 | return word 205 | 206 | def tokenize(self, text): 207 | """ Tokenize a string. """ 208 | split_tokens = [] 209 | if self.fix_text is None: 210 | # Using BERT's BasicTokenizer 211 | text = self.nlp.tokenize(text) 212 | for token in text: 213 | split_tokens.extend([t for t in self.bpe(token).split(' ')]) 214 | else: 215 | # Using SpaCy & ftfy (original tokenization process of OpenAI GPT) 216 | text = self.nlp(text_standardize(self.fix_text(text))) 217 | for token in text: 218 | split_tokens.extend([t for t in self.bpe(token.text.lower()).split(' ')]) 219 | return split_tokens 220 | 221 | def convert_tokens_to_ids(self, tokens): 222 | """ Converts a sequence of tokens into ids using the vocab. """ 223 | ids = [] 224 | if isinstance(tokens, str) or (sys.version_info[0] == 2 and isinstance(tokens, unicode)): 225 | if tokens in self.special_tokens: 226 | return self.special_tokens[tokens] 227 | else: 228 | return self.encoder.get(tokens, 0) 229 | for token in tokens: 230 | if token in self.special_tokens: 231 | ids.append(self.special_tokens[token]) 232 | else: 233 | ids.append(self.encoder.get(token, 0)) 234 | if len(ids) > self.max_len: 235 | logger.warning( 236 | "Token indices sequence length is longer than the specified maximum " 237 | " sequence length for this OpenAI GPT model ({} > {}). Running this" 238 | " sequence through the model will result in indexing errors".format(len(ids), self.max_len) 239 | ) 240 | return ids 241 | 242 | def convert_ids_to_tokens(self, ids, skip_special_tokens=False): 243 | """Converts a sequence of ids in BPE tokens using the vocab.""" 244 | tokens = [] 245 | for i in ids: 246 | if i in self.special_tokens_decoder: 247 | if not skip_special_tokens: 248 | tokens.append(self.special_tokens_decoder[i]) 249 | else: 250 | tokens.append(self.decoder[i]) 251 | return tokens 252 | 253 | def decode(self, ids, skip_special_tokens=False, clean_up_tokenization_spaces=False): 254 | """Converts a sequence of ids in a string.""" 255 | tokens = self.convert_ids_to_tokens(ids, skip_special_tokens=skip_special_tokens) 256 | out_string = ''.join(tokens).replace('', ' ').strip() 257 | if clean_up_tokenization_spaces: 258 | out_string = out_string.replace('', '') 259 | out_string = out_string.replace(' .', '.').replace(' ?', '?').replace(' !', '!').replace(' ,', ',').replace(' ,', ',' 260 | ).replace(" n't", "n't").replace(" 'm", "'m").replace(" 're", "'re").replace(" do not", " don't" 261 | ).replace(" 's", "'s").replace(" t ", "'t ").replace(" s ", "'s ").replace(" m ", "'m " 262 | ).replace(" 've", "'ve") 263 | return out_string 264 | -------------------------------------------------------------------------------- /pytorch_pretrained_bert/modeling_transfo_xl_utilities.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2018 Google AI, Google Brain and Carnegie Mellon University Authors and the HuggingFace Inc. team. 3 | # Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved. 4 | # 5 | # Licensed under the Apache License, Version 2.0 (the "License"); 6 | # you may not use this file except in compliance with the License. 7 | # You may obtain a copy of the License at 8 | # 9 | # http://www.apache.org/licenses/LICENSE-2.0 10 | # 11 | # Unless required by applicable law or agreed to in writing, software 12 | # distributed under the License is distributed on an "AS IS" BASIS, 13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | # See the License for the specific language governing permissions and 15 | # limitations under the License. 16 | """ Utilities for PyTorch Transformer XL model. 17 | Directly adapted from https://github.com/kimiyoung/transformer-xl. 18 | """ 19 | 20 | from collections import defaultdict 21 | 22 | import numpy as np 23 | 24 | import torch 25 | import torch.nn as nn 26 | import torch.nn.functional as F 27 | 28 | # CUDA_MAJOR = int(torch.version.cuda.split('.')[0]) 29 | # CUDA_MINOR = int(torch.version.cuda.split('.')[1]) 30 | 31 | class ProjectedAdaptiveLogSoftmax(nn.Module): 32 | def __init__(self, n_token, d_embed, d_proj, cutoffs, div_val=1, 33 | keep_order=False): 34 | super(ProjectedAdaptiveLogSoftmax, self).__init__() 35 | 36 | self.n_token = n_token 37 | self.d_embed = d_embed 38 | self.d_proj = d_proj 39 | 40 | self.cutoffs = cutoffs + [n_token] 41 | self.cutoff_ends = [0] + self.cutoffs 42 | self.div_val = div_val 43 | 44 | self.shortlist_size = self.cutoffs[0] 45 | self.n_clusters = len(self.cutoffs) - 1 46 | self.head_size = self.shortlist_size + self.n_clusters 47 | 48 | if self.n_clusters > 0: 49 | self.cluster_weight = nn.Parameter(torch.zeros(self.n_clusters, self.d_embed)) 50 | self.cluster_bias = nn.Parameter(torch.zeros(self.n_clusters)) 51 | 52 | self.out_layers = nn.ModuleList() 53 | self.out_projs = nn.ParameterList() 54 | 55 | if div_val == 1: 56 | for i in range(len(self.cutoffs)): 57 | if d_proj != d_embed: 58 | self.out_projs.append( 59 | nn.Parameter(torch.Tensor(d_proj, d_embed)) 60 | ) 61 | else: 62 | self.out_projs.append(None) 63 | 64 | self.out_layers.append(nn.Linear(d_embed, n_token)) 65 | else: 66 | for i in range(len(self.cutoffs)): 67 | l_idx, r_idx = self.cutoff_ends[i], self.cutoff_ends[i+1] 68 | d_emb_i = d_embed // (div_val ** i) 69 | 70 | self.out_projs.append( 71 | nn.Parameter(torch.Tensor(d_proj, d_emb_i)) 72 | ) 73 | 74 | self.out_layers.append(nn.Linear(d_emb_i, r_idx-l_idx)) 75 | 76 | self.keep_order = keep_order 77 | 78 | def _compute_logit(self, hidden, weight, bias, proj): 79 | if proj is None: 80 | logit = F.linear(hidden, weight, bias=bias) 81 | else: 82 | # if CUDA_MAJOR <= 9 and CUDA_MINOR <= 1: 83 | proj_hid = F.linear(hidden, proj.t().contiguous()) 84 | logit = F.linear(proj_hid, weight, bias=bias) 85 | # else: 86 | # logit = torch.einsum('bd,de,ev->bv', (hidden, proj, weight.t())) 87 | # if bias is not None: 88 | # logit = logit + bias 89 | 90 | return logit 91 | 92 | def forward(self, hidden, target=None, keep_order=False): 93 | ''' 94 | Params: 95 | hidden :: [len*bsz x d_proj] 96 | target :: [len*bsz] 97 | Return: 98 | if target is None: 99 | out :: [len*bsz] Negative log likelihood 100 | else: 101 | out :: [len*bsz x n_tokens] log probabilities of tokens over the vocabulary 102 | We could replace this implementation by the native PyTorch one 103 | if their's had an option to set bias on all clusters in the native one. 104 | here: https://github.com/pytorch/pytorch/blob/dbe6a7a9ff1a364a8706bf5df58a1ca96d2fd9da/torch/nn/modules/adaptive.py#L138 105 | ''' 106 | 107 | if target is not None: 108 | target = target.view(-1) 109 | if hidden.size(0) != target.size(0): 110 | raise RuntimeError('Input and target should have the same size ' 111 | 'in the batch dimension.') 112 | 113 | if self.n_clusters == 0: 114 | logit = self._compute_logit(hidden, self.out_layers[0].weight, 115 | self.out_layers[0].bias, self.out_projs[0]) 116 | if target is not None: 117 | output = -F.log_softmax(logit, dim=-1) \ 118 | .gather(1, target.unsqueeze(1)).squeeze(1) 119 | else: 120 | output = F.log_softmax(logit, dim=-1) 121 | else: 122 | # construct weights and biases 123 | weights, biases = [], [] 124 | for i in range(len(self.cutoffs)): 125 | if self.div_val == 1: 126 | l_idx, r_idx = self.cutoff_ends[i], self.cutoff_ends[i + 1] 127 | weight_i = self.out_layers[0].weight[l_idx:r_idx] 128 | bias_i = self.out_layers[0].bias[l_idx:r_idx] 129 | else: 130 | weight_i = self.out_layers[i].weight 131 | bias_i = self.out_layers[i].bias 132 | 133 | if i == 0: 134 | weight_i = torch.cat( 135 | [weight_i, self.cluster_weight], dim=0) 136 | bias_i = torch.cat( 137 | [bias_i, self.cluster_bias], dim=0) 138 | 139 | weights.append(weight_i) 140 | biases.append(bias_i) 141 | 142 | head_weight, head_bias, head_proj = weights[0], biases[0], self.out_projs[0] 143 | 144 | head_logit = self._compute_logit(hidden, head_weight, head_bias, head_proj) 145 | head_logprob = F.log_softmax(head_logit, dim=1) 146 | 147 | if target is None: 148 | out = hidden.new_empty((head_logit.size(0), self.n_token)) 149 | else: 150 | out = torch.zeros_like(target, dtype=hidden.dtype, device=hidden.device) 151 | 152 | offset = 0 153 | cutoff_values = [0] + self.cutoffs 154 | for i in range(len(cutoff_values) - 1): 155 | l_idx, r_idx = cutoff_values[i], cutoff_values[i + 1] 156 | 157 | if target is not None: 158 | mask_i = (target >= l_idx) & (target < r_idx) 159 | indices_i = mask_i.nonzero().squeeze() 160 | 161 | if indices_i.numel() == 0: 162 | continue 163 | 164 | target_i = target.index_select(0, indices_i) - l_idx 165 | head_logprob_i = head_logprob.index_select(0, indices_i) 166 | hidden_i = hidden.index_select(0, indices_i) 167 | else: 168 | hidden_i = hidden 169 | 170 | if i == 0: 171 | if target is not None: 172 | logprob_i = head_logprob_i.gather(1, target_i[:, None]).squeeze(1) 173 | else: 174 | out[:, :self.cutoffs[0]] = head_logprob[:, :self.cutoffs[0]] 175 | else: 176 | weight_i, bias_i, proj_i = weights[i], biases[i], self.out_projs[i] 177 | 178 | tail_logit_i = self._compute_logit(hidden_i, weight_i, bias_i, proj_i) 179 | tail_logprob_i = F.log_softmax(tail_logit_i, dim=1) 180 | cluster_prob_idx = self.cutoffs[0] + i - 1 # No probability for the head cluster 181 | if target is not None: 182 | logprob_i = head_logprob_i[:, cluster_prob_idx] \ 183 | + tail_logprob_i.gather(1, target_i[:, None]).squeeze(1) 184 | else: 185 | logprob_i = head_logprob[:, cluster_prob_idx, None] + tail_logprob_i 186 | out[:, l_idx:r_idx] = logprob_i 187 | 188 | if target is not None: 189 | if (hasattr(self, 'keep_order') and self.keep_order) or keep_order: 190 | out.index_copy_(0, indices_i, -logprob_i) 191 | else: 192 | out[offset:offset+logprob_i.size(0)].copy_(-logprob_i) 193 | offset += logprob_i.size(0) 194 | 195 | return out 196 | 197 | 198 | def log_prob(self, hidden): 199 | r""" Computes log probabilities for all :math:`n\_classes` 200 | From: https://github.com/pytorch/pytorch/blob/master/torch/nn/modules/adaptive.py 201 | Args: 202 | hidden (Tensor): a minibatch of examples 203 | Returns: 204 | log-probabilities of for each class :math:`c` 205 | in range :math:`0 <= c <= n\_classes`, where :math:`n\_classes` is a 206 | parameter passed to ``AdaptiveLogSoftmaxWithLoss`` constructor. 207 | Shape: 208 | - Input: :math:`(N, in\_features)` 209 | - Output: :math:`(N, n\_classes)` 210 | """ 211 | if self.n_clusters == 0: 212 | logit = self._compute_logit(hidden, self.out_layers[0].weight, 213 | self.out_layers[0].bias, self.out_projs[0]) 214 | return F.log_softmax(logit, dim=-1) 215 | else: 216 | # construct weights and biases 217 | weights, biases = [], [] 218 | for i in range(len(self.cutoffs)): 219 | if self.div_val == 1: 220 | l_idx, r_idx = self.cutoff_ends[i], self.cutoff_ends[i + 1] 221 | weight_i = self.out_layers[0].weight[l_idx:r_idx] 222 | bias_i = self.out_layers[0].bias[l_idx:r_idx] 223 | else: 224 | weight_i = self.out_layers[i].weight 225 | bias_i = self.out_layers[i].bias 226 | 227 | if i == 0: 228 | weight_i = torch.cat( 229 | [weight_i, self.cluster_weight], dim=0) 230 | bias_i = torch.cat( 231 | [bias_i, self.cluster_bias], dim=0) 232 | 233 | weights.append(weight_i) 234 | biases.append(bias_i) 235 | 236 | head_weight, head_bias, head_proj = weights[0], biases[0], self.out_projs[0] 237 | head_logit = self._compute_logit(hidden, head_weight, head_bias, head_proj) 238 | 239 | out = hidden.new_empty((head_logit.size(0), self.n_token)) 240 | head_logprob = F.log_softmax(head_logit, dim=1) 241 | 242 | cutoff_values = [0] + self.cutoffs 243 | for i in range(len(cutoff_values) - 1): 244 | start_idx, stop_idx = cutoff_values[i], cutoff_values[i + 1] 245 | 246 | if i == 0: 247 | out[:, :self.cutoffs[0]] = head_logprob[:, :self.cutoffs[0]] 248 | else: 249 | weight_i, bias_i, proj_i = weights[i], biases[i], self.out_projs[i] 250 | 251 | tail_logit_i = self._compute_logit(hidden, weight_i, bias_i, proj_i) 252 | tail_logprob_i = F.log_softmax(tail_logit_i, dim=1) 253 | 254 | logprob_i = head_logprob[:, -i] + tail_logprob_i 255 | out[:, start_idx, stop_idx] = logprob_i 256 | 257 | return out 258 | 259 | 260 | class LogUniformSampler(object): 261 | def __init__(self, range_max, n_sample): 262 | """ 263 | Reference : https://github.com/tensorflow/tensorflow/blob/r1.10/tensorflow/python/ops/candidate_sampling_ops.py 264 | `P(class) = (log(class + 2) - log(class + 1)) / log(range_max + 1)` 265 | 266 | expected count can be approximated by 1 - (1 - p)^n 267 | and we use a numerically stable version -expm1(num_tries * log1p(-p)) 268 | 269 | Our implementation fixes num_tries at 2 * n_sample, and the actual #samples will vary from run to run 270 | """ 271 | with torch.no_grad(): 272 | self.range_max = range_max 273 | log_indices = torch.arange(1., range_max+2., 1.).log_() 274 | self.dist = (log_indices[1:] - log_indices[:-1]) / log_indices[-1] 275 | # print('P', self.dist.numpy().tolist()[-30:]) 276 | 277 | self.log_q = (- (-self.dist.double().log1p_() * 2 * n_sample).expm1_()).log_().float() 278 | 279 | self.n_sample = n_sample 280 | 281 | def sample(self, labels): 282 | """ 283 | labels: [b1, b2] 284 | Return 285 | true_log_probs: [b1, b2] 286 | samp_log_probs: [n_sample] 287 | neg_samples: [n_sample] 288 | """ 289 | 290 | # neg_samples = torch.empty(0).long() 291 | n_sample = self.n_sample 292 | n_tries = 2 * n_sample 293 | 294 | with torch.no_grad(): 295 | neg_samples = torch.multinomial(self.dist, n_tries, replacement=True).unique() 296 | device = labels.device 297 | neg_samples = neg_samples.to(device) 298 | true_log_probs = self.log_q[labels].to(device) 299 | samp_log_probs = self.log_q[neg_samples].to(device) 300 | return true_log_probs, samp_log_probs, neg_samples 301 | 302 | def sample_logits(embedding, bias, labels, inputs, sampler): 303 | """ 304 | embedding: an nn.Embedding layer 305 | bias: [n_vocab] 306 | labels: [b1, b2] 307 | inputs: [b1, b2, n_emb] 308 | sampler: you may use a LogUniformSampler 309 | Return 310 | logits: [b1, b2, 1 + n_sample] 311 | """ 312 | true_log_probs, samp_log_probs, neg_samples = sampler.sample(labels) 313 | n_sample = neg_samples.size(0) 314 | b1, b2 = labels.size(0), labels.size(1) 315 | all_ids = torch.cat([labels.view(-1), neg_samples]) 316 | all_w = embedding(all_ids) 317 | true_w = all_w[: -n_sample].view(b1, b2, -1) 318 | sample_w = all_w[- n_sample:].view(n_sample, -1) 319 | 320 | all_b = bias[all_ids] 321 | true_b = all_b[: -n_sample].view(b1, b2) 322 | sample_b = all_b[- n_sample:] 323 | 324 | hit = (labels[:, :, None] == neg_samples).detach() 325 | 326 | true_logits = torch.einsum('ijk,ijk->ij', 327 | [true_w, inputs]) + true_b - true_log_probs 328 | sample_logits = torch.einsum('lk,ijk->ijl', 329 | [sample_w, inputs]) + sample_b - samp_log_probs 330 | sample_logits.masked_fill_(hit, -1e30) 331 | logits = torch.cat([true_logits[:, :, None], sample_logits], -1) 332 | 333 | return logits 334 | 335 | 336 | # class LogUniformSampler(object): 337 | # def __init__(self, range_max, unique=False): 338 | # """ 339 | # Reference : https://github.com/tensorflow/tensorflow/blob/r1.10/tensorflow/python/ops/candidate_sampling_ops.py 340 | # `P(class) = (log(class + 2) - log(class + 1)) / log(range_max + 1)` 341 | # """ 342 | # self.range_max = range_max 343 | # log_indices = torch.arange(1., range_max+2., 1.).log_() 344 | # self.dist = (log_indices[1:] - log_indices[:-1]) / log_indices[-1] 345 | 346 | # self.unique = unique 347 | 348 | # if self.unique: 349 | # self.exclude_mask = torch.ByteTensor(range_max).fill_(0) 350 | 351 | # def sample(self, n_sample, labels): 352 | # pos_sample, new_labels = labels.unique(return_inverse=True) 353 | # n_pos_sample = pos_sample.size(0) 354 | # n_neg_sample = n_sample - n_pos_sample 355 | 356 | # if self.unique: 357 | # self.exclude_mask.index_fill_(0, pos_sample, 1) 358 | # sample_dist = self.dist.clone().masked_fill_(self.exclude_mask, 0) 359 | # self.exclude_mask.index_fill_(0, pos_sample, 0) 360 | # else: 361 | # sample_dist = self.dist 362 | 363 | # neg_sample = torch.multinomial(sample_dist, n_neg_sample) 364 | 365 | # sample = torch.cat([pos_sample, neg_sample]) 366 | # sample_prob = self.dist[sample] 367 | 368 | # return new_labels, sample, sample_prob 369 | 370 | 371 | if __name__ == '__main__': 372 | S, B = 3, 4 373 | n_vocab = 10000 374 | n_sample = 5 375 | H = 32 376 | 377 | labels = torch.LongTensor(S, B).random_(0, n_vocab) 378 | 379 | # sampler = LogUniformSampler(n_vocab, unique=False) 380 | # new_labels, sample, sample_prob = sampler.sample(n_sample, labels) 381 | 382 | sampler = LogUniformSampler(n_vocab, n_sample)#, unique=True) 383 | # true_probs, samp_probs, neg_samples = sampler.sample(n_sample, labels) 384 | 385 | # print('true_probs', true_probs.numpy().tolist()) 386 | # print('samp_probs', samp_probs.numpy().tolist()) 387 | # print('neg_samples', neg_samples.numpy().tolist()) 388 | 389 | # print('sum', torch.sum(sampler.dist).item()) 390 | 391 | # assert torch.all(torch.sort(sample.unique())[0].eq(torch.sort(sample)[0])).item() 392 | 393 | embedding = nn.Embedding(n_vocab, H) 394 | bias = torch.zeros(n_vocab) 395 | inputs = torch.Tensor(S, B, H).normal_() 396 | 397 | logits, out_labels = sample_logits(embedding, bias, labels, inputs, sampler, n_sample) 398 | print('logits', logits.detach().numpy().tolist()) 399 | print('logits shape', logits.size()) 400 | print('out_labels', out_labels.detach().numpy().tolist()) 401 | print('out_labels shape', out_labels.size()) 402 | 403 | -------------------------------------------------------------------------------- /pytorch_pretrained_bert/tokenization.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2018 The Google AI Language Team Authors and The HuggingFace Inc. team. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | """Tokenization classes.""" 16 | 17 | from __future__ import absolute_import, division, print_function, unicode_literals 18 | 19 | import collections 20 | import logging 21 | import os 22 | import unicodedata 23 | from io import open 24 | 25 | from .file_utils import cached_path 26 | 27 | logger = logging.getLogger(__name__) 28 | 29 | PRETRAINED_VOCAB_ARCHIVE_MAP = { 30 | 'bert-base-uncased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-uncased-vocab.txt", 31 | 'bert-large-uncased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-uncased-vocab.txt", 32 | 'bert-base-cased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-cased-vocab.txt", 33 | 'bert-large-cased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-cased-vocab.txt", 34 | 'bert-base-multilingual-uncased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-multilingual-uncased-vocab.txt", 35 | 'bert-base-multilingual-cased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-multilingual-cased-vocab.txt", 36 | 'bert-base-chinese': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-chinese-vocab.txt", 37 | } 38 | PRETRAINED_VOCAB_POSITIONAL_EMBEDDINGS_SIZE_MAP = { 39 | 'bert-base-uncased': 512, 40 | 'bert-large-uncased': 512, 41 | 'bert-base-cased': 512, 42 | 'bert-large-cased': 512, 43 | 'bert-base-multilingual-uncased': 512, 44 | 'bert-base-multilingual-cased': 512, 45 | 'bert-base-chinese': 512, 46 | } 47 | VOCAB_NAME = 'vocab.txt' 48 | 49 | 50 | def load_vocab(vocab_file): 51 | """Loads a vocabulary file into a dictionary.""" 52 | vocab = collections.OrderedDict() 53 | index = 0 54 | with open(vocab_file, "r", encoding="utf-8") as reader: 55 | while True: 56 | token = reader.readline() 57 | if not token: 58 | break 59 | token = token.strip() 60 | vocab[token] = index 61 | index += 1 62 | return vocab 63 | 64 | 65 | def whitespace_tokenize(text): 66 | """Runs basic whitespace cleaning and splitting on a piece of text.""" 67 | text = text.strip() 68 | if not text: 69 | return [] 70 | tokens = text.split() 71 | return tokens 72 | 73 | 74 | class BertTokenizer(object): 75 | """Runs end-to-end tokenization: punctuation splitting + wordpiece""" 76 | 77 | def __init__(self, vocab_file, do_lower_case=True, max_len=None, do_basic_tokenize=True, 78 | never_split=("[UNK]", "[SEP]", "[PAD]", "[CLS]", "[MASK]")): 79 | """Constructs a BertTokenizer. 80 | 81 | Args: 82 | vocab_file: Path to a one-wordpiece-per-line vocabulary file 83 | do_lower_case: Whether to lower case the input 84 | Only has an effect when do_wordpiece_only=False 85 | do_basic_tokenize: Whether to do basic tokenization before wordpiece. 86 | max_len: An artificial maximum length to truncate tokenized sequences to; 87 | Effective maximum length is always the minimum of this 88 | value (if specified) and the underlying BERT model's 89 | sequence length. 90 | never_split: List of tokens which will never be split during tokenization. 91 | Only has an effect when do_wordpiece_only=False 92 | """ 93 | if not os.path.isfile(vocab_file): 94 | raise ValueError( 95 | "Can't find a vocabulary file at path '{}'. To load the vocabulary from a Google pretrained " 96 | "model use `tokenizer = BertTokenizer.from_pretrained(PRETRAINED_MODEL_NAME)`".format(vocab_file)) 97 | self.vocab = load_vocab(vocab_file) 98 | self.ids_to_tokens = collections.OrderedDict( 99 | [(ids, tok) for tok, ids in self.vocab.items()]) 100 | self.do_basic_tokenize = do_basic_tokenize 101 | if do_basic_tokenize: 102 | self.basic_tokenizer = BasicTokenizer(do_lower_case=do_lower_case, 103 | never_split=never_split) 104 | self.wordpiece_tokenizer = WordpieceTokenizer(vocab=self.vocab) 105 | self.max_len = max_len if max_len is not None else int(1e12) 106 | 107 | def tokenize(self, text): 108 | split_tokens = [] 109 | if self.do_basic_tokenize: 110 | for token in self.basic_tokenizer.tokenize(text): 111 | for sub_token in self.wordpiece_tokenizer.tokenize(token): 112 | split_tokens.append(sub_token) 113 | else: 114 | split_tokens = self.wordpiece_tokenizer.tokenize(text) 115 | return split_tokens 116 | 117 | def convert_tokens_to_ids(self, tokens): 118 | """Converts a sequence of tokens into ids using the vocab.""" 119 | ids = [] 120 | for token in tokens: 121 | ids.append(self.vocab[token]) 122 | if len(ids) > self.max_len: 123 | logger.warning( 124 | "Token indices sequence length is longer than the specified maximum " 125 | " sequence length for this BERT model ({} > {}). Running this" 126 | " sequence through BERT will result in indexing errors".format(len(ids), self.max_len) 127 | ) 128 | return ids 129 | 130 | def convert_ids_to_tokens(self, ids): 131 | """Converts a sequence of ids in wordpiece tokens using the vocab.""" 132 | tokens = [] 133 | for i in ids: 134 | tokens.append(self.ids_to_tokens[i]) 135 | return tokens 136 | 137 | @classmethod 138 | def from_pretrained(cls, pretrained_model_name_or_path, cache_dir=None, *inputs, **kwargs): 139 | """ 140 | Instantiate a PreTrainedBertModel from a pre-trained model file. 141 | Download and cache the pre-trained model file if needed. 142 | """ 143 | if pretrained_model_name_or_path in PRETRAINED_VOCAB_ARCHIVE_MAP: 144 | vocab_file = PRETRAINED_VOCAB_ARCHIVE_MAP[pretrained_model_name_or_path] 145 | if '-cased' in pretrained_model_name_or_path and kwargs.get('do_lower_case', True): 146 | logger.warning("The pre-trained model you are loading is a cased model but you have not set " 147 | "`do_lower_case` to False. We are setting `do_lower_case=False` for you but " 148 | "you may want to check this behavior.") 149 | kwargs['do_lower_case'] = False 150 | elif '-cased' not in pretrained_model_name_or_path and not kwargs.get('do_lower_case', True): 151 | logger.warning("The pre-trained model you are loading is an uncased model but you have set " 152 | "`do_lower_case` to False. We are setting `do_lower_case=True` for you " 153 | "but you may want to check this behavior.") 154 | kwargs['do_lower_case'] = True 155 | else: 156 | vocab_file = pretrained_model_name_or_path 157 | if os.path.isdir(vocab_file): 158 | vocab_file = os.path.join(vocab_file, VOCAB_NAME) 159 | # redirect to the cache, if necessary 160 | try: 161 | resolved_vocab_file = cached_path(vocab_file, cache_dir=cache_dir) 162 | except EnvironmentError: 163 | logger.error( 164 | "Model name '{}' was not found in model name list ({}). " 165 | "We assumed '{}' was a path or url but couldn't find any file " 166 | "associated to this path or url.".format( 167 | pretrained_model_name_or_path, 168 | ', '.join(PRETRAINED_VOCAB_ARCHIVE_MAP.keys()), 169 | vocab_file)) 170 | return None 171 | if resolved_vocab_file == vocab_file: 172 | logger.info("loading vocabulary file {}".format(vocab_file)) 173 | else: 174 | logger.info("loading vocabulary file {} from cache at {}".format( 175 | vocab_file, resolved_vocab_file)) 176 | if pretrained_model_name_or_path in PRETRAINED_VOCAB_POSITIONAL_EMBEDDINGS_SIZE_MAP: 177 | # if we're using a pretrained model, ensure the tokenizer wont index sequences longer 178 | # than the number of positional embeddings 179 | max_len = PRETRAINED_VOCAB_POSITIONAL_EMBEDDINGS_SIZE_MAP[pretrained_model_name_or_path] 180 | kwargs['max_len'] = min(kwargs.get('max_len', int(1e12)), max_len) 181 | # Instantiate tokenizer. 182 | tokenizer = cls(resolved_vocab_file, *inputs, **kwargs) 183 | return tokenizer 184 | 185 | 186 | class BasicTokenizer(object): 187 | """Runs basic tokenization (punctuation splitting, lower casing, etc.).""" 188 | 189 | def __init__(self, 190 | do_lower_case=True, 191 | never_split=("[UNK]", "[SEP]", "[PAD]", "[CLS]", "[MASK]")): 192 | """Constructs a BasicTokenizer. 193 | 194 | Args: 195 | do_lower_case: Whether to lower case the input. 196 | """ 197 | self.do_lower_case = do_lower_case 198 | self.never_split = never_split 199 | 200 | def tokenize(self, text): 201 | """Tokenizes a piece of text.""" 202 | text = self._clean_text(text) 203 | # This was added on November 1st, 2018 for the multilingual and Chinese 204 | # models. This is also applied to the English models now, but it doesn't 205 | # matter since the English models were not trained on any Chinese data 206 | # and generally don't have any Chinese data in them (there are Chinese 207 | # characters in the vocabulary because Wikipedia does have some Chinese 208 | # words in the English Wikipedia.). 209 | text = self._tokenize_chinese_chars(text) 210 | orig_tokens = whitespace_tokenize(text) 211 | split_tokens = [] 212 | for token in orig_tokens: 213 | if self.do_lower_case and token not in self.never_split: 214 | token = token.lower() 215 | token = self._run_strip_accents(token) 216 | split_tokens.extend(self._run_split_on_punc(token)) 217 | 218 | output_tokens = whitespace_tokenize(" ".join(split_tokens)) 219 | return output_tokens 220 | 221 | def _run_strip_accents(self, text): 222 | """Strips accents from a piece of text.""" 223 | text = unicodedata.normalize("NFD", text) 224 | output = [] 225 | for char in text: 226 | cat = unicodedata.category(char) 227 | if cat == "Mn": 228 | continue 229 | output.append(char) 230 | return "".join(output) 231 | 232 | def _run_split_on_punc(self, text): 233 | """Splits punctuation on a piece of text.""" 234 | if text in self.never_split: 235 | return [text] 236 | chars = list(text) 237 | i = 0 238 | start_new_word = True 239 | output = [] 240 | while i < len(chars): 241 | char = chars[i] 242 | if _is_punctuation(char): 243 | output.append([char]) 244 | start_new_word = True 245 | else: 246 | if start_new_word: 247 | output.append([]) 248 | start_new_word = False 249 | output[-1].append(char) 250 | i += 1 251 | 252 | return ["".join(x) for x in output] 253 | 254 | def _tokenize_chinese_chars(self, text): 255 | """Adds whitespace around any CJK character.""" 256 | output = [] 257 | for char in text: 258 | cp = ord(char) 259 | if self._is_chinese_char(cp): 260 | output.append(" ") 261 | output.append(char) 262 | output.append(" ") 263 | else: 264 | output.append(char) 265 | return "".join(output) 266 | 267 | def _is_chinese_char(self, cp): 268 | """Checks whether CP is the codepoint of a CJK character.""" 269 | # This defines a "chinese character" as anything in the CJK Unicode block: 270 | # https://en.wikipedia.org/wiki/CJK_Unified_Ideographs_(Unicode_block) 271 | # 272 | # Note that the CJK Unicode block is NOT all Japanese and Korean characters, 273 | # despite its name. The modern Korean Hangul alphabet is a different block, 274 | # as is Japanese Hiragana and Katakana. Those alphabets are used to write 275 | # space-separated words, so they are not treated specially and handled 276 | # like the all of the other languages. 277 | if ((cp >= 0x4E00 and cp <= 0x9FFF) or # 278 | (cp >= 0x3400 and cp <= 0x4DBF) or # 279 | (cp >= 0x20000 and cp <= 0x2A6DF) or # 280 | (cp >= 0x2A700 and cp <= 0x2B73F) or # 281 | (cp >= 0x2B740 and cp <= 0x2B81F) or # 282 | (cp >= 0x2B820 and cp <= 0x2CEAF) or 283 | (cp >= 0xF900 and cp <= 0xFAFF) or # 284 | (cp >= 0x2F800 and cp <= 0x2FA1F)): # 285 | return True 286 | 287 | return False 288 | 289 | def _clean_text(self, text): 290 | """Performs invalid character removal and whitespace cleanup on text.""" 291 | output = [] 292 | for char in text: 293 | cp = ord(char) 294 | if cp == 0 or cp == 0xfffd or _is_control(char): 295 | continue 296 | if _is_whitespace(char): 297 | output.append(" ") 298 | else: 299 | output.append(char) 300 | return "".join(output) 301 | 302 | 303 | class WordpieceTokenizer(object): 304 | """Runs WordPiece tokenization.""" 305 | 306 | def __init__(self, vocab, unk_token="[UNK]", max_input_chars_per_word=100): 307 | self.vocab = vocab 308 | self.unk_token = unk_token 309 | self.max_input_chars_per_word = max_input_chars_per_word 310 | 311 | def tokenize(self, text): 312 | """Tokenizes a piece of text into its word pieces. 313 | 314 | This uses a greedy longest-match-first algorithm to perform tokenization 315 | using the given vocabulary. 316 | 317 | For example: 318 | input = "unaffable" 319 | output = ["un", "##aff", "##able"] 320 | 321 | Args: 322 | text: A single token or whitespace separated tokens. This should have 323 | already been passed through `BasicTokenizer`. 324 | 325 | Returns: 326 | A list of wordpiece tokens. 327 | """ 328 | 329 | output_tokens = [] 330 | for token in whitespace_tokenize(text): 331 | chars = list(token) 332 | if len(chars) > self.max_input_chars_per_word: 333 | output_tokens.append(self.unk_token) 334 | continue 335 | 336 | is_bad = False 337 | start = 0 338 | sub_tokens = [] 339 | while start < len(chars): 340 | end = len(chars) 341 | cur_substr = None 342 | while start < end: 343 | substr = "".join(chars[start:end]) 344 | if start > 0: 345 | substr = "##" + substr 346 | if substr in self.vocab: 347 | cur_substr = substr 348 | break 349 | end -= 1 350 | if cur_substr is None: 351 | is_bad = True 352 | break 353 | sub_tokens.append(cur_substr) 354 | start = end 355 | 356 | if is_bad: 357 | output_tokens.append(self.unk_token) 358 | else: 359 | output_tokens.extend(sub_tokens) 360 | return output_tokens 361 | 362 | 363 | def _is_whitespace(char): 364 | """Checks whether `chars` is a whitespace character.""" 365 | # \t, \n, and \r are technically contorl characters but we treat them 366 | # as whitespace since they are generally considered as such. 367 | if char == " " or char == "\t" or char == "\n" or char == "\r": 368 | return True 369 | cat = unicodedata.category(char) 370 | if cat == "Zs": 371 | return True 372 | return False 373 | 374 | 375 | def _is_control(char): 376 | """Checks whether `chars` is a control character.""" 377 | # These are technically control characters but we count them as whitespace 378 | # characters. 379 | if char == "\t" or char == "\n" or char == "\r": 380 | return False 381 | cat = unicodedata.category(char) 382 | if cat.startswith("C"): 383 | return True 384 | return False 385 | 386 | 387 | def _is_punctuation(char): 388 | """Checks whether `chars` is a punctuation character.""" 389 | cp = ord(char) 390 | # We treat all non-letter/number ASCII as punctuation. 391 | # Characters such as "^", "$", and "`" are not in the Unicode 392 | # Punctuation class but we treat them as punctuation anyways, for 393 | # consistency. 394 | if ((cp >= 33 and cp <= 47) or (cp >= 58 and cp <= 64) or 395 | (cp >= 91 and cp <= 96) or (cp >= 123 and cp <= 126)): 396 | return True 397 | cat = unicodedata.category(char) 398 | if cat.startswith("P"): 399 | return True 400 | return False 401 | -------------------------------------------------------------------------------- /pytorch_pretrained_bert/tokenization_transfo_xl.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2018 Google AI, Google Brain and Carnegie Mellon University Authors and the HuggingFace Inc. team. 3 | # Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved. 4 | # 5 | # Licensed under the Apache License, Version 2.0 (the "License"); 6 | # you may not use this file except in compliance with the License. 7 | # You may obtain a copy of the License at 8 | # 9 | # http://www.apache.org/licenses/LICENSE-2.0 10 | # 11 | # Unless required by applicable law or agreed to in writing, software 12 | # distributed under the License is distributed on an "AS IS" BASIS, 13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | # See the License for the specific language governing permissions and 15 | # limitations under the License. 16 | """ Tokenization classes for Transformer XL model. 17 | Adapted from https://github.com/kimiyoung/transformer-xl. 18 | """ 19 | from __future__ import (absolute_import, division, print_function, 20 | unicode_literals) 21 | 22 | import glob 23 | import logging 24 | import os 25 | import sys 26 | from collections import Counter, OrderedDict 27 | from io import open 28 | import unicodedata 29 | 30 | import torch 31 | import numpy as np 32 | 33 | from .file_utils import cached_path 34 | 35 | if sys.version_info[0] == 2: 36 | import cPickle as pickle 37 | else: 38 | import pickle 39 | 40 | 41 | logger = logging.getLogger(__name__) 42 | 43 | PRETRAINED_VOCAB_ARCHIVE_MAP = { 44 | 'transfo-xl-wt103': "https://s3.amazonaws.com/models.huggingface.co/bert/transfo-xl-wt103-vocab.bin", 45 | } 46 | VOCAB_NAME = 'vocab.bin' 47 | 48 | PRETRAINED_CORPUS_ARCHIVE_MAP = { 49 | 'transfo-xl-wt103': "https://s3.amazonaws.com/models.huggingface.co/bert/transfo-xl-wt103-corpus.bin", 50 | } 51 | CORPUS_NAME = 'corpus.bin' 52 | 53 | class TransfoXLTokenizer(object): 54 | """ 55 | Transformer-XL tokenizer adapted from Vocab class in https://github.com/kimiyoung/transformer-xl 56 | """ 57 | @classmethod 58 | def from_pretrained(cls, pretrained_model_name_or_path, cache_dir=None, *inputs, **kwargs): 59 | """ 60 | Instantiate a TransfoXLTokenizer. 61 | The TransfoXLTokenizer. 62 | """ 63 | if pretrained_model_name_or_path in PRETRAINED_VOCAB_ARCHIVE_MAP: 64 | vocab_file = PRETRAINED_VOCAB_ARCHIVE_MAP[pretrained_model_name_or_path] 65 | else: 66 | vocab_file = os.path.join(pretrained_model_name_or_path, VOCAB_NAME) 67 | # redirect to the cache, if necessary 68 | try: 69 | resolved_vocab_file = cached_path(vocab_file, cache_dir=cache_dir) 70 | except EnvironmentError: 71 | logger.error( 72 | "Model name '{}' was not found in model name list ({}). " 73 | "We assumed '{}' was a path or url but couldn't find files {} " 74 | "at this path or url.".format( 75 | pretrained_model_name_or_path, 76 | ', '.join(PRETRAINED_VOCAB_ARCHIVE_MAP.keys()), 77 | pretrained_model_name_or_path, 78 | vocab_file)) 79 | return None 80 | if resolved_vocab_file == vocab_file: 81 | logger.info("loading vocabulary file {}".format(vocab_file)) 82 | else: 83 | logger.info("loading vocabulary file {} from cache at {}".format( 84 | vocab_file, resolved_vocab_file)) 85 | 86 | # Instantiate tokenizer. 87 | tokenizer = cls(*inputs, **kwargs) 88 | vocab_dict = torch.load(resolved_vocab_file) 89 | for key, value in vocab_dict.items(): 90 | tokenizer.__dict__[key] = value 91 | return tokenizer 92 | 93 | def __init__(self, special=[], min_freq=0, max_size=None, lower_case=False, 94 | delimiter=None, vocab_file=None, never_split=("", "", "")): 95 | self.counter = Counter() 96 | self.special = special 97 | self.min_freq = min_freq 98 | self.max_size = max_size 99 | self.lower_case = lower_case 100 | self.delimiter = delimiter 101 | self.vocab_file = vocab_file 102 | self.never_split = never_split 103 | 104 | def count_file(self, path, verbose=False, add_eos=False): 105 | if verbose: print('counting file {} ...'.format(path)) 106 | assert os.path.exists(path) 107 | 108 | sents = [] 109 | with open(path, 'r', encoding='utf-8') as f: 110 | for idx, line in enumerate(f): 111 | if verbose and idx > 0 and idx % 500000 == 0: 112 | print(' line {}'.format(idx)) 113 | symbols = self.tokenize(line, add_eos=add_eos) 114 | self.counter.update(symbols) 115 | sents.append(symbols) 116 | 117 | return sents 118 | 119 | def count_sents(self, sents, verbose=False): 120 | """ 121 | sents : a list of sentences, each a list of tokenized symbols 122 | """ 123 | if verbose: print('counting {} sents ...'.format(len(sents))) 124 | for idx, symbols in enumerate(sents): 125 | if verbose and idx > 0 and idx % 500000 == 0: 126 | print(' line {}'.format(idx)) 127 | self.counter.update(symbols) 128 | 129 | def _build_from_file(self, vocab_file): 130 | self.idx2sym = [] 131 | self.sym2idx = OrderedDict() 132 | 133 | with open(vocab_file, 'r', encoding='utf-8') as f: 134 | for line in f: 135 | symb = line.strip().split()[0] 136 | self.add_symbol(symb) 137 | if '' in self.sym2idx: 138 | self.unk_idx = self.sym2idx[''] 139 | elif '' in self.sym2idx: 140 | self.unk_idx = self.sym2idx[''] 141 | else: 142 | raise ValueError('No token in vocabulary') 143 | 144 | def build_vocab(self): 145 | if self.vocab_file: 146 | print('building vocab from {}'.format(self.vocab_file)) 147 | self._build_from_file(self.vocab_file) 148 | print('final vocab size {}'.format(len(self))) 149 | else: 150 | print('building vocab with min_freq={}, max_size={}'.format( 151 | self.min_freq, self.max_size)) 152 | self.idx2sym = [] 153 | self.sym2idx = OrderedDict() 154 | 155 | for sym in self.special: 156 | self.add_special(sym) 157 | 158 | for sym, cnt in self.counter.most_common(self.max_size): 159 | if cnt < self.min_freq: break 160 | self.add_symbol(sym) 161 | 162 | print('final vocab size {} from {} unique tokens'.format( 163 | len(self), len(self.counter))) 164 | 165 | def encode_file(self, path, ordered=False, verbose=False, add_eos=True, 166 | add_double_eos=False): 167 | if verbose: print('encoding file {} ...'.format(path)) 168 | assert os.path.exists(path) 169 | encoded = [] 170 | with open(path, 'r', encoding='utf-8') as f: 171 | for idx, line in enumerate(f): 172 | if verbose and idx > 0 and idx % 500000 == 0: 173 | print(' line {}'.format(idx)) 174 | symbols = self.tokenize(line, add_eos=add_eos, 175 | add_double_eos=add_double_eos) 176 | encoded.append(self.convert_to_tensor(symbols)) 177 | 178 | if ordered: 179 | encoded = torch.cat(encoded) 180 | 181 | return encoded 182 | 183 | def encode_sents(self, sents, ordered=False, verbose=False): 184 | if verbose: print('encoding {} sents ...'.format(len(sents))) 185 | encoded = [] 186 | for idx, symbols in enumerate(sents): 187 | if verbose and idx > 0 and idx % 500000 == 0: 188 | print(' line {}'.format(idx)) 189 | encoded.append(self.convert_to_tensor(symbols)) 190 | 191 | if ordered: 192 | encoded = torch.cat(encoded) 193 | 194 | return encoded 195 | 196 | def add_special(self, sym): 197 | if sym not in self.sym2idx: 198 | self.idx2sym.append(sym) 199 | self.sym2idx[sym] = len(self.idx2sym) - 1 200 | setattr(self, '{}_idx'.format(sym.strip('<>')), self.sym2idx[sym]) 201 | 202 | def add_symbol(self, sym): 203 | if sym not in self.sym2idx: 204 | self.idx2sym.append(sym) 205 | self.sym2idx[sym] = len(self.idx2sym) - 1 206 | 207 | def get_sym(self, idx): 208 | assert 0 <= idx < len(self), 'Index {} out of vocabulary range'.format(idx) 209 | return self.idx2sym[idx] 210 | 211 | def get_idx(self, sym): 212 | if sym in self.sym2idx: 213 | return self.sym2idx[sym] 214 | else: 215 | # print('encounter unk {}'.format(sym)) 216 | # assert '' not in sym 217 | if hasattr(self, 'unk_idx'): 218 | return self.sym2idx.get(sym, self.unk_idx) 219 | # Backward compatibility with pre-trained models 220 | elif '' in self.sym2idx: 221 | return self.sym2idx[''] 222 | elif '' in self.sym2idx: 223 | return self.sym2idx[''] 224 | else: 225 | raise ValueError('Token not in vocabulary and no token in vocabulary for replacement') 226 | 227 | def convert_ids_to_tokens(self, indices): 228 | """Converts a sequence of indices in symbols using the vocab.""" 229 | return [self.get_sym(idx) for idx in indices] 230 | 231 | def convert_tokens_to_ids(self, symbols): 232 | """Converts a sequence of symbols into ids using the vocab.""" 233 | return [self.get_idx(sym) for sym in symbols] 234 | 235 | def convert_to_tensor(self, symbols): 236 | return torch.LongTensor(self.convert_tokens_to_ids(symbols)) 237 | 238 | def decode(self, indices, exclude=None): 239 | """Converts a sequence of indices in a string.""" 240 | if exclude is None: 241 | return ' '.join([self.get_sym(idx) for idx in indices]) 242 | else: 243 | return ' '.join([self.get_sym(idx) for idx in indices if idx not in exclude]) 244 | 245 | def __len__(self): 246 | return len(self.idx2sym) 247 | 248 | def _run_split_on_punc(self, text): 249 | """Splits punctuation on a piece of text.""" 250 | if text in self.never_split: 251 | return [text] 252 | chars = list(text) 253 | i = 0 254 | start_new_word = True 255 | output = [] 256 | while i < len(chars): 257 | char = chars[i] 258 | if _is_punctuation(char): 259 | output.append([char]) 260 | start_new_word = True 261 | else: 262 | if start_new_word: 263 | output.append([]) 264 | start_new_word = False 265 | output[-1].append(char) 266 | i += 1 267 | 268 | return ["".join(x) for x in output] 269 | 270 | def _run_strip_accents(self, text): 271 | """Strips accents from a piece of text.""" 272 | text = unicodedata.normalize("NFD", text) 273 | output = [] 274 | for char in text: 275 | cat = unicodedata.category(char) 276 | if cat == "Mn": 277 | continue 278 | output.append(char) 279 | return "".join(output) 280 | 281 | def _clean_text(self, text): 282 | """Performs invalid character removal and whitespace cleanup on text.""" 283 | output = [] 284 | for char in text: 285 | cp = ord(char) 286 | if cp == 0 or cp == 0xfffd or _is_control(char): 287 | continue 288 | if _is_whitespace(char): 289 | output.append(" ") 290 | else: 291 | output.append(char) 292 | return "".join(output) 293 | 294 | def whitespace_tokenize(self, text): 295 | """Runs basic whitespace cleaning and splitting on a piece of text.""" 296 | text = text.strip() 297 | if not text: 298 | return [] 299 | if self.delimiter == '': 300 | tokens = text 301 | else: 302 | tokens = text.split(self.delimiter) 303 | return tokens 304 | 305 | def tokenize(self, line, add_eos=False, add_double_eos=False): 306 | line = self._clean_text(line) 307 | line = line.strip() 308 | 309 | symbols = self.whitespace_tokenize(line) 310 | 311 | split_symbols = [] 312 | for symbol in symbols: 313 | if self.lower_case and symbol not in self.never_split: 314 | symbol = symbol.lower() 315 | symbol = self._run_strip_accents(symbol) 316 | split_symbols.extend(self._run_split_on_punc(symbol)) 317 | 318 | if add_double_eos: # lm1b 319 | return [''] + split_symbols + [''] 320 | elif add_eos: 321 | return split_symbols + [''] 322 | else: 323 | return split_symbols 324 | 325 | 326 | class LMOrderedIterator(object): 327 | def __init__(self, data, bsz, bptt, device='cpu', ext_len=None): 328 | """ 329 | data -- LongTensor -- the LongTensor is strictly ordered 330 | """ 331 | self.bsz = bsz 332 | self.bptt = bptt 333 | self.ext_len = ext_len if ext_len is not None else 0 334 | 335 | self.device = device 336 | 337 | # Work out how cleanly we can divide the dataset into bsz parts. 338 | self.n_step = data.size(0) // bsz 339 | 340 | # Trim off any extra elements that wouldn't cleanly fit (remainders). 341 | data = data.narrow(0, 0, self.n_step * bsz) 342 | 343 | # Evenly divide the data across the bsz batches. 344 | self.data = data.view(bsz, -1).t().contiguous().to(device) 345 | 346 | # Number of mini-batches 347 | self.n_batch = (self.n_step + self.bptt - 1) // self.bptt 348 | 349 | def get_batch(self, i, bptt=None): 350 | if bptt is None: bptt = self.bptt 351 | seq_len = min(bptt, self.data.size(0) - 1 - i) 352 | 353 | end_idx = i + seq_len 354 | beg_idx = max(0, i - self.ext_len) 355 | 356 | data = self.data[beg_idx:end_idx] 357 | target = self.data[i+1:i+1+seq_len] 358 | 359 | data_out = data.transpose(0, 1).contiguous().to(self.device) 360 | target_out = target.transpose(0, 1).contiguous().to(self.device) 361 | 362 | return data_out, target_out, seq_len 363 | 364 | def get_fixlen_iter(self, start=0): 365 | for i in range(start, self.data.size(0) - 1, self.bptt): 366 | yield self.get_batch(i) 367 | 368 | def get_varlen_iter(self, start=0, std=5, min_len=5, max_deviation=3): 369 | max_len = self.bptt + max_deviation * std 370 | i = start 371 | while True: 372 | bptt = self.bptt if np.random.random() < 0.95 else self.bptt / 2. 373 | bptt = min(max_len, max(min_len, int(np.random.normal(bptt, std)))) 374 | data, target, seq_len = self.get_batch(i, bptt) 375 | i += seq_len 376 | yield data, target, seq_len 377 | if i >= self.data.size(0) - 2: 378 | break 379 | 380 | def __iter__(self): 381 | return self.get_fixlen_iter() 382 | 383 | 384 | class LMShuffledIterator(object): 385 | def __init__(self, data, bsz, bptt, device='cpu', ext_len=None, shuffle=False): 386 | """ 387 | data -- list[LongTensor] -- there is no order among the LongTensors 388 | """ 389 | self.data = data 390 | 391 | self.bsz = bsz 392 | self.bptt = bptt 393 | self.ext_len = ext_len if ext_len is not None else 0 394 | 395 | self.device = device 396 | self.shuffle = shuffle 397 | 398 | def get_sent_stream(self): 399 | # index iterator 400 | epoch_indices = np.random.permutation(len(self.data)) if self.shuffle \ 401 | else np.array(range(len(self.data))) 402 | 403 | # sentence iterator 404 | for idx in epoch_indices: 405 | yield self.data[idx] 406 | 407 | def stream_iterator(self, sent_stream): 408 | # streams for each data in the batch 409 | streams = [None] * self.bsz 410 | 411 | data = torch.LongTensor(self.bptt, self.bsz) 412 | target = torch.LongTensor(self.bptt, self.bsz) 413 | 414 | n_retain = 0 415 | 416 | while True: 417 | # data : [n_retain+bptt x bsz] 418 | # target : [bptt x bsz] 419 | data[n_retain:].fill_(-1) 420 | target.fill_(-1) 421 | 422 | valid_batch = True 423 | 424 | for i in range(self.bsz): 425 | n_filled = 0 426 | try: 427 | while n_filled < self.bptt: 428 | if streams[i] is None or len(streams[i]) <= 1: 429 | streams[i] = next(sent_stream) 430 | # number of new tokens to fill in 431 | n_new = min(len(streams[i]) - 1, self.bptt - n_filled) 432 | # first n_retain tokens are retained from last batch 433 | data[n_retain+n_filled:n_retain+n_filled+n_new, i] = \ 434 | streams[i][:n_new] 435 | target[n_filled:n_filled+n_new, i] = \ 436 | streams[i][1:n_new+1] 437 | streams[i] = streams[i][n_new:] 438 | n_filled += n_new 439 | except StopIteration: 440 | valid_batch = False 441 | break 442 | 443 | if not valid_batch: 444 | return 445 | 446 | data_out = data.transpose(0, 1).contiguous().to(self.device) 447 | target_out = target.transpose(0, 1).contiguous().to(self.device) 448 | 449 | yield data_out, target_out, self.bptt 450 | 451 | n_retain = min(data.size(0), self.ext_len) 452 | if n_retain > 0: 453 | data[:n_retain] = data[-n_retain:] 454 | data.resize_(n_retain + self.bptt, data.size(1)) 455 | 456 | def __iter__(self): 457 | # sent_stream is an iterator 458 | sent_stream = self.get_sent_stream() 459 | 460 | for batch in self.stream_iterator(sent_stream): 461 | yield batch 462 | 463 | 464 | class LMMultiFileIterator(LMShuffledIterator): 465 | def __init__(self, paths, vocab, bsz, bptt, device='cpu', ext_len=None, 466 | shuffle=False): 467 | 468 | self.paths = paths 469 | self.vocab = vocab 470 | 471 | self.bsz = bsz 472 | self.bptt = bptt 473 | self.ext_len = ext_len if ext_len is not None else 0 474 | 475 | self.device = device 476 | self.shuffle = shuffle 477 | 478 | def get_sent_stream(self, path): 479 | sents = self.vocab.encode_file(path, add_double_eos=True) 480 | if self.shuffle: 481 | np.random.shuffle(sents) 482 | sent_stream = iter(sents) 483 | 484 | return sent_stream 485 | 486 | def __iter__(self): 487 | if self.shuffle: 488 | np.random.shuffle(self.paths) 489 | 490 | for path in self.paths: 491 | # sent_stream is an iterator 492 | sent_stream = self.get_sent_stream(path) 493 | for batch in self.stream_iterator(sent_stream): 494 | yield batch 495 | 496 | 497 | class TransfoXLCorpus(object): 498 | @classmethod 499 | def from_pretrained(cls, pretrained_model_name_or_path, cache_dir=None, *inputs, **kwargs): 500 | """ 501 | Instantiate a pre-processed corpus. 502 | """ 503 | vocab = TransfoXLTokenizer.from_pretrained(pretrained_model_name_or_path, *inputs, **kwargs) 504 | if pretrained_model_name_or_path in PRETRAINED_CORPUS_ARCHIVE_MAP: 505 | corpus_file = PRETRAINED_CORPUS_ARCHIVE_MAP[pretrained_model_name_or_path] 506 | else: 507 | corpus_file = os.path.join(pretrained_model_name_or_path, CORPUS_NAME) 508 | # redirect to the cache, if necessary 509 | try: 510 | resolved_corpus_file = cached_path(corpus_file, cache_dir=cache_dir) 511 | except EnvironmentError: 512 | logger.error( 513 | "Corpus '{}' was not found in corpus list ({}). " 514 | "We assumed '{}' was a path or url but couldn't find files {} " 515 | "at this path or url.".format( 516 | pretrained_model_name_or_path, 517 | ', '.join(PRETRAINED_VOCAB_ARCHIVE_MAP.keys()), 518 | pretrained_model_name_or_path, 519 | corpus_file)) 520 | return None 521 | if resolved_corpus_file == corpus_file: 522 | logger.info("loading corpus file {}".format(corpus_file)) 523 | else: 524 | logger.info("loading corpus file {} from cache at {}".format( 525 | corpus_file, resolved_corpus_file)) 526 | 527 | # Instantiate tokenizer. 528 | corpus = cls(*inputs, **kwargs) 529 | corpus_dict = torch.load(resolved_corpus_file) 530 | for key, value in corpus_dict.items(): 531 | corpus.__dict__[key] = value 532 | corpus.vocab = vocab 533 | if corpus.train is not None: 534 | corpus.train = torch.tensor(corpus.train, dtype=torch.long) 535 | if corpus.valid is not None: 536 | corpus.valid = torch.tensor(corpus.valid, dtype=torch.long) 537 | if corpus.test is not None: 538 | corpus.test = torch.tensor(corpus.test, dtype=torch.long) 539 | return corpus 540 | 541 | def __init__(self, *args, **kwargs): 542 | self.vocab = TransfoXLTokenizer(*args, **kwargs) 543 | self.dataset = None 544 | self.train = None 545 | self.valid = None 546 | self.test = None 547 | 548 | def build_corpus(self, path, dataset): 549 | self.dataset = dataset 550 | 551 | if self.dataset in ['ptb', 'wt2', 'enwik8', 'text8']: 552 | self.vocab.count_file(os.path.join(path, 'train.txt')) 553 | self.vocab.count_file(os.path.join(path, 'valid.txt')) 554 | self.vocab.count_file(os.path.join(path, 'test.txt')) 555 | elif self.dataset == 'wt103': 556 | self.vocab.count_file(os.path.join(path, 'train.txt')) 557 | elif self.dataset == 'lm1b': 558 | train_path_pattern = os.path.join( 559 | path, '1-billion-word-language-modeling-benchmark-r13output', 560 | 'training-monolingual.tokenized.shuffled', 'news.en-*') 561 | train_paths = glob.glob(train_path_pattern) 562 | # the vocab will load from file when build_vocab() is called 563 | 564 | self.vocab.build_vocab() 565 | 566 | if self.dataset in ['ptb', 'wt2', 'wt103']: 567 | self.train = self.vocab.encode_file( 568 | os.path.join(path, 'train.txt'), ordered=True) 569 | self.valid = self.vocab.encode_file( 570 | os.path.join(path, 'valid.txt'), ordered=True) 571 | self.test = self.vocab.encode_file( 572 | os.path.join(path, 'test.txt'), ordered=True) 573 | elif self.dataset in ['enwik8', 'text8']: 574 | self.train = self.vocab.encode_file( 575 | os.path.join(path, 'train.txt'), ordered=True, add_eos=False) 576 | self.valid = self.vocab.encode_file( 577 | os.path.join(path, 'valid.txt'), ordered=True, add_eos=False) 578 | self.test = self.vocab.encode_file( 579 | os.path.join(path, 'test.txt'), ordered=True, add_eos=False) 580 | elif self.dataset == 'lm1b': 581 | self.train = train_paths 582 | self.valid = self.vocab.encode_file( 583 | os.path.join(path, 'valid.txt'), ordered=False, add_double_eos=True) 584 | self.test = self.vocab.encode_file( 585 | os.path.join(path, 'test.txt'), ordered=False, add_double_eos=True) 586 | 587 | def get_iterator(self, split, *args, **kwargs): 588 | if split == 'train': 589 | if self.dataset in ['ptb', 'wt2', 'wt103', 'enwik8', 'text8']: 590 | data_iter = LMOrderedIterator(self.train, *args, **kwargs) 591 | elif self.dataset == 'lm1b': 592 | kwargs['shuffle'] = True 593 | data_iter = LMMultiFileIterator(self.train, self.vocab, *args, **kwargs) 594 | elif split in ['valid', 'test']: 595 | data = self.valid if split == 'valid' else self.test 596 | if self.dataset in ['ptb', 'wt2', 'wt103', 'enwik8', 'text8']: 597 | data_iter = LMOrderedIterator(data, *args, **kwargs) 598 | elif self.dataset == 'lm1b': 599 | data_iter = LMShuffledIterator(data, *args, **kwargs) 600 | 601 | return data_iter 602 | 603 | 604 | def get_lm_corpus(datadir, dataset): 605 | fn = os.path.join(datadir, 'cache.pt') 606 | fn_pickle = os.path.join(datadir, 'cache.pkl') 607 | if os.path.exists(fn): 608 | print('Loading cached dataset...') 609 | corpus = torch.load(fn_pickle) 610 | elif os.path.exists(fn): 611 | print('Loading cached dataset from pickle...') 612 | with open(fn, "rb") as fp: 613 | corpus = pickle.load(fp) 614 | else: 615 | print('Producing dataset {}...'.format(dataset)) 616 | kwargs = {} 617 | if dataset in ['wt103', 'wt2']: 618 | kwargs['special'] = [''] 619 | kwargs['lower_case'] = False 620 | elif dataset == 'ptb': 621 | kwargs['special'] = [''] 622 | kwargs['lower_case'] = True 623 | elif dataset == 'lm1b': 624 | kwargs['special'] = [] 625 | kwargs['lower_case'] = False 626 | kwargs['vocab_file'] = os.path.join(datadir, '1b_word_vocab.txt') 627 | elif dataset in ['enwik8', 'text8']: 628 | pass 629 | 630 | corpus = TransfoXLCorpus(datadir, dataset, **kwargs) 631 | torch.save(corpus, fn) 632 | 633 | return corpus 634 | 635 | def _is_whitespace(char): 636 | """Checks whether `chars` is a whitespace character.""" 637 | # \t, \n, and \r are technically contorl characters but we treat them 638 | # as whitespace since they are generally considered as such. 639 | if char == " " or char == "\t" or char == "\n" or char == "\r": 640 | return True 641 | cat = unicodedata.category(char) 642 | if cat == "Zs": 643 | return True 644 | return False 645 | 646 | 647 | def _is_control(char): 648 | """Checks whether `chars` is a control character.""" 649 | # These are technically control characters but we count them as whitespace 650 | # characters. 651 | if char == "\t" or char == "\n" or char == "\r": 652 | return False 653 | cat = unicodedata.category(char) 654 | if cat.startswith("C"): 655 | return True 656 | return False 657 | 658 | 659 | def _is_punctuation(char): 660 | """Checks whether `chars` is a punctuation character.""" 661 | cp = ord(char) 662 | # We treat all non-letter/number ASCII as punctuation. 663 | # Characters such as "^", "$", and "`" are not in the Unicode 664 | # Punctuation class but we treat them as punctuation anyways, for 665 | # consistency. 666 | if ((cp >= 33 and cp <= 47) or (cp >= 58 and cp <= 64) or 667 | (cp >= 91 and cp <= 96) or (cp >= 123 and cp <= 126)): 668 | return True 669 | cat = unicodedata.category(char) 670 | if cat.startswith("P"): 671 | return True 672 | return False 673 | -------------------------------------------------------------------------------- /src/model.py: -------------------------------------------------------------------------------- 1 | import sys 2 | from pytorch_pretrained_bert.module import BertModel 3 | import torch 4 | import torch.nn as nn 5 | import pytorch_pretrained_bert as Bert 6 | import numpy as np 7 | import copy 8 | import torch.nn.functional as F 9 | import math 10 | from src.vae import * 11 | 12 | 13 | 14 | def gelu(x): 15 | return 0.5 * x * (1 + torch.tanh(math.sqrt(math.pi / 2) * (x + 0.044715 * x ** 3))) 16 | class BertConfig(Bert.modeling.BertConfig): 17 | def __init__(self, config): 18 | super(BertConfig, self).__init__( 19 | vocab_size_or_config_json_file=config.get('vocab_size'), 20 | hidden_size=config['hidden_size'], 21 | num_hidden_layers=config.get('num_hidden_layers'), 22 | num_attention_heads=config.get('num_attention_heads'), 23 | intermediate_size=config.get('intermediate_size'), 24 | hidden_act=config.get('hidden_act'), 25 | hidden_dropout_prob=config.get('hidden_dropout_prob'), 26 | attention_probs_dropout_prob=config.get('attention_probs_dropout_prob'), 27 | max_position_embeddings = config.get('max_position_embedding'), 28 | initializer_range=config.get('initializer_range'), 29 | ) 30 | self.seg_vocab_size = config.get('seg_vocab_size') 31 | self.age_vocab_size = config.get('age_vocab_size') 32 | self.num_treatment = config.get('num_treatment') 33 | self.device = config.get('device') 34 | self.year_vocab_size = config.get('year_vocab_size') 35 | 36 | if config.get('poolingSize') is not None: 37 | self.poolingSize = config.get('poolingSize') 38 | if config.get('MEM') is not None: 39 | self.MEM = config.get('MEM') 40 | else: 41 | self.MEM=False 42 | if config.get('unsupSize') is not None: 43 | self.unsupSize = config.get('unsupSize') 44 | if config.get('unsupVAE') is not None: 45 | self.unsupVAE = True 46 | else: 47 | self.unsupVAE = False 48 | if config.get("vaeinchannels") is not None: 49 | self.vaeinchannels = config.get('vaeinchannels') 50 | if config.get("vaelatentdim") is not None: 51 | self.vaelatentdim = config.get('vaelatentdim') 52 | if config.get("vaehidden") is not None: 53 | self.vaehidden = config.get('vaehidden') 54 | if config.get("klpar") is not None: 55 | self.klpar = config.get('klpar') 56 | else: 57 | self.klpar = 1 58 | if config.get('BetaD') is not None: 59 | self.BetaD = config.get('BetaD') 60 | else: 61 | self.BetaD = False 62 | 63 | class BertEmbeddingsUnsup(nn.Module): 64 | """Construct the embeddings from word, segment, age 65 | """ 66 | 67 | def __init__(self, config): 68 | super(BertEmbeddingsUnsup, self).__init__() 69 | self.word_embeddings = nn.Embedding(config.vocab_size, config.hidden_size) 70 | # self.segment_embeddings = nn.Embedding(config.seg_vocab_size, config.hidden_size) 71 | self.age_embeddings = nn.Embedding(config.age_vocab_size, config.hidden_size) 72 | self.posi_embeddings = nn.Embedding(config.max_position_embeddings, config.hidden_size).\ 73 | from_pretrained(embeddings=self._init_posi_embedding(config.max_position_embeddings, config.hidden_size)) 74 | self.year_embeddings = nn.Embedding(config.year_vocab_size, config.hidden_size) 75 | self.unsuplist = config.unsupSize 76 | sumInputTabular = sum([el[1] for el in self.unsuplist]) 77 | self.unsupEmbeddings = nn.ModuleList([nn.Embedding(el[0], el[1]) for el in self.unsuplist]) 78 | self.unsupLinear = nn.Linear(sumInputTabular, config.hidden_size) 79 | 80 | self.LayerNorm = Bert.modeling.BertLayerNorm(config.hidden_size, eps=1e-12) 81 | 82 | self.dropout = nn.Dropout(config.hidden_dropout_prob) 83 | 84 | def forward(self, word_ids, age_ids=None, seg_ids=None, posi_ids=None, year_ids=None): 85 | if seg_ids is None: 86 | seg_ids = torch.zeros_like(word_ids) 87 | if age_ids is None: 88 | age_ids = torch.zeros_like(word_ids) 89 | if posi_ids is None: 90 | posi_ids = torch.zeros_like(word_ids) 91 | 92 | tabularVar = word_ids[:,:len(self.unsuplist)] 93 | word_ids = word_ids[:,len(self.unsuplist):] 94 | 95 | 96 | word_embed = self.word_embeddings(word_ids) 97 | age_embed = self.age_embeddings(age_ids) 98 | posi_embeddings = self.posi_embeddings(posi_ids) 99 | year_embed = self.year_embeddings(year_ids) 100 | tabularVar = tabularVar.transpose(1,0) 101 | tabularVarembed = torch.cat([self.unsupEmbeddings[eliter](el) for eliter, el in enumerate(tabularVar)], dim=1) 102 | tabularVarembed = self.unsupLinear(tabularVarembed).unsqueeze(1) 103 | embeddings = word_embed + age_embed + year_embed + posi_embeddings 104 | embeddings = torch.cat((tabularVarembed, embeddings), dim=1) 105 | 106 | embeddings = self.LayerNorm(embeddings) 107 | 108 | embeddings = self.dropout(embeddings) 109 | return embeddings 110 | 111 | def _init_posi_embedding(self, max_position_embedding, hidden_size): 112 | def even_code(pos, idx): 113 | return np.sin(pos/(10000**(2*idx/hidden_size))) 114 | 115 | def odd_code(pos, idx): 116 | return np.cos(pos/(10000**(2*idx/hidden_size))) 117 | 118 | # initialize position embedding table 119 | lookup_table = np.zeros((max_position_embedding, hidden_size), dtype=np.float32) 120 | 121 | # reset table parameters with hard encoding 122 | # set even dimension 123 | for pos in range(max_position_embedding): 124 | for idx in np.arange(0, hidden_size, step=2): 125 | lookup_table[pos, idx] = even_code(pos, idx) 126 | # set odd dimension 127 | for pos in range(max_position_embedding): 128 | for idx in np.arange(1, hidden_size, step=2): 129 | lookup_table[pos, idx] = odd_code(pos, idx) 130 | 131 | return torch.tensor(lookup_table) 132 | 133 | class BertLastPooler(nn.Module): 134 | def __init__(self, config): 135 | super(BertLastPooler, self).__init__() 136 | self.dense = nn.Linear(config.hidden_size, config.poolingSize) 137 | self.activation = nn.Tanh() 138 | 139 | def forward(self, hidden_states): 140 | # We "pool" the model by simply taking the hidden state corresponding 141 | # to the last token. this is unlike first token pooling, just switched the ordering of the data 142 | first_token_tensor = hidden_states[:, -1] 143 | pooled_output = self.dense(first_token_tensor) 144 | pooled_output = self.activation(pooled_output) 145 | return pooled_output 146 | 147 | class BertVAEPooler(nn.Module): 148 | def __init__(self, config): 149 | super(BertVAEPooler, self).__init__() 150 | self.dense = nn.Linear(config.hidden_size, config.poolingSize) 151 | self.activation = nn.Tanh() 152 | 153 | def forward(self, hidden_states): 154 | # We "pool" the model by simply taking the hidden state corresponding 155 | # to the first token. 156 | first_token_tensor = hidden_states[:, 0] 157 | pooled_output = self.dense(first_token_tensor) 158 | pooled_output = self.activation(pooled_output) 159 | return pooled_output 160 | 161 | 162 | 163 | class TBEHRT(Bert.modeling.BertPreTrainedModel): 164 | def __init__(self, config, num_labels): 165 | super(TBEHRT, self).__init__(config) 166 | if 'cuda' in config.device: 167 | 168 | self.device = int(config.device[-1]) 169 | 170 | self.otherDevice = self.device 171 | else: 172 | self.device = 'cpu' 173 | self.otherDevice = self.device 174 | self.bert = SimpleBEHRT (config) 175 | 176 | # self.bert = BertModel(config) 177 | self.treatmentC = nn.Linear(config.poolingSize, config.num_treatment) 178 | self.OutcomeT1_1 = nn.Linear(config.poolingSize, config.poolingSize) 179 | self.OutcomeT2_1 = nn.Linear(config.poolingSize, config.poolingSize) 180 | self.OutcomeT3_1 = nn.Linear(config.poolingSize, 2) 181 | 182 | self.OutcomeT1_2 = nn.Linear(config.poolingSize, config.poolingSize) 183 | self.OutcomeT2_2 = nn.Linear(config.poolingSize, config.poolingSize) 184 | self.OutcomeT3_2 = nn.Linear(config.poolingSize, 2) 185 | self.config = config 186 | self.dropoutVAE = nn.Dropout(config.hidden_dropout_prob) 187 | 188 | self.dropout = nn.Dropout(config.hidden_dropout_prob) 189 | self.treatmentC.to(self.otherDevice) 190 | 191 | self.OutcomeT1_1.to(self.otherDevice) 192 | self.OutcomeT2_1.to(self.otherDevice) 193 | 194 | self.OutcomeT3_1.to(self.otherDevice) 195 | 196 | 197 | self.OutcomeT1_2.to(self.otherDevice) 198 | self.OutcomeT2_2.to(self.otherDevice) 199 | 200 | self.OutcomeT3_2.to(self.otherDevice) 201 | self.gelu = nn.ELU() 202 | self.num_labels = num_labels 203 | self.num_treatment = config.num_treatment 204 | self.logS = nn.LogSoftmax() 205 | 206 | self.VAEpooler = BertVAEPooler(config) 207 | self.VAEpooler.to(self.otherDevice) 208 | self.VAE = VAE(config) 209 | self.VAE.to(self.otherDevice) 210 | self.treatmentW = 1.0 211 | self.MEM = False 212 | self.config = config 213 | if config.MEM is True: 214 | self.MEM = True 215 | print('turning on the MEM....') 216 | self.cls = Bert.modeling.BertOnlyMLMHead(config, self.bert.bert.embeddings.word_embeddings.weight) 217 | self.cls.to(self.otherDevice) 218 | self.apply(self.init_bert_weights) 219 | print("full init completed...") 220 | 221 | def forward(self, input_ids, age_ids=None, seg_ids=None, posi_ids=None, year_ids = None, attention_mask=None, masked_lm_labels=None, 222 | outcomeT=None, treatmentCLabel=None, fullEval=False, vaelabel = None): 223 | batchs = input_ids.shape[0] 224 | embed4MLM, pooled_out = self.bert(input_ids, age_ids, seg_ids, posi_ids , year_ids, attention_mask, 225 | output_all_encoded_layers=False, fullmask = None) 226 | 227 | treatmentCLabel = treatmentCLabel.to(self.otherDevice) 228 | pooled_outVAE = self.VAEpooler(embed4MLM) 229 | pooled_outVAE = self.dropoutVAE(pooled_outVAE) 230 | 231 | outcomeT = outcomeT.to(self.otherDevice) 232 | outputVAE = self.VAE(pooled_outVAE, vaelabel) 233 | 234 | if self.MEM==True: 235 | prediction_scores = self.cls(embed4MLM[:,1:]) 236 | 237 | if masked_lm_labels is not None: 238 | loss_fct = nn.CrossEntropyLoss(ignore_index=-1) 239 | masked_lm_loss = loss_fct(prediction_scores.view(-1, self.config.vocab_size), masked_lm_labels.view(-1)) 240 | 241 | else: 242 | prediction_scores = torch.ones([batchs, self.config.vocab_size]).to(self.otherDevice) 243 | masked_lm_loss = torch.tensor([0.0]).to(self.otherDevice) 244 | 245 | treatmentOut = self.treatmentC(pooled_out) 246 | tloss_fuct = nn.NLLLoss(reduction='mean') 247 | lossT = tloss_fuct(self.logS(treatmentOut).view(-1, self.config.num_treatment), treatmentCLabel.view(-1)) 248 | pureSoft = nn.Softmax(dim=1) 249 | 250 | treatmentOut = pureSoft(treatmentOut) 251 | 252 | out1 = self.gelu(self.OutcomeT1_1(pooled_out)) 253 | out2 = self.gelu(self.OutcomeT2_1(out1)) 254 | logits0 = (self.OutcomeT3_1(out2)) 255 | 256 | out12 = self.gelu(self.OutcomeT1_2(pooled_out)) 257 | out22 = self.gelu(self.OutcomeT2_2(out12)) 258 | logits1 = (self.OutcomeT3_2(out22)) 259 | 260 | 261 | 262 | outcome1loss = nn.CrossEntropyLoss(reduction='none') 263 | outcome0loss = nn.CrossEntropyLoss(reduction='none') 264 | lossRaw0 = outcome0loss(logits0,outcomeT.squeeze(-1)) 265 | lossRaw1 = outcome1loss(logits1, outcomeT.squeeze(-1)) 266 | 267 | trueLoss1 = torch.mean(lossRaw1*(treatmentCLabel.type(torch.FloatTensor).squeeze(-1)).to(self.otherDevice)) 268 | trueLoss0 = torch.mean(lossRaw0*(1-treatmentCLabel.type(torch.FloatTensor).squeeze(-1)).to(self.otherDevice)) 269 | 270 | tloss = trueLoss0 +trueLoss1 + self.treatmentW*lossT 271 | 272 | 273 | outlog1 = pureSoft(logits1)[:,1] 274 | outlog0 = pureSoft(logits0)[:,1] 275 | outlogits = torch.cat((outlog0.view(-1, 1).unsqueeze(0), outlog1.view(-1, 1).unsqueeze(0)), dim=0 ) 276 | fulTout = treatmentOut 277 | outputTreatIndex = treatmentCLabel 278 | outlabelsfull = outcomeT 279 | 280 | 281 | # vae loss 282 | vaeloss = self.VAE.loss_function(outputVAE) 283 | vaeloss_total = vaeloss['loss'] 284 | masked_lm_loss= masked_lm_loss+ vaeloss_total 285 | return masked_lm_loss, tloss, prediction_scores.view(-1, self.config.vocab_size), masked_lm_labels.contiguous().view( -1), fulTout, outputTreatIndex, outlogits, outlabelsfull, outputTreatIndex, 0, vaeloss 286 | 287 | 288 | class TARNET_MEM(Bert.modeling.BertPreTrainedModel): 289 | def __init__(self, config, num_labels): 290 | super(TARNET_MEM, self).__init__(config) 291 | if 'cuda' in config.device: 292 | 293 | self.device = int(config.device[-1]) 294 | 295 | self.otherDevice = self.device 296 | else: 297 | self.device = 'cpu' 298 | self.otherDevice = self.device 299 | self.bert = SimpleBEHRT (config) 300 | 301 | # self.bert = BertModel(config) 302 | self.treatmentC = nn.Linear(config.poolingSize, config.num_treatment) 303 | self.OutcomeT1_1 = nn.Linear(config.poolingSize, config.poolingSize) 304 | self.OutcomeT2_1 = nn.Linear(config.poolingSize, config.poolingSize) 305 | self.OutcomeT3_1 = nn.Linear(config.poolingSize, 2) 306 | 307 | 308 | self.OutcomeT1_2 = nn.Linear(config.poolingSize, config.poolingSize) 309 | self.OutcomeT2_2 = nn.Linear(config.poolingSize, config.poolingSize) 310 | self.OutcomeT3_2 = nn.Linear(config.poolingSize, 2) 311 | self.config = config 312 | self.dropoutVAE = nn.Dropout(config.hidden_dropout_prob) 313 | 314 | self.dropout = nn.Dropout(config.hidden_dropout_prob) 315 | self.treatmentC.to(self.otherDevice) 316 | 317 | self.OutcomeT1_1.to(self.otherDevice) 318 | self.OutcomeT2_1.to(self.otherDevice) 319 | 320 | self.OutcomeT3_1.to(self.otherDevice) 321 | 322 | 323 | self.OutcomeT1_2.to(self.otherDevice) 324 | self.OutcomeT2_2.to(self.otherDevice) 325 | 326 | self.OutcomeT3_2.to(self.otherDevice) 327 | self.gelu = nn.ELU() 328 | self.num_labels = num_labels 329 | self.num_treatment = config.num_treatment 330 | self.logS = nn.LogSoftmax() 331 | 332 | 333 | 334 | 335 | 336 | self.VAEpooler = BertVAEPooler(config) 337 | self.VAEpooler.to(self.otherDevice) 338 | self.VAE = VAE(config) 339 | self.VAE.to(self.otherDevice) 340 | self.treatmentW = 1.0 341 | self.MEM = False 342 | self.config = config 343 | if config.MEM is True: 344 | self.MEM = True 345 | print('turning on the MEM....') 346 | self.cls = Bert.modeling.BertOnlyMLMHead(config, self.bert.bert.embeddings.word_embeddings.weight) 347 | self.cls.to(self.otherDevice) 348 | self.apply(self.init_bert_weights) 349 | print("full init completed...") 350 | 351 | def forward(self, input_ids, age_ids=None, seg_ids=None, posi_ids=None, year_ids = None, attention_mask=None, masked_lm_labels=None, 352 | outcomeT=None, treatmentCLabel=None, fullEval=False, vaelabel = None): 353 | batchs = input_ids.shape[0] 354 | embed4MLM, pooled_out = self.bert(input_ids, age_ids, seg_ids, posi_ids , year_ids, attention_mask, 355 | output_all_encoded_layers=False, fullmask = None) 356 | 357 | treatmentCLabel = treatmentCLabel.to(self.otherDevice) 358 | # 359 | pooled_outVAE = self.VAEpooler(embed4MLM) 360 | pooled_outVAE = self.dropoutVAE(pooled_outVAE) 361 | 362 | outcomeT = outcomeT.to(self.otherDevice) 363 | outputVAE = self.VAE(pooled_outVAE, vaelabel) 364 | 365 | masked_lm_loss = torch.tensor([0.0]).to(self.otherDevice) 366 | 367 | if self.MEM==True: 368 | prediction_scores = self.cls(embed4MLM[:,1:]) 369 | 370 | if masked_lm_labels is not None: 371 | loss_fct = nn.CrossEntropyLoss(ignore_index=-1) 372 | masked_lm_loss = loss_fct(prediction_scores.view(-1, self.config.vocab_size), masked_lm_labels.view(-1)) 373 | 374 | else: 375 | prediction_scores = torch.ones([batchs, self.config.vocab_size]).to(self.otherDevice) 376 | masked_lm_loss = torch.tensor([0.0]).to(self.otherDevice) 377 | 378 | 379 | treatmentOut = self.treatmentC(pooled_out) 380 | tloss_fuct = nn.NLLLoss(reduction='mean') 381 | pureSoft = nn.Softmax(dim=1) 382 | 383 | treatmentOut = pureSoft(treatmentOut) 384 | 385 | out1 = self.gelu(self.OutcomeT1_1(pooled_out)) 386 | out2 = self.gelu(self.OutcomeT2_1(out1)) 387 | logits0 = (self.OutcomeT3_1(out2)) 388 | 389 | out12 = self.gelu(self.OutcomeT1_2(pooled_out)) 390 | out22 = self.gelu(self.OutcomeT2_2(out12)) 391 | logits1 = (self.OutcomeT3_2(out22)) 392 | 393 | 394 | outcome1loss = nn.CrossEntropyLoss(reduction='none') 395 | outcome0loss = nn.CrossEntropyLoss(reduction='none') 396 | lossRaw0 = outcome0loss(logits0,outcomeT.squeeze(-1)) 397 | lossRaw1 = outcome1loss(logits1, outcomeT.squeeze(-1)) 398 | 399 | trueLoss1 = torch.mean(lossRaw1*(treatmentCLabel.type(torch.FloatTensor).squeeze(-1)).to(self.otherDevice)) 400 | trueLoss0 = torch.mean(lossRaw0*(1-treatmentCLabel.type(torch.FloatTensor).squeeze(-1)).to(self.otherDevice)) 401 | 402 | tloss = trueLoss0 +trueLoss1 403 | outlog1 = pureSoft(logits1)[:,1] 404 | 405 | outlog0 = pureSoft(logits0)[:,1] 406 | 407 | 408 | outlogits = torch.cat((outlog0.view(-1, 1).unsqueeze(0), outlog1.view(-1, 1).unsqueeze(0)), dim=0 ) 409 | fulTout = treatmentOut 410 | 411 | outputTreatIndex = treatmentCLabel 412 | outlabelsfull = outcomeT 413 | 414 | 415 | 416 | 417 | 418 | 419 | 420 | # vae loss 421 | vaeloss = self.VAE.loss_function(outputVAE) 422 | vaeloss_total = vaeloss['loss'] 423 | masked_lm_loss= masked_lm_loss+ vaeloss_total 424 | return masked_lm_loss, tloss, prediction_scores.view(-1, self.config.vocab_size), masked_lm_labels.contiguous().view( -1), fulTout, outputTreatIndex, outlogits, outlabelsfull, outputTreatIndex, 0, vaeloss 425 | 426 | 427 | 428 | 429 | class SimpleBEHRT(Bert.modeling.BertPreTrainedModel): 430 | def __init__(self, config, num_labels=1): 431 | super(SimpleBEHRT, self).__init__(config) 432 | 433 | self.bert = BEHRTBASE(config) 434 | if 'cuda' in config.device: 435 | 436 | self.device = int(config.device[-1]) 437 | 438 | self.otherDevice = self.device 439 | else: 440 | self.device = 'cpu' 441 | self.otherDevice = self.device 442 | 443 | 444 | self.bert.to(self.device) 445 | 446 | self.num_labels = num_labels 447 | 448 | self.dropout = nn.Dropout(config.hidden_dropout_prob) 449 | 450 | self.pooler = BertLastPooler(config) 451 | self.pooler.to(self.otherDevice) 452 | 453 | 454 | self.config = config 455 | self.apply(self.init_bert_weights) 456 | 457 | def forward(self, input_ids, age_ids=None, seg_ids = None, posi_ids=None, year_ids=None, attention_mask=None,output_all_encoded_layers = False, fullmask = None): 458 | batchS = input_ids.shape[0] 459 | sequence_output, embedding_outputLSTM, embedding_outputLSTM2, attention_maskLSTM = self.bert(input_ids, age_ids ,seg_ids, posi_ids, year_ids, attention_mask, fullmask = fullmask) 460 | 461 | pooled_out = self.pooler(embedding_outputLSTM) 462 | pooled_out = self.dropout(pooled_out) 463 | return embedding_outputLSTM, pooled_out 464 | 465 | class BEHRTBASE(Bert.modeling.BertPreTrainedModel): 466 | def __init__(self, config): 467 | super(BEHRTBASE, self).__init__(config) 468 | 469 | self.embeddings = BertEmbeddingsUnsup(config=config) 470 | 471 | self.encoder = Bert.modeling.BertEncoder(config=config) 472 | self.config = config 473 | 474 | self.apply(self.init_bert_weights) 475 | 476 | def forward(self, input_ids, age_ids=None, seg_ids = None, posi_ids=None, year_ids=None, attention_mask=None, fullmask = None): 477 | if attention_mask is None: 478 | attention_mask = torch.ones_like(input_ids) 479 | if age_ids is None: 480 | age_ids = torch.zeros_like(input_ids) 481 | 482 | if posi_ids is None: 483 | posi_ids = torch.zeros_like(input_ids) 484 | 485 | # We create a 3D attention mask from a 2D tensor mask. 486 | # Sizes are [batch_size, 1, 1, to_seq_length] 487 | # So we can broadcast to [batch_size, num_heads, from_seq_length, to_seq_length] 488 | # this attention mask is more simple than the triangular masking of causal attention 489 | # used in OpenAI GPT, we just need to prepare the broadcast dimension here. 490 | encodermask = attention_mask 491 | encodermask = encodermask.unsqueeze(1).unsqueeze(2) 492 | encodermask = encodermask.to(dtype=next(self.parameters()).dtype) # fp16 compatibility 493 | encodermask = (1.0 - encodermask) * -10000.0 494 | 495 | 496 | 497 | # Since attention_mask is 1.0 for positions we want to attend and 0.0 for 498 | # masked positions, this operation will create a tensor which is 0.0 for 499 | # positions we want to attend and -10000.0 for masked positions. 500 | # Since we are adding it to the raw scores before the softmax, this is 501 | # effectively the same as removing these entirely. 502 | 503 | embedding_output = self.embeddings(input_ids, age_ids, seg_ids, posi_ids, year_ids) 504 | 505 | encoded_layers = self.encoder(embedding_output, 506 | encodermask, 507 | output_all_encoded_layers=False) 508 | sequenceOut = encoded_layers[-1] 509 | return [0], sequenceOut, [0], [0] 510 | 511 | 512 | class DRAGONNET(Bert.modeling.BertPreTrainedModel): 513 | def __init__(self, config, num_labels): 514 | super(DRAGONNET, self).__init__(config) 515 | if 'cuda' in config.device: 516 | 517 | self.device = int(config.device[-1]) 518 | 519 | self.otherDevice = self.device 520 | else: 521 | self.device = 'cpu' 522 | self.otherDevice = self.device 523 | self.bert = SimpleBEHRT (config) 524 | 525 | self.treatmentC = nn.Linear(config.poolingSize, config.num_treatment) 526 | self.OutcomeT1_1 = nn.Linear(config.poolingSize, config.poolingSize) 527 | self.OutcomeT2_1 = nn.Linear(config.poolingSize, config.poolingSize) 528 | self.OutcomeT3_1 = nn.Linear(config.poolingSize, 2) 529 | 530 | 531 | self.OutcomeT1_2 = nn.Linear(config.poolingSize, config.poolingSize) 532 | self.OutcomeT2_2 = nn.Linear(config.poolingSize, config.poolingSize) 533 | self.OutcomeT3_2 = nn.Linear(config.poolingSize, 2) 534 | self.config = config 535 | self.dropoutVAE = nn.Dropout(config.hidden_dropout_prob) 536 | 537 | self.dropout = nn.Dropout(config.hidden_dropout_prob) 538 | self.treatmentC.to(self.otherDevice) 539 | 540 | self.OutcomeT1_1.to(self.otherDevice) 541 | self.OutcomeT2_1.to(self.otherDevice) 542 | 543 | self.OutcomeT3_1.to(self.otherDevice) 544 | 545 | 546 | self.OutcomeT1_2.to(self.otherDevice) 547 | self.OutcomeT2_2.to(self.otherDevice) 548 | 549 | self.OutcomeT3_2.to(self.otherDevice) 550 | self.gelu = nn.ELU() 551 | self.num_labels = num_labels 552 | self.num_treatment = config.num_treatment 553 | self.logS = nn.LogSoftmax() 554 | 555 | 556 | 557 | self.VAEpooler = BertVAEPooler(config) 558 | self.VAEpooler.to(self.otherDevice) 559 | self.VAE = VAE(config) 560 | self.VAE.to(self.otherDevice) 561 | self.treatmentW = 1.0 562 | self.MEM = False 563 | self.config = config 564 | if config.MEM is True: 565 | self.MEM = True 566 | print('turning on the MEM....') 567 | self.cls = Bert.modeling.BertOnlyMLMHead(config, self.bert.bert.embeddings.word_embeddings.weight) 568 | self.cls.to(self.otherDevice) 569 | self.apply(self.init_bert_weights) 570 | print("full init completed...") 571 | 572 | def forward(self, input_ids, age_ids=None, seg_ids=None, posi_ids=None, year_ids = None, attention_mask=None, masked_lm_labels=None, 573 | outcomeT=None, treatmentCLabel=None, fullEval=False, vaelabel = None): 574 | batchs = input_ids.shape[0] 575 | embed4MLM, pooled_out = self.bert(input_ids, age_ids, seg_ids, posi_ids , year_ids, attention_mask, 576 | output_all_encoded_layers=False, fullmask = None) 577 | 578 | treatmentCLabel = treatmentCLabel.to(self.otherDevice) 579 | # 580 | pooled_outVAE = self.VAEpooler(embed4MLM) 581 | pooled_outVAE = self.dropoutVAE(pooled_outVAE) 582 | 583 | outcomeT = outcomeT.to(self.otherDevice) 584 | outputVAE = self.VAE(pooled_outVAE, vaelabel) 585 | 586 | if self.MEM==True: 587 | prediction_scores = self.cls(embed4MEM[:,1:]) 588 | 589 | if masked_lm_labels is not None: 590 | loss_fct = nn.CrossEntropyLoss(ignore_index=-1) 591 | masked_lm_loss = loss_fct(prediction_scores.view(-1, self.config.vocab_size), masked_lm_labels.view(-1)) 592 | 593 | else: 594 | prediction_scores = torch.ones([batchs, self.config.vocab_size]).to(self.otherDevice) 595 | masked_lm_loss = torch.tensor([0.0]).to(self.otherDevice) 596 | 597 | treatmentOut = self.treatmentC(pooled_out) 598 | tloss_fuct = nn.NLLLoss(reduction='mean') 599 | lossT = tloss_fuct(self.logS(treatmentOut).view(-1, self.config.num_treatment), treatmentCLabel.view(-1)) 600 | pureSoft = nn.Softmax(dim=1) 601 | 602 | treatmentOut = pureSoft(treatmentOut) 603 | 604 | out1 = self.gelu(self.OutcomeT1_1(pooled_out)) 605 | out2 = self.gelu(self.OutcomeT2_1(out1)) 606 | logits0 = (self.OutcomeT3_1(out2)) 607 | 608 | out12 = self.gelu(self.OutcomeT1_2(pooled_out)) 609 | out22 = self.gelu(self.OutcomeT2_2(out12)) 610 | logits1 = (self.OutcomeT3_2(out22)) 611 | outcome1loss = nn.CrossEntropyLoss(reduction='none') 612 | outcome0loss = nn.CrossEntropyLoss(reduction='none') 613 | lossRaw0 = outcome0loss(logits0,outcomeT.squeeze(-1)) 614 | lossRaw1 = outcome1loss(logits1, outcomeT.squeeze(-1)) 615 | trueLoss1 = torch.mean(lossRaw1*(treatmentCLabel.type(torch.FloatTensor).squeeze(-1)).to(self.otherDevice)) 616 | trueLoss0 = torch.mean(lossRaw0*(1-treatmentCLabel.type(torch.FloatTensor).squeeze(-1)).to(self.otherDevice)) 617 | 618 | tloss = trueLoss0 +trueLoss1 + self.treatmentW*lossT 619 | outlog1 = pureSoft(logits1)[:,1] 620 | 621 | outlog0 = pureSoft(logits0)[:,1] 622 | 623 | 624 | outlogits = torch.cat((outlog0.view(-1, 1).unsqueeze(0), outlog1.view(-1, 1).unsqueeze(0)), dim=0 ) 625 | fulTout = treatmentOut 626 | 627 | outputTreatIndex = treatmentCLabel 628 | outlabelsfull = outcomeT 629 | 630 | 631 | 632 | # vae loss 633 | vaeloss = self.VAE.loss_function(outputVAE) 634 | vaeloss_total = vaeloss['loss'] 635 | masked_lm_loss= masked_lm_loss+ vaeloss_total 636 | return masked_lm_loss, tloss, prediction_scores.view(-1, self.config.vocab_size), masked_lm_labels.contiguous().view( -1), fulTout, outputTreatIndex, outlogits, outlabelsfull, outputTreatIndex, 0, vaeloss 637 | --------------------------------------------------------------------------------