├── LICENSE ├── MANIFEST.in ├── README.md ├── figs ├── .DS_Store ├── fig1.png ├── fig2.png ├── fig3a.png └── fig3b.png ├── requirements.txt ├── run_finetuning_CR.py ├── run_finetuning_QA.py ├── run_finetuning_SL.py ├── run_taskemb_CR.py ├── run_taskemb_QA.py ├── run_taskemb_SL.py ├── run_textemb_CR.py ├── run_textemb_QA.py ├── run_textemb_SL.py ├── setup.py ├── transformers ├── .DS_Store ├── __init__.py ├── __main__.py ├── configuration_auto.py ├── configuration_bert.py ├── configuration_utils.py ├── convert_bert_original_tf_checkpoint_to_pytorch.py ├── convert_bert_pytorch_checkpoint_to_original_tf.py ├── convert_pytorch_checkpoint_to_tf2.py ├── data │ ├── .DS_Store │ ├── __init__.py │ ├── metrics │ │ ├── .DS_Store │ │ └── __init__.py │ └── processors │ │ ├── .DS_Store │ │ ├── __init__.py │ │ ├── glue.py │ │ └── utils.py ├── file_utils.py ├── modeling_auto.py ├── modeling_bert.py ├── modeling_task_embeddings.py ├── modeling_tf_auto.py ├── modeling_tf_bert.py ├── modeling_tf_pytorch_utils.py ├── modeling_tf_utils.py ├── modeling_utils.py ├── optimization.py ├── tests │ ├── .DS_Store │ ├── __init__.py │ ├── configuration_common_test.py │ ├── conftest.py │ ├── fixtures │ │ ├── input.txt │ │ ├── sample_text.txt │ │ └── test_sentencepiece.model │ ├── modeling_auto_test.py │ ├── modeling_bert_test.py │ ├── modeling_common_test.py │ ├── optimization_test.py │ ├── tokenization_auto_test.py │ ├── tokenization_bert_test.py │ ├── tokenization_tests_commons.py │ └── tokenization_utils_test.py ├── tokenization_auto.py ├── tokenization_bert.py └── tokenization_utils.py ├── utils_ner.py ├── utils_squad.py └── utils_squad_evaluate.py /LICENSE: -------------------------------------------------------------------------------- 1 | 2 | Apache License 3 | Version 2.0, January 2004 4 | http://www.apache.org/licenses/ 5 | 6 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 7 | 8 | 1. Definitions. 9 | 10 | "License" shall mean the terms and conditions for use, reproduction, 11 | and distribution as defined by Sections 1 through 9 of this document. 12 | 13 | "Licensor" shall mean the copyright owner or entity authorized by 14 | the copyright owner that is granting the License. 15 | 16 | "Legal Entity" shall mean the union of the acting entity and all 17 | other entities that control, are controlled by, or are under common 18 | control with that entity. For the purposes of this definition, 19 | "control" means (i) the power, direct or indirect, to cause the 20 | direction or management of such entity, whether by contract or 21 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 22 | outstanding shares, or (iii) beneficial ownership of such entity. 23 | 24 | "You" (or "Your") shall mean an individual or Legal Entity 25 | exercising permissions granted by this License. 26 | 27 | "Source" form shall mean the preferred form for making modifications, 28 | including but not limited to software source code, documentation 29 | source, and configuration files. 30 | 31 | "Object" form shall mean any form resulting from mechanical 32 | transformation or translation of a Source form, including but 33 | not limited to compiled object code, generated documentation, 34 | and conversions to other media types. 35 | 36 | "Work" shall mean the work of authorship, whether in Source or 37 | Object form, made available under the License, as indicated by a 38 | copyright notice that is included in or attached to the work 39 | (an example is provided in the Appendix below). 40 | 41 | "Derivative Works" shall mean any work, whether in Source or Object 42 | form, that is based on (or derived from) the Work and for which the 43 | editorial revisions, annotations, elaborations, or other modifications 44 | represent, as a whole, an original work of authorship. For the purposes 45 | of this License, Derivative Works shall not include works that remain 46 | separable from, or merely link (or bind by name) to the interfaces of, 47 | the Work and Derivative Works thereof. 48 | 49 | "Contribution" shall mean any work of authorship, including 50 | the original version of the Work and any modifications or additions 51 | to that Work or Derivative Works thereof, that is intentionally 52 | submitted to Licensor for inclusion in the Work by the copyright owner 53 | or by an individual or Legal Entity authorized to submit on behalf of 54 | the copyright owner. For the purposes of this definition, "submitted" 55 | means any form of electronic, verbal, or written communication sent 56 | to the Licensor or its representatives, including but not limited to 57 | communication on electronic mailing lists, source code control systems, 58 | and issue tracking systems that are managed by, or on behalf of, the 59 | Licensor for the purpose of discussing and improving the Work, but 60 | excluding communication that is conspicuously marked or otherwise 61 | designated in writing by the copyright owner as "Not a Contribution." 62 | 63 | "Contributor" shall mean Licensor and any individual or Legal Entity 64 | on behalf of whom a Contribution has been received by Licensor and 65 | subsequently incorporated within the Work. 66 | 67 | 2. Grant of Copyright License. Subject to the terms and conditions of 68 | this License, each Contributor hereby grants to You a perpetual, 69 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 70 | copyright license to reproduce, prepare Derivative Works of, 71 | publicly display, publicly perform, sublicense, and distribute the 72 | Work and such Derivative Works in Source or Object form. 73 | 74 | 3. Grant of Patent License. Subject to the terms and conditions of 75 | this License, each Contributor hereby grants to You a perpetual, 76 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 77 | (except as stated in this section) patent license to make, have made, 78 | use, offer to sell, sell, import, and otherwise transfer the Work, 79 | where such license applies only to those patent claims licensable 80 | by such Contributor that are necessarily infringed by their 81 | Contribution(s) alone or by combination of their Contribution(s) 82 | with the Work to which such Contribution(s) was submitted. If You 83 | institute patent litigation against any entity (including a 84 | cross-claim or counterclaim in a lawsuit) alleging that the Work 85 | or a Contribution incorporated within the Work constitutes direct 86 | or contributory patent infringement, then any patent licenses 87 | granted to You under this License for that Work shall terminate 88 | as of the date such litigation is filed. 89 | 90 | 4. Redistribution. You may reproduce and distribute copies of the 91 | Work or Derivative Works thereof in any medium, with or without 92 | modifications, and in Source or Object form, provided that You 93 | meet the following conditions: 94 | 95 | (a) You must give any other recipients of the Work or 96 | Derivative Works a copy of this License; and 97 | 98 | (b) You must cause any modified files to carry prominent notices 99 | stating that You changed the files; and 100 | 101 | (c) You must retain, in the Source form of any Derivative Works 102 | that You distribute, all copyright, patent, trademark, and 103 | attribution notices from the Source form of the Work, 104 | excluding those notices that do not pertain to any part of 105 | the Derivative Works; and 106 | 107 | (d) If the Work includes a "NOTICE" text file as part of its 108 | distribution, then any Derivative Works that You distribute must 109 | include a readable copy of the attribution notices contained 110 | within such NOTICE file, excluding those notices that do not 111 | pertain to any part of the Derivative Works, in at least one 112 | of the following places: within a NOTICE text file distributed 113 | as part of the Derivative Works; within the Source form or 114 | documentation, if provided along with the Derivative Works; or, 115 | within a display generated by the Derivative Works, if and 116 | wherever such third-party notices normally appear. The contents 117 | of the NOTICE file are for informational purposes only and 118 | do not modify the License. You may add Your own attribution 119 | notices within Derivative Works that You distribute, alongside 120 | or as an addendum to the NOTICE text from the Work, provided 121 | that such additional attribution notices cannot be construed 122 | as modifying the License. 123 | 124 | You may add Your own copyright statement to Your modifications and 125 | may provide additional or different license terms and conditions 126 | for use, reproduction, or distribution of Your modifications, or 127 | for any such Derivative Works as a whole, provided Your use, 128 | reproduction, and distribution of the Work otherwise complies with 129 | the conditions stated in this License. 130 | 131 | 5. Submission of Contributions. Unless You explicitly state otherwise, 132 | any Contribution intentionally submitted for inclusion in the Work 133 | by You to the Licensor shall be under the terms and conditions of 134 | this License, without any additional terms or conditions. 135 | Notwithstanding the above, nothing herein shall supersede or modify 136 | the terms of any separate license agreement you may have executed 137 | with Licensor regarding such Contributions. 138 | 139 | 6. Trademarks. This License does not grant permission to use the trade 140 | names, trademarks, service marks, or product names of the Licensor, 141 | except as required for reasonable and customary use in describing the 142 | origin of the Work and reproducing the content of the NOTICE file. 143 | 144 | 7. Disclaimer of Warranty. Unless required by applicable law or 145 | agreed to in writing, Licensor provides the Work (and each 146 | Contributor provides its Contributions) on an "AS IS" BASIS, 147 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 148 | implied, including, without limitation, any warranties or conditions 149 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 150 | PARTICULAR PURPOSE. You are solely responsible for determining the 151 | appropriateness of using or redistributing the Work and assume any 152 | risks associated with Your exercise of permissions under this License. 153 | 154 | 8. Limitation of Liability. In no event and under no legal theory, 155 | whether in tort (including negligence), contract, or otherwise, 156 | unless required by applicable law (such as deliberate and grossly 157 | negligent acts) or agreed to in writing, shall any Contributor be 158 | liable to You for damages, including any direct, indirect, special, 159 | incidental, or consequential damages of any character arising as a 160 | result of this License or out of the use or inability to use the 161 | Work (including but not limited to damages for loss of goodwill, 162 | work stoppage, computer failure or malfunction, or any and all 163 | other commercial damages or losses), even if such Contributor 164 | has been advised of the possibility of such damages. 165 | 166 | 9. Accepting Warranty or Additional Liability. While redistributing 167 | the Work or Derivative Works thereof, You may choose to offer, 168 | and charge a fee for, acceptance of support, warranty, indemnity, 169 | or other liability obligations and/or rights consistent with this 170 | License. However, in accepting such obligations, You may act only 171 | on Your own behalf and on Your sole responsibility, not on behalf 172 | of any other Contributor, and only if You agree to indemnify, 173 | defend, and hold each Contributor harmless for any liability 174 | incurred by, or claims asserted against, such Contributor by reason 175 | of your accepting any such warranty or additional liability. 176 | 177 | END OF TERMS AND CONDITIONS 178 | 179 | APPENDIX: How to apply the Apache License to your work. 180 | 181 | To apply the Apache License to your work, attach the following 182 | boilerplate notice, with the fields enclosed by brackets "[]" 183 | replaced with your own identifying information. (Don't include 184 | the brackets!) The text should be enclosed in the appropriate 185 | comment syntax for the file format. We also recommend that a 186 | file or class name and description of purpose be included on the 187 | same "printed page" as the copyright notice for easier 188 | identification within third-party archives. 189 | 190 | Copyright [yyyy] [name of copyright owner] 191 | 192 | Licensed under the Apache License, Version 2.0 (the "License"); 193 | you may not use this file except in compliance with the License. 194 | You may obtain a copy of the License at 195 | 196 | http://www.apache.org/licenses/LICENSE-2.0 197 | 198 | Unless required by applicable law or agreed to in writing, software 199 | distributed under the License is distributed on an "AS IS" BASIS, 200 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 201 | See the License for the specific language governing permissions and 202 | limitations under the License. 203 | -------------------------------------------------------------------------------- /MANIFEST.in: -------------------------------------------------------------------------------- 1 | include LICENSE 2 | -------------------------------------------------------------------------------- /figs/.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tuvuumass/task-transferability/88ac7e11b7d2befb6e049d1276f275c8a23ae3a0/figs/.DS_Store -------------------------------------------------------------------------------- /figs/fig1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tuvuumass/task-transferability/88ac7e11b7d2befb6e049d1276f275c8a23ae3a0/figs/fig1.png -------------------------------------------------------------------------------- /figs/fig2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tuvuumass/task-transferability/88ac7e11b7d2befb6e049d1276f275c8a23ae3a0/figs/fig2.png -------------------------------------------------------------------------------- /figs/fig3a.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tuvuumass/task-transferability/88ac7e11b7d2befb6e049d1276f275c8a23ae3a0/figs/fig3a.png -------------------------------------------------------------------------------- /figs/fig3b.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tuvuumass/task-transferability/88ac7e11b7d2befb6e049d1276f275c8a23ae3a0/figs/fig3b.png -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | # progress bars in model download and training scripts 2 | tqdm 3 | # Accessing files from S3 directly. 4 | boto3 5 | # Used for downloading models over HTTP 6 | requests 7 | # Other packages 8 | numpy 9 | torch 10 | tensorboardX 11 | scikit-learn 12 | seqeval 13 | -------------------------------------------------------------------------------- /run_textemb_CR.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | """ Compute TextEmb for classification/regression tasks.""" 3 | 4 | from __future__ import absolute_import, division, print_function 5 | 6 | import argparse 7 | import glob 8 | import logging 9 | import os 10 | import random 11 | import json 12 | 13 | import numpy as np 14 | import torch 15 | from torch.utils.data import (DataLoader, RandomSampler, SequentialSampler, 16 | TensorDataset, Subset) 17 | from torch.utils.data.distributed import DistributedSampler 18 | 19 | try: 20 | from torch.utils.tensorboard import SummaryWriter 21 | except: 22 | from tensorboardX import SummaryWriter 23 | 24 | from tqdm import tqdm, trange 25 | 26 | from transformers import (WEIGHTS_NAME, BertConfig, BertModel, BertTokenizer) 27 | 28 | from transformers import AdamW, get_linear_schedule_with_warmup 29 | 30 | from transformers import glue_compute_metrics as compute_metrics 31 | from transformers import glue_output_modes as output_modes 32 | from transformers import glue_processors as processors 33 | from transformers import glue_convert_examples_to_features as convert_examples_to_features 34 | 35 | logger = logging.getLogger(__name__) 36 | 37 | ALL_MODELS = sum((tuple(conf.pretrained_config_archive_map.keys()) for conf in (BertConfig, )), ()) 38 | 39 | MODEL_CLASSES = { 40 | 'bert': (BertConfig, BertModel, BertTokenizer) 41 | } 42 | 43 | 44 | def set_seed(args): 45 | random.seed(args.seed) 46 | np.random.seed(args.seed) 47 | torch.manual_seed(args.seed) 48 | if args.n_gpu > 0: 49 | torch.cuda.manual_seed_all(args.seed) 50 | 51 | 52 | def compute_textemb(args, train_dataset, model): 53 | """ Train the model """ 54 | tb_writer = SummaryWriter() 55 | 56 | args.train_batch_size = args.per_gpu_train_batch_size * max(1, args.n_gpu) 57 | train_sampler = SequentialSampler(train_dataset) 58 | train_dataloader = DataLoader(train_dataset, sampler=train_sampler, batch_size=args.train_batch_size) 59 | 60 | # multi-gpu training (should be after apex fp16 initialization) 61 | if args.n_gpu > 1: 62 | model = torch.nn.DataParallel(model) 63 | 64 | logger.info("***** Compute TextEmb *****") 65 | logger.info("Num examples = %d", len(train_dataset)) 66 | logger.info("Batch size = %d", args.train_batch_size) 67 | 68 | model.zero_grad() 69 | train_iterator = trange(int(args.num_train_epochs), desc="Epoch", disable=False) 70 | set_seed(args) # Added here for reproductibility (even between python 2 and 3) 71 | 72 | total_num_examples = 0 73 | global_feature_dict = {} 74 | for _ in train_iterator: 75 | num_examples = 0 76 | epoch_iterator = tqdm(train_dataloader, desc="Iteration", disable=False) 77 | for step, batch in enumerate(epoch_iterator): 78 | model.eval() 79 | batch = tuple(t.to(args.device) for t in batch) 80 | with torch.no_grad(): 81 | inputs = {'input_ids': batch[0], 82 | 'attention_mask': batch[1], 83 | 'token_type_ids': batch[2]} 84 | input_mask = inputs['attention_mask'] 85 | outputs = model(**inputs) 86 | sequence_output = outputs[0] # batch_size x max_seq_length x hidden_size 87 | # pooled_output = outputs[1] # batch_size x hidden_size 88 | 89 | active_sequence_output = torch.einsum("ijk,ij->ijk",[sequence_output, input_mask]) 90 | avg_sequence_output = active_sequence_output.sum(1) / input_mask.sum(dim=1).view(input_mask.size(0),1) 91 | 92 | if len(global_feature_dict) == 0: 93 | global_feature_dict["avg_sequence_output"] = avg_sequence_output.sum(dim=0).detach().cpu().numpy() 94 | # global_feature_dict["pooled_output"] = pooled_output.sum(dim=0).detach().cpu().numpy() 95 | else: 96 | global_feature_dict["avg_sequence_output"] += avg_sequence_output.sum(dim=0).detach().cpu().numpy() 97 | # global_feature_dict["pooled_output"] += pooled_output.sum(dim=0).detach().cpu().numpy() 98 | 99 | num_examples += input_mask.size(0) 100 | total_num_examples += num_examples 101 | 102 | # Normalize 103 | for key in global_feature_dict: 104 | global_feature_dict[key] = global_feature_dict[key] / total_num_examples 105 | 106 | # Save features 107 | for key in global_feature_dict: 108 | np.save(os.path.join(args.output_dir, '{}.npy'.format(key)), global_feature_dict[key]) 109 | 110 | tb_writer.close() 111 | 112 | 113 | def load_and_cache_examples(args, task, tokenizer, evaluate=False): 114 | processor = processors[task]() 115 | output_mode = output_modes[task] 116 | # Load data features from cache or dataset file 117 | cached_features_file = os.path.join(args.data_dir, 'cached_{}_{}_{}_{}'.format( 118 | 'dev' if evaluate else 'train', 119 | list(filter(None, args.model_name_or_path.split('/'))).pop(), 120 | str(args.max_seq_length), 121 | str(task))) 122 | if os.path.exists(cached_features_file) and not args.overwrite_cache: 123 | logger.info("Loading features from cached file %s", cached_features_file) 124 | features = torch.load(cached_features_file) 125 | else: 126 | logger.info("Creating features from dataset file at %s", args.data_dir) 127 | label_list = processor.get_labels() 128 | examples = processor.get_dev_examples(args.data_dir) if evaluate else processor.get_train_examples( 129 | args.data_dir) 130 | features = convert_examples_to_features(examples, 131 | tokenizer, 132 | label_list=label_list, 133 | max_length=args.max_seq_length, 134 | output_mode=output_mode, 135 | pad_on_left=False, 136 | pad_token=tokenizer.convert_tokens_to_ids([tokenizer.pad_token])[0], 137 | pad_token_segment_id=0, 138 | ) 139 | logger.info("Saving features into cached file %s", cached_features_file) 140 | torch.save(features, cached_features_file) 141 | 142 | # Convert to Tensors and build dataset 143 | all_input_ids = torch.tensor([f.input_ids for f in features], dtype=torch.long) 144 | all_attention_mask = torch.tensor([f.attention_mask for f in features], dtype=torch.long) 145 | all_token_type_ids = torch.tensor([f.token_type_ids for f in features], dtype=torch.long) 146 | if output_mode == "classification": 147 | all_labels = torch.tensor([f.label for f in features], dtype=torch.long) 148 | elif output_mode == "regression": 149 | all_labels = torch.tensor([f.label for f in features], dtype=torch.float) 150 | 151 | dataset = TensorDataset(all_input_ids, all_attention_mask, all_token_type_ids, all_labels) 152 | return dataset 153 | 154 | 155 | def main(): 156 | parser = argparse.ArgumentParser() 157 | 158 | ## Required parameters 159 | parser.add_argument("--data_dir", default=None, type=str, required=True, 160 | help="The input data dir. Should contain the .tsv files (or other data files) for the task.") 161 | parser.add_argument("--model_type", default=None, type=str, required=True, 162 | help="Model type selected in the list: " + ", ".join(MODEL_CLASSES.keys())) 163 | parser.add_argument("--model_name_or_path", default=None, type=str, required=True, 164 | help="Path to pre-trained model or shortcut name selected in the list: " + ", ".join( 165 | ALL_MODELS)) 166 | parser.add_argument("--task_name", default=None, type=str, required=True, 167 | help="The name of the task to train selected in the list: " + ", ".join(processors.keys())) 168 | parser.add_argument("--output_dir", default=None, type=str, required=True, 169 | help="The output directory where the model predictions and checkpoints will be written.") 170 | 171 | ## Other parameters 172 | parser.add_argument("--config_name", default="", type=str, 173 | help="Pretrained config name or path if not the same as model_name") 174 | parser.add_argument("--tokenizer_name", default="", type=str, 175 | help="Pretrained tokenizer name or path if not the same as model_name") 176 | parser.add_argument("--cache_dir", default="", type=str, 177 | help="Where do you want to store the pre-trained models downloaded from s3") 178 | parser.add_argument("--train_data_subset", type=int, default=-1, 179 | help="If > 0: limit the training data to a subset of train_data_subset instances.") 180 | parser.add_argument("--max_seq_length", default=128, type=int, 181 | help="The maximum total input sequence length after tokenization. Sequences longer " 182 | "than this will be truncated, sequences shorter will be padded.") 183 | 184 | parser.add_argument("--do_lower_case", action='store_true', 185 | help="Set this flag if you are using an uncased model.") 186 | parser.add_argument("--per_gpu_train_batch_size", default=8, type=int, 187 | help="Batch size per GPU/CPU for training.") 188 | parser.add_argument("--num_train_epochs", default=3.0, type=float, 189 | help="Total number of training epochs to perform.") 190 | 191 | parser.add_argument('--logging_steps', type=int, default=50, 192 | help="Log every X updates steps.") 193 | parser.add_argument("--no_cuda", action='store_true', 194 | help="Avoid using CUDA when available") 195 | parser.add_argument('--overwrite_output_dir', action='store_true', 196 | help="Overwrite the content of the output directory") 197 | parser.add_argument('--overwrite_cache', action='store_true', 198 | help="Overwrite the cached training and evaluation sets") 199 | parser.add_argument('--seed', type=int, default=42, 200 | help="random seed for initialization") 201 | 202 | args = parser.parse_args() 203 | 204 | if os.path.exists(args.output_dir) and os.listdir( 205 | args.output_dir) and not args.overwrite_output_dir: 206 | raise ValueError( 207 | "Output directory ({}) already exists and is not empty. Use --overwrite_output_dir to overcome.".format( 208 | args.output_dir)) 209 | 210 | # Create output directory if needed 211 | if not os.path.exists(args.output_dir): 212 | os.makedirs(args.output_dir) 213 | with open(os.path.join(args.output_dir, 'run_args.txt'), 'w') as f: 214 | f.write(json.dumps(args.__dict__, indent=2)) 215 | f.close() 216 | 217 | device = torch.device("cuda" if torch.cuda.is_available() and not args.no_cuda else "cpu") 218 | args.n_gpu = torch.cuda.device_count() 219 | args.device = device 220 | 221 | # Setup logging 222 | logging.basicConfig(format='%(asctime)s - %(levelname)s - %(name)s - %(message)s', 223 | datefmt='%m/%d/%Y %H:%M:%S', 224 | level=logging.INFO) 225 | logger.warning("Device: %s, n_gpu: %s", device, args.n_gpu) 226 | 227 | # Set seed 228 | set_seed(args) 229 | 230 | # Prepare GLUE task 231 | args.task_name = args.task_name.lower() 232 | if args.task_name not in processors: 233 | raise ValueError("Task not found: %s" % (args.task_name)) 234 | processor = processors[args.task_name]() 235 | args.output_mode = output_modes[args.task_name] 236 | label_list = processor.get_labels() 237 | num_labels = len(label_list) 238 | 239 | args.model_type = args.model_type.lower() 240 | config_class, model_class, tokenizer_class = MODEL_CLASSES[args.model_type] 241 | config = config_class.from_pretrained(args.config_name if args.config_name else args.model_name_or_path, 242 | num_labels=num_labels, 243 | finetuning_task=args.task_name, 244 | cache_dir=args.cache_dir if args.cache_dir else None) 245 | tokenizer = tokenizer_class.from_pretrained(args.tokenizer_name if args.tokenizer_name else args.model_name_or_path, 246 | do_lower_case=args.do_lower_case, 247 | cache_dir=args.cache_dir if args.cache_dir else None) 248 | model = model_class.from_pretrained(args.model_name_or_path, 249 | from_tf=bool('.ckpt' in args.model_name_or_path), 250 | config=config, 251 | cache_dir=args.cache_dir if args.cache_dir else None) 252 | 253 | model.to(args.device) 254 | 255 | logger.info("Training/evaluation parameters %s", args) 256 | 257 | 258 | train_dataset = load_and_cache_examples(args, args.task_name, tokenizer, evaluate=False) 259 | if args.train_data_subset > 0: 260 | train_dataset = Subset(train_dataset, list(range(min(args.train_data_subset, len(train_dataset))))) 261 | compute_textemb(args, train_dataset, model) 262 | 263 | if __name__ == "__main__": 264 | main() 265 | -------------------------------------------------------------------------------- /run_textemb_QA.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | """ Compute TextEmb for for question answering tasks.""" 3 | 4 | from __future__ import absolute_import, division, print_function 5 | 6 | import argparse 7 | import logging 8 | import os 9 | import random 10 | import glob 11 | import timeit 12 | import json 13 | 14 | import numpy as np 15 | import torch 16 | from torch.utils.data import (DataLoader, RandomSampler, SequentialSampler, 17 | TensorDataset, Subset) 18 | from torch.utils.data.distributed import DistributedSampler 19 | 20 | try: 21 | from torch.utils.tensorboard import SummaryWriter 22 | except: 23 | from tensorboardX import SummaryWriter 24 | 25 | from tqdm import tqdm, trange 26 | 27 | from transformers import (WEIGHTS_NAME, BertConfig, BertModel, BertTokenizer) 28 | 29 | from transformers import AdamW, get_linear_schedule_with_warmup 30 | 31 | from utils_squad import (read_squad_examples, convert_examples_to_features) 32 | 33 | # The follwing import is the official SQuAD evaluation script (2.0). 34 | # You can remove it from the dependencies if you are using this script outside of the library 35 | # We've added it here for automated tests (see examples/test_examples.py file) 36 | from utils_squad_evaluate import EVAL_OPTS, main as evaluate_on_squad 37 | 38 | logger = logging.getLogger(__name__) 39 | 40 | ALL_MODELS = sum((tuple(conf.pretrained_config_archive_map.keys()) \ 41 | for conf in (BertConfig, )), ()) 42 | 43 | MODEL_CLASSES = { 44 | 'bert': (BertConfig, BertModel, BertTokenizer) 45 | } 46 | 47 | 48 | def set_seed(args): 49 | random.seed(args.seed) 50 | np.random.seed(args.seed) 51 | torch.manual_seed(args.seed) 52 | if args.n_gpu > 0: 53 | torch.cuda.manual_seed_all(args.seed) 54 | 55 | 56 | def to_list(tensor): 57 | return tensor.detach().cpu().tolist() 58 | 59 | 60 | def compute_textemb(args, train_dataset, model): 61 | 62 | tb_writer = SummaryWriter() 63 | 64 | args.train_batch_size = args.per_gpu_train_batch_size * max(1, args.n_gpu) 65 | train_sampler = SequentialSampler(train_dataset) 66 | train_dataloader = DataLoader(train_dataset, sampler=train_sampler, batch_size=args.train_batch_size) 67 | 68 | # multi-gpu training (should be after apex fp16 initialization) 69 | if args.n_gpu > 1: 70 | model = torch.nn.DataParallel(model) 71 | 72 | logger.info("***** Compute TextEmb *****") 73 | logger.info("Num examples = %d", len(train_dataset)) 74 | logger.info("Batch size = %d", args.train_batch_size) 75 | 76 | model.zero_grad() 77 | train_iterator = trange(int(args.num_train_epochs), desc="Epoch", disable=False) 78 | set_seed(args) # Added here for reproductibility (even between python 2 and 3) 79 | 80 | total_num_examples = 0 81 | global_feature_dict = {} 82 | for _ in train_iterator: 83 | num_examples = 0 84 | epoch_iterator = tqdm(train_dataloader, desc="Iteration", disable=False) 85 | for step, batch in enumerate(epoch_iterator): 86 | model.eval() 87 | batch = tuple(t.to(args.device) for t in batch) 88 | 89 | with torch.no_grad(): 90 | inputs = {'input_ids': batch[0], 91 | 'attention_mask': batch[1], 92 | 'token_type_ids': batch[2]} 93 | 94 | input_mask = inputs['attention_mask'] 95 | outputs = model(**inputs) 96 | sequence_output = outputs[0] # batch_size x max_seq_length x hidden_size 97 | # pooled_output = outputs[1] # batch_size x hidden_size 98 | 99 | active_sequence_output = torch.einsum("ijk,ij->ijk", [sequence_output, input_mask]) 100 | avg_sequence_output = active_sequence_output.sum(1) / input_mask.sum(dim=1).view(input_mask.size(0), 1) 101 | 102 | if len(global_feature_dict) == 0: 103 | global_feature_dict["avg_sequence_output"] = avg_sequence_output.sum(dim=0).detach().cpu().numpy() 104 | # global_feature_dict["pooled_output"] = pooled_output.sum(dim=0).detach().cpu().numpy() 105 | else: 106 | global_feature_dict["avg_sequence_output"] += avg_sequence_output.sum(dim=0).detach().cpu().numpy() 107 | # global_feature_dict["pooled_output"] += pooled_output.sum(dim=0).detach().cpu().numpy() 108 | 109 | num_examples += input_mask.size(0) 110 | total_num_examples += num_examples 111 | 112 | # Normalize 113 | for key in global_feature_dict: 114 | global_feature_dict[key] = global_feature_dict[key] / total_num_examples 115 | 116 | # Save features 117 | for key in global_feature_dict: 118 | np.save(os.path.join(args.output_dir, '{}.npy'.format(key)), global_feature_dict[key]) 119 | 120 | tb_writer.close() 121 | 122 | 123 | def load_and_cache_examples(args, tokenizer, evaluate=False, output_examples=False): 124 | # Load data features from cache or dataset file 125 | input_file = args.train_file 126 | cached_features_file = os.path.join(os.path.dirname(input_file), 'cached_{}_{}_{}'.format( 127 | 'dev' if evaluate else 'train', 128 | list(filter(None, args.model_name_or_path.split('/'))).pop(), 129 | str(args.max_seq_length))) 130 | if os.path.exists(cached_features_file) and not args.overwrite_cache and not output_examples: 131 | logger.info("Loading features from cached file %s", cached_features_file) 132 | features = torch.load(cached_features_file) 133 | else: 134 | logger.info("Creating features from dataset file at %s", input_file) 135 | examples = read_squad_examples(input_file=input_file, 136 | is_training=not evaluate, 137 | version_2_with_negative=args.version_2_with_negative) 138 | features = convert_examples_to_features(examples=examples, 139 | tokenizer=tokenizer, 140 | max_seq_length=args.max_seq_length, 141 | doc_stride=args.doc_stride, 142 | max_query_length=args.max_query_length, 143 | is_training=not evaluate, 144 | cls_token_segment_id=0, 145 | pad_token_segment_id=0, 146 | cls_token_at_end=False, 147 | sequence_a_is_doc=False) 148 | 149 | logger.info("Saving features into cached file %s", cached_features_file) 150 | torch.save(features, cached_features_file) 151 | 152 | # Convert to Tensors and build dataset 153 | all_input_ids = torch.tensor([f.input_ids for f in features], dtype=torch.long) 154 | all_input_mask = torch.tensor([f.input_mask for f in features], dtype=torch.long) 155 | all_segment_ids = torch.tensor([f.segment_ids for f in features], dtype=torch.long) 156 | all_cls_index = torch.tensor([f.cls_index for f in features], dtype=torch.long) 157 | all_p_mask = torch.tensor([f.p_mask for f in features], dtype=torch.float) 158 | if evaluate: 159 | all_example_index = torch.arange(all_input_ids.size(0), dtype=torch.long) 160 | dataset = TensorDataset(all_input_ids, all_input_mask, all_segment_ids, 161 | all_example_index, all_cls_index, all_p_mask) 162 | else: 163 | all_start_positions = torch.tensor([f.start_position for f in features], dtype=torch.long) 164 | all_end_positions = torch.tensor([f.end_position for f in features], dtype=torch.long) 165 | dataset = TensorDataset(all_input_ids, all_input_mask, all_segment_ids, 166 | all_start_positions, all_end_positions, 167 | all_cls_index, all_p_mask) 168 | 169 | if output_examples: 170 | return dataset, examples, features 171 | return dataset 172 | 173 | 174 | def main(): 175 | parser = argparse.ArgumentParser() 176 | 177 | parser.add_argument("--train_file", default=None, type=str, required=True, 178 | help="SQuAD json for training. E.g., train-v1.1.json") 179 | parser.add_argument("--model_type", default=None, type=str, required=True, 180 | help="Model type selected in the list: " + ", ".join(MODEL_CLASSES.keys())) 181 | parser.add_argument("--model_name_or_path", default=None, type=str, required=True, 182 | help="Path to pre-trained model or shortcut name selected in the list: " + ", ".join(ALL_MODELS)) 183 | parser.add_argument("--output_dir", default=None, type=str, required=True, 184 | help="The output directory where the model checkpoints and predictions will be written.") 185 | 186 | ## Other parameters 187 | parser.add_argument("--config_name", default="", type=str, 188 | help="Pretrained config name or path if not the same as model_name") 189 | parser.add_argument("--tokenizer_name", default="", type=str, 190 | help="Pretrained tokenizer name or path if not the same as model_name") 191 | parser.add_argument("--cache_dir", default="", type=str, 192 | help="Where do you want to store the pre-trained models downloaded from s3") 193 | 194 | parser.add_argument('--version_2_with_negative', default=False, type=eval, 195 | help='If true, the SQuAD examples contain some that do not have an answer.') 196 | parser.add_argument('--null_score_diff_threshold', type=float, default=0.0, 197 | help="If null_score - best_non_null is greater than the threshold predict null.") 198 | parser.add_argument("--train_data_subset", type=int, default=-1, 199 | help="If > 0: limit the training data to a subset of train_data_subset instances.") 200 | parser.add_argument("--max_seq_length", default=384, type=int, 201 | help="The maximum total input sequence length after WordPiece tokenization. Sequences " 202 | "longer than this will be truncated, and sequences shorter than this will be padded.") 203 | parser.add_argument("--doc_stride", default=128, type=int, 204 | help="When splitting up a long document into chunks, how much stride to take between chunks.") 205 | parser.add_argument("--max_query_length", default=64, type=int, 206 | help="The maximum number of tokens for the question. Questions longer than this will " 207 | "be truncated to this length.") 208 | parser.add_argument("--do_lower_case", action='store_true', 209 | help="Set this flag if you are using an uncased model.") 210 | 211 | parser.add_argument("--per_gpu_train_batch_size", default=8, type=int, 212 | help="Batch size per GPU/CPU for training.") 213 | parser.add_argument("--num_train_epochs", default=3.0, type=float, 214 | help="Total number of training epochs to perform.") 215 | parser.add_argument("--n_best_size", default=20, type=int, 216 | help="The total number of n-best predictions to generate in the nbest_predictions.json output file.") 217 | parser.add_argument("--max_answer_length", default=30, type=int, 218 | help="The maximum length of an answer that can be generated. This is needed because the start " 219 | "and end predictions are not conditioned on one another.") 220 | parser.add_argument("--verbose_logging", action='store_true', 221 | help="If true, all of the warnings related to data processing will be printed. " 222 | "A number of warnings are expected for a normal SQuAD evaluation.") 223 | 224 | parser.add_argument("--no_cuda", action='store_true', 225 | help="Whether not to use CUDA when available") 226 | parser.add_argument('--overwrite_output_dir', action='store_true', 227 | help="Overwrite the content of the output directory") 228 | parser.add_argument('--overwrite_cache', action='store_true', 229 | help="Overwrite the cached training and evaluation sets") 230 | parser.add_argument('--seed', type=int, default=42, 231 | help="random seed for initialization") 232 | 233 | args = parser.parse_args() 234 | 235 | if os.path.exists(args.output_dir) and os.listdir(args.output_dir) and not args.overwrite_output_dir: 236 | raise ValueError("Output directory ({}) already exists and is not empty. Use --overwrite_output_dir to overcome.".format(args.output_dir)) 237 | 238 | # Create output directory if needed 239 | if not os.path.exists(args.output_dir): 240 | os.makedirs(args.output_dir) 241 | with open(os.path.join(args.output_dir, 'run_args.txt'), 'w') as f: 242 | f.write(json.dumps(args.__dict__, indent=2)) 243 | f.close() 244 | 245 | device = torch.device("cuda" if torch.cuda.is_available() and not args.no_cuda else "cpu") 246 | args.n_gpu = torch.cuda.device_count() 247 | args.device = device 248 | 249 | # Setup logging 250 | logging.basicConfig(format='%(asctime)s - %(levelname)s - %(name)s - %(message)s', 251 | datefmt='%m/%d/%Y %H:%M:%S', 252 | level=logging.INFO) 253 | logger.warning("Device: %s, n_gpu: %s", device, args.n_gpu) 254 | 255 | # Set seed 256 | set_seed(args) 257 | 258 | args.model_type = args.model_type.lower() 259 | config_class, model_class, tokenizer_class = MODEL_CLASSES[args.model_type] 260 | config = config_class.from_pretrained(args.config_name if args.config_name else args.model_name_or_path, 261 | cache_dir=args.cache_dir if args.cache_dir else None) 262 | tokenizer = tokenizer_class.from_pretrained(args.tokenizer_name if args.tokenizer_name else args.model_name_or_path, 263 | do_lower_case=args.do_lower_case, 264 | cache_dir=args.cache_dir if args.cache_dir else None) 265 | model = model_class.from_pretrained(args.model_name_or_path, 266 | from_tf=bool('.ckpt' in args.model_name_or_path), 267 | config=config, 268 | cache_dir=args.cache_dir if args.cache_dir else None) 269 | 270 | model.to(args.device) 271 | 272 | logger.info("Training/evaluation parameters %s", args) 273 | 274 | 275 | train_dataset = load_and_cache_examples(args, tokenizer, evaluate=False, output_examples=False) 276 | if args.train_data_subset > 0: 277 | train_dataset = Subset(train_dataset, list(range(min(args.train_data_subset, len(train_dataset))))) 278 | compute_textemb(args, train_dataset, model) 279 | 280 | 281 | if __name__ == "__main__": 282 | main() 283 | -------------------------------------------------------------------------------- /run_textemb_SL.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | """ Finetuning the library models for sequence labeling tasks (Bert, XLM, XLNet, RoBERTa).""" 3 | 4 | from __future__ import absolute_import, division, print_function 5 | 6 | import argparse 7 | import glob 8 | import logging 9 | import os 10 | import random 11 | import json 12 | 13 | import numpy as np 14 | import torch 15 | from seqeval.metrics import precision_score, recall_score, f1_score 16 | from tensorboardX import SummaryWriter 17 | from torch.nn import CrossEntropyLoss 18 | from torch.utils.data import DataLoader, RandomSampler, SequentialSampler, TensorDataset, Subset 19 | from torch.utils.data.distributed import DistributedSampler 20 | from tqdm import tqdm, trange 21 | from utils_ner import convert_examples_to_features, get_labels, read_examples_from_file 22 | 23 | from transformers import AdamW, get_linear_schedule_with_warmup 24 | from transformers import WEIGHTS_NAME, BertConfig, BertModel, BertTokenizer 25 | 26 | logger = logging.getLogger(__name__) 27 | 28 | ALL_MODELS = sum( 29 | (tuple(conf.pretrained_config_archive_map.keys()) for conf in (BertConfig,)), 30 | ()) 31 | 32 | MODEL_CLASSES = { 33 | "bert": (BertConfig, BertModel, BertTokenizer), 34 | } 35 | 36 | 37 | def set_seed(args): 38 | random.seed(args.seed) 39 | np.random.seed(args.seed) 40 | torch.manual_seed(args.seed) 41 | if args.n_gpu > 0: 42 | torch.cuda.manual_seed_all(args.seed) 43 | 44 | 45 | def run_feature_extractor(args, train_dataset, model): 46 | """ Train the model """ 47 | tb_writer = SummaryWriter() 48 | 49 | args.train_batch_size = args.per_gpu_train_batch_size * max(1, args.n_gpu) 50 | train_sampler = SequentialSampler(train_dataset) 51 | train_dataloader = DataLoader(train_dataset, sampler=train_sampler, batch_size=args.train_batch_size) 52 | 53 | # multi-gpu training (should be after apex fp16 initialization) 54 | if args.n_gpu > 1: 55 | model = torch.nn.DataParallel(model) 56 | 57 | logger.info("***** Compute TextEmb *****") 58 | logger.info("Num examples = %d", len(train_dataset)) 59 | logger.info("Batch size = %d", args.train_batch_size) 60 | 61 | model.zero_grad() 62 | train_iterator = trange(int(args.num_train_epochs), desc="Epoch", disable=False) 63 | set_seed(args) # Added here for reproductibility (even between python 2 and 3) 64 | 65 | total_num_examples = 0 66 | global_feature_dict = {} 67 | for _ in train_iterator: 68 | num_examples = 0 69 | epoch_iterator = tqdm(train_dataloader, desc="Iteration", disable=False) 70 | for step, batch in enumerate(epoch_iterator): 71 | model.eval() 72 | batch = tuple(t.to(args.device) for t in batch) 73 | 74 | with torch.no_grad(): 75 | inputs = {"input_ids": batch[0], 76 | "attention_mask": batch[1], 77 | "token_type_ids": batch[2] 78 | } 79 | input_mask = inputs['attention_mask'] 80 | outputs = model(**inputs) 81 | sequence_output = outputs[0] # batch_size x max_seq_length x hidden_size 82 | pooled_output = outputs[1] # batch_size x hidden_size 83 | 84 | active_sequence_output = torch.einsum("ijk,ij->ijk", [sequence_output, input_mask]) 85 | avg_sequence_output = active_sequence_output.sum(1) / input_mask.sum(dim=1).view(input_mask.size(0), 1) 86 | 87 | if len(global_feature_dict) == 0: 88 | global_feature_dict["avg_sequence_output"] = avg_sequence_output.sum(dim=0).detach().cpu().numpy() 89 | global_feature_dict["pooled_output"] = pooled_output.sum(dim=0).detach().cpu().numpy() 90 | else: 91 | global_feature_dict["avg_sequence_output"] += avg_sequence_output.sum(dim=0).detach().cpu().numpy() 92 | global_feature_dict["pooled_output"] += pooled_output.sum(dim=0).detach().cpu().numpy() 93 | 94 | num_examples += input_mask.size(0) 95 | total_num_examples += num_examples 96 | 97 | # Normalize 98 | for key in global_feature_dict: 99 | global_feature_dict[key] = global_feature_dict[key] / total_num_examples 100 | 101 | # Save features 102 | for key in global_feature_dict: 103 | np.save(os.path.join(args.output_dir, '{}.npy'.format(key)), global_feature_dict[key]) 104 | 105 | tb_writer.close() 106 | 107 | 108 | def load_and_cache_examples(args, tokenizer, labels, pad_token_label_id, mode): 109 | # Load data features from cache or dataset file 110 | cached_features_file = os.path.join(args.data_dir, "cached_{}_{}_{}".format(mode, 111 | list(filter(None, args.model_name_or_path.split("/"))).pop(), 112 | str(args.max_seq_length))) 113 | if os.path.exists(cached_features_file) and not args.overwrite_cache: 114 | logger.info("Loading features from cached file %s", cached_features_file) 115 | features = torch.load(cached_features_file) 116 | else: 117 | logger.info("Creating features from dataset file at %s", args.data_dir) 118 | examples = read_examples_from_file(args.data_dir, mode) 119 | features = convert_examples_to_features(examples, labels, args.max_seq_length, tokenizer, 120 | cls_token_at_end=bool(args.model_type in ["xlnet"]), 121 | # xlnet has a cls token at the end 122 | cls_token=tokenizer.cls_token, 123 | cls_token_segment_id=2 if args.model_type in ["xlnet"] else 0, 124 | sep_token=tokenizer.sep_token, 125 | sep_token_extra=bool(args.model_type in ["roberta"]), 126 | # roberta uses an extra separator b/w pairs of sentences, cf. github.com/pytorch/fairseq/commit/1684e166e3da03f5b600dbb7855cb98ddfcd0805 127 | pad_on_left=bool(args.model_type in ["xlnet"]), 128 | # pad on the left for xlnet 129 | pad_token=tokenizer.convert_tokens_to_ids([tokenizer.pad_token])[0], 130 | pad_token_segment_id=4 if args.model_type in ["xlnet"] else 0, 131 | pad_token_label_id=pad_token_label_id 132 | ) 133 | logger.info("Saving features into cached file %s", cached_features_file) 134 | torch.save(features, cached_features_file) 135 | 136 | # Convert to Tensors and build dataset 137 | all_input_ids = torch.tensor([f.input_ids for f in features], dtype=torch.long) 138 | all_input_mask = torch.tensor([f.input_mask for f in features], dtype=torch.long) 139 | all_segment_ids = torch.tensor([f.segment_ids for f in features], dtype=torch.long) 140 | all_label_ids = torch.tensor([f.label_ids for f in features], dtype=torch.long) 141 | 142 | dataset = TensorDataset(all_input_ids, all_input_mask, all_segment_ids, all_label_ids) 143 | return dataset 144 | 145 | 146 | def main(): 147 | parser = argparse.ArgumentParser() 148 | 149 | ## Required parameters 150 | parser.add_argument("--data_dir", default=None, type=str, required=True, 151 | help="The input data dir. Should contain the training files for the CoNLL-2003 NER task.") 152 | parser.add_argument("--model_type", default=None, type=str, required=True, 153 | help="Model type selected in the list: " + ", ".join(MODEL_CLASSES.keys())) 154 | parser.add_argument("--model_name_or_path", default=None, type=str, required=True, 155 | help="Path to pre-trained model or shortcut name selected in the list: " + ", ".join(ALL_MODELS)) 156 | parser.add_argument("--output_dir", default=None, type=str, required=True, 157 | help="The output directory where the model predictions and checkpoints will be written.") 158 | 159 | ## Other parameters 160 | parser.add_argument("--labels", default="", type=str, 161 | help="Path to a file containing all labels. If not specified, CoNLL-2003 labels are used.") 162 | parser.add_argument("--config_name", default="", type=str, 163 | help="Pretrained config name or path if not the same as model_name") 164 | parser.add_argument("--tokenizer_name", default="", type=str, 165 | help="Pretrained tokenizer name or path if not the same as model_name") 166 | parser.add_argument("--cache_dir", default="", type=str, 167 | help="Where do you want to store the pre-trained models downloaded from s3") 168 | parser.add_argument("--train_data_subset", type=int, default=-1, 169 | help="If > 0: limit the training data to a subset of train_data_subset instances.") 170 | parser.add_argument("--eval_data_subset", type=int, default=-1, 171 | help="If > 0: limit the evaluation data to a subset of eval_data_subset instances.") 172 | parser.add_argument("--max_seq_length", default=128, type=int, 173 | help="The maximum total input sequence length after tokenization. Sequences longer " 174 | "than this will be truncated, sequences shorter will be padded.") 175 | parser.add_argument("--do_lower_case", action="store_true", 176 | help="Set this flag if you are using an uncased model.") 177 | 178 | parser.add_argument("--per_gpu_train_batch_size", default=8, type=int, 179 | help="Batch size per GPU/CPU for training.") 180 | parser.add_argument("--num_train_epochs", default=3.0, type=float, 181 | help="Total number of training epochs to perform.") 182 | 183 | parser.add_argument("--no_cuda", action="store_true", 184 | help="Avoid using CUDA when available") 185 | parser.add_argument("--overwrite_output_dir", action="store_true", 186 | help="Overwrite the content of the output directory") 187 | parser.add_argument("--overwrite_cache", action="store_true", 188 | help="Overwrite the cached training and evaluation sets") 189 | parser.add_argument("--seed", type=int, default=42, 190 | help="random seed for initialization") 191 | 192 | args = parser.parse_args() 193 | 194 | if os.path.exists(args.output_dir) and os.listdir( 195 | args.output_dir) and not args.overwrite_output_dir: 196 | raise ValueError( 197 | "Output directory ({}) already exists and is not empty. Use --overwrite_output_dir to overcome.".format( 198 | args.output_dir)) 199 | 200 | # Create output directory if needed 201 | if not os.path.exists(args.output_dir): 202 | os.makedirs(args.output_dir) 203 | with open(os.path.join(args.output_dir, 'run_args.txt'), 'w') as f: 204 | f.write(json.dumps(args.__dict__, indent=2)) 205 | f.close() 206 | 207 | device = torch.device("cuda" if torch.cuda.is_available() and not args.no_cuda else "cpu") 208 | args.n_gpu = torch.cuda.device_count() 209 | args.device = device 210 | 211 | # Setup logging 212 | logging.basicConfig(format='%(asctime)s - %(levelname)s - %(name)s - %(message)s', 213 | datefmt='%m/%d/%Y %H:%M:%S', 214 | level=logging.INFO) 215 | logger.warning("Device: %s, n_gpu: %s", device, args.n_gpu) 216 | 217 | # Set seed 218 | set_seed(args) 219 | 220 | # Prepare CONLL-2003 task 221 | labels = get_labels(args.labels) 222 | num_labels = len(labels) 223 | # Use cross entropy ignore index as padding label id so that only real label ids contribute to the loss later 224 | pad_token_label_id = CrossEntropyLoss().ignore_index 225 | 226 | args.model_type = args.model_type.lower() 227 | config_class, model_class, tokenizer_class = MODEL_CLASSES[args.model_type] 228 | config = config_class.from_pretrained(args.config_name if args.config_name else args.model_name_or_path, 229 | num_labels=num_labels, 230 | cache_dir=args.cache_dir if args.cache_dir else None) 231 | tokenizer = tokenizer_class.from_pretrained(args.tokenizer_name if args.tokenizer_name else args.model_name_or_path, 232 | do_lower_case=args.do_lower_case, 233 | cache_dir=args.cache_dir if args.cache_dir else None) 234 | model = model_class.from_pretrained(args.model_name_or_path, 235 | from_tf=bool(".ckpt" in args.model_name_or_path), 236 | config=config, 237 | cache_dir=args.cache_dir if args.cache_dir else None) 238 | 239 | model.to(args.device) 240 | 241 | logger.info("Training/evaluation parameters %s", args) 242 | 243 | 244 | train_dataset = load_and_cache_examples(args, tokenizer, labels, pad_token_label_id, mode="train") 245 | if args.train_data_subset > 0: 246 | train_dataset = Subset(train_dataset, list(range(min(args.train_data_subset, len(train_dataset))))) 247 | run_feature_extractor(args, train_dataset, model) 248 | 249 | 250 | if __name__ == "__main__": 251 | main() 252 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | """ 2 | Simple check list from AllenNLP repo: https://github.com/allenai/allennlp/blob/master/setup.py 3 | 4 | To create the package for pypi. 5 | 6 | 1. Change the version in __init__.py, setup.py as well as docs/source/conf.py. 7 | 8 | 2. Commit these changes with the message: "Release: VERSION" 9 | 10 | 3. Add a tag in git to mark the release: "git tag VERSION -m'Adds tag VERSION for pypi' " 11 | Push the tag to git: git push --tags origin master 12 | 13 | 4. Build both the sources and the wheel. Do not change anything in setup.py between 14 | creating the wheel and the source distribution (obviously). 15 | 16 | For the wheel, run: "python setup.py bdist_wheel" in the top level directory. 17 | (this will build a wheel for the python version you use to build it - make sure you use python 3.x). 18 | 19 | For the sources, run: "python setup.py sdist" 20 | You should now have a /dist directory with both .whl and .tar.gz source versions. 21 | 22 | 5. Check that everything looks correct by uploading the package to the pypi test server: 23 | 24 | twine upload dist/* -r pypitest 25 | (pypi suggest using twine as other methods upload files via plaintext.) 26 | 27 | Check that you can install it in a virtualenv by running: 28 | pip install -i https://testpypi.python.org/pypi transformers 29 | 30 | 6. Upload the final version to actual pypi: 31 | twine upload dist/* -r pypi 32 | 33 | 7. Copy the release notes from RELEASE.md to the tag in github once everything is looking hunky-dory. 34 | 35 | """ 36 | from io import open 37 | from setuptools import find_packages, setup 38 | 39 | setup( 40 | name="transformers", 41 | version="2.1.1", 42 | author="Thomas Wolf, Lysandre Debut, Victor Sanh, Julien Chaumond, Google AI Language Team Authors, Open AI team Authors, Facebook AI Authors, Carnegie Mellon University Authors", 43 | author_email="thomas@huggingface.co", 44 | description="State-of-the-art Natural Language Processing for TensorFlow 2.0 and PyTorch", 45 | long_description=open("README.md", "r", encoding='utf-8').read(), 46 | long_description_content_type="text/markdown", 47 | keywords='NLP deep learning transformer pytorch tensorflow BERT', 48 | license='Apache', 49 | url="https://github.com/huggingface/transformers", 50 | packages=find_packages(exclude=["*.tests", "*.tests.*", 51 | "tests.*", "tests"]), 52 | install_requires=['numpy', 53 | 'boto3', 54 | 'requests', 55 | 'tqdm',], 56 | entry_points={ 57 | 'console_scripts': [ 58 | "transformers=transformers.__main__:main", 59 | ] 60 | }, 61 | # python_requires='>=3.5.0', 62 | tests_require=['pytest'], 63 | classifiers=[ 64 | 'Intended Audience :: Science/Research', 65 | 'License :: OSI Approved :: Apache Software License', 66 | 'Programming Language :: Python :: 3', 67 | 'Topic :: Scientific/Engineering :: Artificial Intelligence', 68 | ], 69 | ) 70 | -------------------------------------------------------------------------------- /transformers/.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tuvuumass/task-transferability/88ac7e11b7d2befb6e049d1276f275c8a23ae3a0/transformers/.DS_Store -------------------------------------------------------------------------------- /transformers/__init__.py: -------------------------------------------------------------------------------- 1 | __version__ = "2.1.1" 2 | 3 | # Work around to update TensorFlow's absl.logging threshold which alters the 4 | # default Python logging output behavior when present. 5 | # see: https://github.com/abseil/abseil-py/issues/99 6 | # and: https://github.com/tensorflow/tensorflow/issues/26691#issuecomment-500369493 7 | try: 8 | import absl.logging 9 | absl.logging.set_verbosity('info') 10 | absl.logging.set_stderrthreshold('info') 11 | absl.logging._warn_preinit_stderr = False 12 | except: 13 | pass 14 | 15 | import logging 16 | 17 | logger = logging.getLogger(__name__) # pylint: disable=invalid-name 18 | 19 | # Files and general utilities 20 | from .file_utils import (TRANSFORMERS_CACHE, PYTORCH_TRANSFORMERS_CACHE, PYTORCH_PRETRAINED_BERT_CACHE, 21 | cached_path, add_start_docstrings, add_end_docstrings, 22 | WEIGHTS_NAME, TF2_WEIGHTS_NAME, TF_WEIGHTS_NAME, CONFIG_NAME, 23 | is_tf_available, is_torch_available) 24 | 25 | from .data import (is_sklearn_available, 26 | InputExample, InputFeatures, DataProcessor, 27 | glue_output_modes, glue_convert_examples_to_features, 28 | glue_processors, glue_tasks_num_labels) 29 | 30 | if is_sklearn_available(): 31 | from .data import glue_compute_metrics 32 | 33 | # Tokenizers 34 | from .tokenization_utils import (PreTrainedTokenizer) 35 | from .tokenization_auto import AutoTokenizer 36 | from .tokenization_bert import BertTokenizer, BasicTokenizer, WordpieceTokenizer 37 | 38 | # Configurations 39 | from .configuration_utils import PretrainedConfig 40 | from .configuration_auto import AutoConfig 41 | from .configuration_bert import BertConfig, BERT_PRETRAINED_CONFIG_ARCHIVE_MAP 42 | 43 | # Modeling 44 | if is_torch_available(): 45 | from .modeling_utils import (PreTrainedModel, prune_layer, Conv1D) 46 | from .modeling_auto import (AutoModel, AutoModelForSequenceClassification, AutoModelForQuestionAnswering, 47 | AutoModelWithLMHead) 48 | 49 | from .modeling_bert import (BertPreTrainedModel, BertModel, BertForPreTraining, 50 | BertForMaskedLM, BertForNextSentencePrediction, 51 | BertForSequenceClassification, BertForMultipleChoice, 52 | BertForTokenClassification, BertForQuestionAnswering, 53 | load_tf_weights_in_bert, BERT_PRETRAINED_MODEL_ARCHIVE_MAP) 54 | 55 | # Task Embeddings 56 | from .modeling_task_embeddings import (BertConfig as BertConfig_TaskEmbeddings, 57 | BertForSequenceClassification as BertForSequenceClassification_TaskEmbeddings, 58 | BertForQuestionAnswering as BertForQuestionAnswering_TaskEmbeddings, 59 | BertForTokenClassification as BertForTokenClassification_TaskEmbeddings) 60 | 61 | # Optimization 62 | from .optimization import (AdamW, get_constant_schedule, get_constant_schedule_with_warmup, get_cosine_schedule_with_warmup, 63 | get_cosine_with_hard_restarts_schedule_with_warmup, get_linear_schedule_with_warmup) 64 | 65 | 66 | # TensorFlow 67 | if is_tf_available(): 68 | from .modeling_tf_utils import TFPreTrainedModel, TFSharedEmbeddings, TFSequenceSummary 69 | from .modeling_tf_auto import (TFAutoModel, TFAutoModelForSequenceClassification, TFAutoModelForQuestionAnswering, 70 | TFAutoModelWithLMHead) 71 | 72 | from .modeling_tf_bert import (TFBertPreTrainedModel, TFBertMainLayer, TFBertEmbeddings, 73 | TFBertModel, TFBertForPreTraining, 74 | TFBertForMaskedLM, TFBertForNextSentencePrediction, 75 | TFBertForSequenceClassification, TFBertForMultipleChoice, 76 | TFBertForTokenClassification, TFBertForQuestionAnswering, 77 | TF_BERT_PRETRAINED_MODEL_ARCHIVE_MAP) 78 | 79 | 80 | # TF 2.0 <=> PyTorch conversion utilities 81 | from .modeling_tf_pytorch_utils import (convert_tf_weight_name_to_pt_weight_name, 82 | load_pytorch_checkpoint_in_tf2_model, 83 | load_pytorch_weights_in_tf2_model, 84 | load_pytorch_model_in_tf2_model, 85 | load_tf2_checkpoint_in_pytorch_model, 86 | load_tf2_weights_in_pytorch_model, 87 | load_tf2_model_in_pytorch_model) 88 | 89 | if not is_tf_available() and not is_torch_available(): 90 | logger.warning("Neither PyTorch nor TensorFlow >= 2.0 have been found." 91 | "Models won't be available and only tokenizers, configuration" 92 | "and file/data utilities can be used.") 93 | 94 | -------------------------------------------------------------------------------- /transformers/__main__.py: -------------------------------------------------------------------------------- 1 | # coding: utf8 2 | def main(): 3 | import sys 4 | if (len(sys.argv) < 4 or len(sys.argv) > 6) or sys.argv[1] not in ["bert"]: 5 | print( 6 | "This command line utility let you convert original (author released) model checkpoint to pytorch.\n" 7 | "It should be used as one of: \n" 8 | ">> transformers bert TF_CHECKPOINT TF_CONFIG PYTORCH_DUMP_OUTPUT, \n") 9 | else: 10 | if sys.argv[1] == "bert": 11 | try: 12 | from .convert_bert_original_tf_checkpoint_to_pytorch import convert_tf_checkpoint_to_pytorch 13 | except ImportError: 14 | print("transformers can only be used from the commandline to convert TensorFlow models in PyTorch, " 15 | "In that case, it requires TensorFlow to be installed. Please see " 16 | "https://www.tensorflow.org/install/ for installation instructions.") 17 | raise 18 | 19 | if len(sys.argv) != 5: 20 | # pylint: disable=line-too-long 21 | print("Should be used as `transformers bert TF_CHECKPOINT TF_CONFIG PYTORCH_DUMP_OUTPUT`") 22 | else: 23 | PYTORCH_DUMP_OUTPUT = sys.argv.pop() 24 | TF_CONFIG = sys.argv.pop() 25 | TF_CHECKPOINT = sys.argv.pop() 26 | convert_tf_checkpoint_to_pytorch(TF_CHECKPOINT, TF_CONFIG, PYTORCH_DUMP_OUTPUT) 27 | 28 | 29 | if __name__ == '__main__': 30 | main() 31 | -------------------------------------------------------------------------------- /transformers/configuration_auto.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2018 The HuggingFace Inc. team. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | """ Auto Model class. """ 16 | 17 | from __future__ import absolute_import, division, print_function, unicode_literals 18 | 19 | import logging 20 | 21 | from .configuration_bert import BertConfig 22 | 23 | logger = logging.getLogger(__name__) 24 | 25 | 26 | class AutoConfig(object): 27 | r""":class:`~transformers.AutoConfig` is a generic configuration class 28 | that will be instantiated as one of the configuration classes of the library 29 | when created with the `AutoConfig.from_pretrained(pretrained_model_name_or_path)` 30 | class method. 31 | 32 | The `from_pretrained()` method take care of returning the correct model class instance 33 | using pattern matching on the `pretrained_model_name_or_path` string. 34 | 35 | The base model class to instantiate is selected as the first pattern matching 36 | in the `pretrained_model_name_or_path` string (in the following order): 37 | - contains `bert`: BertConfig (Bert model) 38 | This class cannot be instantiated using `__init__()` (throw an error). 39 | """ 40 | def __init__(self): 41 | raise EnvironmentError("AutoConfig is designed to be instantiated " 42 | "using the `AutoConfig.from_pretrained(pretrained_model_name_or_path)` method.") 43 | 44 | @classmethod 45 | def from_pretrained(cls, pretrained_model_name_or_path, **kwargs): 46 | r""" Instantiate a one of the configuration classes of the library 47 | from a pre-trained model configuration. 48 | 49 | The configuration class to instantiate is selected as the first pattern matching 50 | in the `pretrained_model_name_or_path` string (in the following order): 51 | - contains `bert`: BertConfig (Bert model) 52 | Params: 53 | pretrained_model_name_or_path: either: 54 | 55 | - a string with the `shortcut name` of a pre-trained model configuration to load from cache or download, e.g.: ``bert-base-uncased``. 56 | - a path to a `directory` containing a configuration file saved using the :func:`~transformers.PretrainedConfig.save_pretrained` method, e.g.: ``./my_model_directory/``. 57 | - a path or url to a saved configuration JSON `file`, e.g.: ``./my_model_directory/configuration.json``. 58 | 59 | cache_dir: (`optional`) string: 60 | Path to a directory in which a downloaded pre-trained model 61 | configuration should be cached if the standard cache should not be used. 62 | 63 | kwargs: (`optional`) dict: key/value pairs with which to update the configuration object after loading. 64 | 65 | - The values in kwargs of any keys which are configuration attributes will be used to override the loaded values. 66 | - Behavior concerning key/value pairs whose keys are *not* configuration attributes is controlled by the `return_unused_kwargs` keyword parameter. 67 | 68 | force_download: (`optional`) boolean, default False: 69 | Force to (re-)download the model weights and configuration files and override the cached versions if they exists. 70 | 71 | proxies: (`optional`) dict, default None: 72 | A dictionary of proxy servers to use by protocol or endpoint, e.g.: {'http': 'foo.bar:3128', 'http://hostname': 'foo.bar:4012'}. 73 | The proxies are used on each request. 74 | 75 | return_unused_kwargs: (`optional`) bool: 76 | 77 | - If False, then this function returns just the final configuration object. 78 | - If True, then this functions returns a tuple `(config, unused_kwargs)` where `unused_kwargs` is a dictionary consisting of the key/value pairs whose keys are not configuration attributes: ie the part of kwargs which has not been used to update `config` and is otherwise ignored. 79 | 80 | Examples:: 81 | 82 | config = AutoConfig.from_pretrained('bert-base-uncased') # Download configuration from S3 and cache. 83 | config = AutoConfig.from_pretrained('./test/bert_saved_model/') # E.g. config (or model) was saved using `save_pretrained('./test/saved_model/')` 84 | config = AutoConfig.from_pretrained('./test/bert_saved_model/my_configuration.json') 85 | config = AutoConfig.from_pretrained('bert-base-uncased', output_attention=True, foo=False) 86 | assert config.output_attention == True 87 | config, unused_kwargs = AutoConfig.from_pretrained('bert-base-uncased', output_attention=True, 88 | foo=False, return_unused_kwargs=True) 89 | assert config.output_attention == True 90 | assert unused_kwargs == {'foo': False} 91 | 92 | """ 93 | if 'bert' in pretrained_model_name_or_path: 94 | return BertConfig.from_pretrained(pretrained_model_name_or_path, **kwargs) 95 | raise ValueError("Unrecognized model identifier in {}. Should contains one of " 96 | "'bert',".format(pretrained_model_name_or_path)) 97 | -------------------------------------------------------------------------------- /transformers/configuration_bert.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2018 The Google AI Language Team Authors and The HuggingFace Inc. team. 3 | # Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved. 4 | # 5 | # Licensed under the Apache License, Version 2.0 (the "License"); 6 | # you may not use this file except in compliance with the License. 7 | # You may obtain a copy of the License at 8 | # 9 | # http://www.apache.org/licenses/LICENSE-2.0 10 | # 11 | # Unless required by applicable law or agreed to in writing, software 12 | # distributed under the License is distributed on an "AS IS" BASIS, 13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | # See the License for the specific language governing permissions and 15 | # limitations under the License. 16 | """ BERT model configuration """ 17 | 18 | from __future__ import absolute_import, division, print_function, unicode_literals 19 | 20 | import json 21 | import logging 22 | import sys 23 | from io import open 24 | 25 | from .configuration_utils import PretrainedConfig 26 | 27 | logger = logging.getLogger(__name__) 28 | 29 | BERT_PRETRAINED_CONFIG_ARCHIVE_MAP = { 30 | 'bert-base-uncased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-uncased-config.json", 31 | 'bert-large-uncased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-uncased-config.json", 32 | 'bert-base-cased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-cased-config.json", 33 | 'bert-large-cased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-cased-config.json", 34 | 'bert-base-multilingual-uncased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-multilingual-uncased-config.json", 35 | 'bert-base-multilingual-cased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-multilingual-cased-config.json", 36 | 'bert-base-chinese': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-chinese-config.json", 37 | 'bert-base-german-cased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-german-cased-config.json", 38 | 'bert-large-uncased-whole-word-masking': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-uncased-whole-word-masking-config.json", 39 | 'bert-large-cased-whole-word-masking': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-cased-whole-word-masking-config.json", 40 | 'bert-large-uncased-whole-word-masking-finetuned-squad': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-uncased-whole-word-masking-finetuned-squad-config.json", 41 | 'bert-large-cased-whole-word-masking-finetuned-squad': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-cased-whole-word-masking-finetuned-squad-config.json", 42 | 'bert-base-cased-finetuned-mrpc': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-cased-finetuned-mrpc-config.json", 43 | 'bert-base-german-dbmdz-cased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-german-dbmdz-cased-config.json", 44 | 'bert-base-german-dbmdz-uncased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-german-dbmdz-uncased-config.json", 45 | } 46 | 47 | 48 | class BertConfig(PretrainedConfig): 49 | r""" 50 | :class:`~transformers.BertConfig` is the configuration class to store the configuration of a 51 | `BertModel`. 52 | 53 | 54 | Arguments: 55 | vocab_size_or_config_json_file: Vocabulary size of `inputs_ids` in `BertModel`. 56 | hidden_size: Size of the encoder layers and the pooler layer. 57 | num_hidden_layers: Number of hidden layers in the Transformer encoder. 58 | num_attention_heads: Number of attention heads for each attention layer in 59 | the Transformer encoder. 60 | intermediate_size: The size of the "intermediate" (i.e., feed-forward) 61 | layer in the Transformer encoder. 62 | hidden_act: The non-linear activation function (function or string) in the 63 | encoder and pooler. If string, "gelu", "relu", "swish" and "gelu_new" are supported. 64 | hidden_dropout_prob: The dropout probabilitiy for all fully connected 65 | layers in the embeddings, encoder, and pooler. 66 | attention_probs_dropout_prob: The dropout ratio for the attention 67 | probabilities. 68 | max_position_embeddings: The maximum sequence length that this model might 69 | ever be used with. Typically set this to something large just in case 70 | (e.g., 512 or 1024 or 2048). 71 | type_vocab_size: The vocabulary size of the `token_type_ids` passed into 72 | `BertModel`. 73 | initializer_range: The sttdev of the truncated_normal_initializer for 74 | initializing all weight matrices. 75 | layer_norm_eps: The epsilon used by LayerNorm. 76 | """ 77 | pretrained_config_archive_map = BERT_PRETRAINED_CONFIG_ARCHIVE_MAP 78 | 79 | def __init__(self, 80 | vocab_size_or_config_json_file=30522, 81 | hidden_size=768, 82 | num_hidden_layers=12, 83 | num_attention_heads=12, 84 | intermediate_size=3072, 85 | hidden_act="gelu", 86 | hidden_dropout_prob=0.1, 87 | attention_probs_dropout_prob=0.1, 88 | max_position_embeddings=512, 89 | type_vocab_size=2, 90 | initializer_range=0.02, 91 | layer_norm_eps=1e-12, 92 | **kwargs): 93 | super(BertConfig, self).__init__(**kwargs) 94 | if isinstance(vocab_size_or_config_json_file, str) or (sys.version_info[0] == 2 95 | and isinstance(vocab_size_or_config_json_file, unicode)): 96 | with open(vocab_size_or_config_json_file, "r", encoding='utf-8') as reader: 97 | json_config = json.loads(reader.read()) 98 | for key, value in json_config.items(): 99 | self.__dict__[key] = value 100 | elif isinstance(vocab_size_or_config_json_file, int): 101 | self.vocab_size = vocab_size_or_config_json_file 102 | self.hidden_size = hidden_size 103 | self.num_hidden_layers = num_hidden_layers 104 | self.num_attention_heads = num_attention_heads 105 | self.hidden_act = hidden_act 106 | self.intermediate_size = intermediate_size 107 | self.hidden_dropout_prob = hidden_dropout_prob 108 | self.attention_probs_dropout_prob = attention_probs_dropout_prob 109 | self.max_position_embeddings = max_position_embeddings 110 | self.type_vocab_size = type_vocab_size 111 | self.initializer_range = initializer_range 112 | self.layer_norm_eps = layer_norm_eps 113 | else: 114 | raise ValueError("First argument must be either a vocabulary size (int)" 115 | " or the path to a pretrained model config file (str)") 116 | -------------------------------------------------------------------------------- /transformers/configuration_utils.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2018 The Google AI Language Team Authors and The HuggingFace Inc. team. 3 | # Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved. 4 | # 5 | # Licensed under the Apache License, Version 2.0 (the "License"); 6 | # you may not use this file except in compliance with the License. 7 | # You may obtain a copy of the License at 8 | # 9 | # http://www.apache.org/licenses/LICENSE-2.0 10 | # 11 | # Unless required by applicable law or agreed to in writing, software 12 | # distributed under the License is distributed on an "AS IS" BASIS, 13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | # See the License for the specific language governing permissions and 15 | # limitations under the License. 16 | """ Configuration base class and utilities.""" 17 | 18 | from __future__ import (absolute_import, division, print_function, 19 | unicode_literals) 20 | 21 | import copy 22 | import json 23 | import logging 24 | import os 25 | from io import open 26 | 27 | from .file_utils import cached_path, CONFIG_NAME 28 | 29 | logger = logging.getLogger(__name__) 30 | 31 | class PretrainedConfig(object): 32 | r""" Base class for all configuration classes. 33 | Handles a few parameters common to all models' configurations as well as methods for loading/downloading/saving configurations. 34 | 35 | Note: 36 | A configuration file can be loaded and saved to disk. Loading the configuration file and using this file to initialize a model does **not** load the model weights. 37 | It only affects the model's configuration. 38 | 39 | Class attributes (overridden by derived classes): 40 | - ``pretrained_config_archive_map``: a python ``dict`` of with `short-cut-names` (string) as keys and `url` (string) of associated pretrained model configurations as values. 41 | 42 | Parameters: 43 | ``finetuning_task``: string, default `None`. Name of the task used to fine-tune the model. This can be used when converting from an original (TensorFlow or PyTorch) checkpoint. 44 | ``num_labels``: integer, default `2`. Number of classes to use when the model is a classification model (sequences/tokens) 45 | ``output_attentions``: boolean, default `False`. Should the model returns attentions weights. 46 | ``output_hidden_states``: string, default `False`. Should the model returns all hidden-states. 47 | ``torchscript``: string, default `False`. Is the model used with Torchscript. 48 | """ 49 | pretrained_config_archive_map = {} 50 | 51 | def __init__(self, **kwargs): 52 | self.finetuning_task = kwargs.pop('finetuning_task', None) 53 | self.num_labels = kwargs.pop('num_labels', 2) 54 | self.output_attentions = kwargs.pop('output_attentions', False) 55 | self.output_hidden_states = kwargs.pop('output_hidden_states', False) 56 | self.output_past = kwargs.pop('output_past', True) # Not used by all models 57 | self.torchscript = kwargs.pop('torchscript', False) # Only used by PyTorch models 58 | self.use_bfloat16 = kwargs.pop('use_bfloat16', False) 59 | self.pruned_heads = kwargs.pop('pruned_heads', {}) 60 | self.is_decoder = kwargs.pop('is_decoder', False) 61 | 62 | # attributes for task embeddings 63 | self.num_softmax_classifiers = kwargs.pop('num_softmax_classifiers', 1) 64 | self.retain_gradients = kwargs.pop('retain_gradients', True) 65 | self.do_pooling = kwargs.pop('do_pooling', True) 66 | 67 | def save_pretrained(self, save_directory): 68 | """ Save a configuration object to the directory `save_directory`, so that it 69 | can be re-loaded using the :func:`~transformers.PretrainedConfig.from_pretrained` class method. 70 | """ 71 | assert os.path.isdir(save_directory), "Saving path should be a directory where the model and configuration can be saved" 72 | 73 | # If we save using the predefined names, we can load using `from_pretrained` 74 | output_config_file = os.path.join(save_directory, CONFIG_NAME) 75 | 76 | self.to_json_file(output_config_file) 77 | logger.info("Configuration saved in {}".format(output_config_file)) 78 | 79 | @classmethod 80 | def from_pretrained(cls, pretrained_model_name_or_path, **kwargs): 81 | r""" Instantiate a :class:`~transformers.PretrainedConfig` (or a derived class) from a pre-trained model configuration. 82 | 83 | Parameters: 84 | pretrained_model_name_or_path: either: 85 | 86 | - a string with the `shortcut name` of a pre-trained model configuration to load from cache or download, e.g.: ``bert-base-uncased``. 87 | - a path to a `directory` containing a configuration file saved using the :func:`~transformers.PretrainedConfig.save_pretrained` method, e.g.: ``./my_model_directory/``. 88 | - a path or url to a saved configuration JSON `file`, e.g.: ``./my_model_directory/configuration.json``. 89 | 90 | cache_dir: (`optional`) string: 91 | Path to a directory in which a downloaded pre-trained model 92 | configuration should be cached if the standard cache should not be used. 93 | 94 | kwargs: (`optional`) dict: key/value pairs with which to update the configuration object after loading. 95 | 96 | - The values in kwargs of any keys which are configuration attributes will be used to override the loaded values. 97 | - Behavior concerning key/value pairs whose keys are *not* configuration attributes is controlled by the `return_unused_kwargs` keyword parameter. 98 | 99 | force_download: (`optional`) boolean, default False: 100 | Force to (re-)download the model weights and configuration files and override the cached versions if they exists. 101 | 102 | proxies: (`optional`) dict, default None: 103 | A dictionary of proxy servers to use by protocol or endpoint, e.g.: {'http': 'foo.bar:3128', 'http://hostname': 'foo.bar:4012'}. 104 | The proxies are used on each request. 105 | 106 | return_unused_kwargs: (`optional`) bool: 107 | 108 | - If False, then this function returns just the final configuration object. 109 | - If True, then this functions returns a tuple `(config, unused_kwargs)` where `unused_kwargs` is a dictionary consisting of the key/value pairs whose keys are not configuration attributes: ie the part of kwargs which has not been used to update `config` and is otherwise ignored. 110 | 111 | Examples:: 112 | 113 | # We can't instantiate directly the base class `PretrainedConfig` so let's show the examples on a 114 | # derived class: BertConfig 115 | config = BertConfig.from_pretrained('bert-base-uncased') # Download configuration from S3 and cache. 116 | config = BertConfig.from_pretrained('./test/saved_model/') # E.g. config (or model) was saved using `save_pretrained('./test/saved_model/')` 117 | config = BertConfig.from_pretrained('./test/saved_model/my_configuration.json') 118 | config = BertConfig.from_pretrained('bert-base-uncased', output_attention=True, foo=False) 119 | assert config.output_attention == True 120 | config, unused_kwargs = BertConfig.from_pretrained('bert-base-uncased', output_attention=True, 121 | foo=False, return_unused_kwargs=True) 122 | assert config.output_attention == True 123 | assert unused_kwargs == {'foo': False} 124 | 125 | """ 126 | cache_dir = kwargs.pop('cache_dir', None) 127 | force_download = kwargs.pop('force_download', False) 128 | proxies = kwargs.pop('proxies', None) 129 | return_unused_kwargs = kwargs.pop('return_unused_kwargs', False) 130 | 131 | if pretrained_model_name_or_path in cls.pretrained_config_archive_map: 132 | config_file = cls.pretrained_config_archive_map[pretrained_model_name_or_path] 133 | elif os.path.isdir(pretrained_model_name_or_path): 134 | config_file = os.path.join(pretrained_model_name_or_path, CONFIG_NAME) 135 | else: 136 | config_file = pretrained_model_name_or_path 137 | # redirect to the cache, if necessary 138 | try: 139 | resolved_config_file = cached_path(config_file, cache_dir=cache_dir, force_download=force_download, proxies=proxies) 140 | except EnvironmentError: 141 | if pretrained_model_name_or_path in cls.pretrained_config_archive_map: 142 | msg = "Couldn't reach server at '{}' to download pretrained model configuration file.".format( 143 | config_file) 144 | else: 145 | msg = "Model name '{}' was not found in model name list ({}). " \ 146 | "We assumed '{}' was a path or url to a configuration file named {} or " \ 147 | "a directory containing such a file but couldn't find any such file at this path or url.".format( 148 | pretrained_model_name_or_path, 149 | ', '.join(cls.pretrained_config_archive_map.keys()), 150 | config_file, CONFIG_NAME) 151 | raise EnvironmentError(msg) 152 | 153 | if resolved_config_file == config_file: 154 | logger.info("loading configuration file {}".format(config_file)) 155 | else: 156 | logger.info("loading configuration file {} from cache at {}".format( 157 | config_file, resolved_config_file)) 158 | 159 | # Load config 160 | config = cls.from_json_file(resolved_config_file) 161 | 162 | if hasattr(config, 'pruned_heads'): 163 | config.pruned_heads = dict((int(key), value) for key, value in config.pruned_heads.items()) 164 | 165 | # Update config with kwargs if needed 166 | to_remove = [] 167 | for key, value in kwargs.items(): 168 | if hasattr(config, key): 169 | setattr(config, key, value) 170 | to_remove.append(key) 171 | for key in to_remove: 172 | kwargs.pop(key, None) 173 | 174 | logger.info("Model config %s", str(config)) 175 | if return_unused_kwargs: 176 | return config, kwargs 177 | else: 178 | return config 179 | 180 | @classmethod 181 | def from_dict(cls, json_object): 182 | """Constructs a `Config` from a Python dictionary of parameters.""" 183 | config = cls(vocab_size_or_config_json_file=-1) 184 | for key, value in json_object.items(): 185 | setattr(config, key, value) 186 | return config 187 | 188 | @classmethod 189 | def from_json_file(cls, json_file): 190 | """Constructs a `BertConfig` from a json file of parameters.""" 191 | with open(json_file, "r", encoding='utf-8') as reader: 192 | text = reader.read() 193 | return cls.from_dict(json.loads(text)) 194 | 195 | def __eq__(self, other): 196 | return self.__dict__ == other.__dict__ 197 | 198 | def __repr__(self): 199 | return str(self.to_json_string()) 200 | 201 | def to_dict(self): 202 | """Serializes this instance to a Python dictionary.""" 203 | output = copy.deepcopy(self.__dict__) 204 | return output 205 | 206 | def to_json_string(self): 207 | """Serializes this instance to a JSON string.""" 208 | return json.dumps(self.to_dict(), indent=2, sort_keys=True) + "\n" 209 | 210 | def to_json_file(self, json_file_path): 211 | """ Save this instance to a json file.""" 212 | with open(json_file_path, "w", encoding='utf-8') as writer: 213 | writer.write(self.to_json_string()) 214 | 215 | -------------------------------------------------------------------------------- /transformers/convert_bert_original_tf_checkpoint_to_pytorch.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2018 The HuggingFace Inc. team. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | """Convert BERT checkpoint.""" 16 | 17 | from __future__ import absolute_import 18 | from __future__ import division 19 | from __future__ import print_function 20 | 21 | import argparse 22 | import torch 23 | 24 | from transformers import BertConfig, BertForPreTraining, load_tf_weights_in_bert 25 | 26 | import logging 27 | logging.basicConfig(level=logging.INFO) 28 | 29 | def convert_tf_checkpoint_to_pytorch(tf_checkpoint_path, bert_config_file, pytorch_dump_path): 30 | # Initialise PyTorch model 31 | config = BertConfig.from_json_file(bert_config_file) 32 | print("Building PyTorch model from configuration: {}".format(str(config))) 33 | model = BertForPreTraining(config) 34 | 35 | # Load weights from tf checkpoint 36 | load_tf_weights_in_bert(model, config, tf_checkpoint_path) 37 | 38 | # Save pytorch-model 39 | print("Save PyTorch model to {}".format(pytorch_dump_path)) 40 | torch.save(model.state_dict(), pytorch_dump_path) 41 | 42 | 43 | if __name__ == "__main__": 44 | parser = argparse.ArgumentParser() 45 | ## Required parameters 46 | parser.add_argument("--tf_checkpoint_path", 47 | default = None, 48 | type = str, 49 | required = True, 50 | help = "Path to the TensorFlow checkpoint path.") 51 | parser.add_argument("--bert_config_file", 52 | default = None, 53 | type = str, 54 | required = True, 55 | help = "The config json file corresponding to the pre-trained BERT model. \n" 56 | "This specifies the model architecture.") 57 | parser.add_argument("--pytorch_dump_path", 58 | default = None, 59 | type = str, 60 | required = True, 61 | help = "Path to the output PyTorch model.") 62 | args = parser.parse_args() 63 | convert_tf_checkpoint_to_pytorch(args.tf_checkpoint_path, 64 | args.bert_config_file, 65 | args.pytorch_dump_path) 66 | -------------------------------------------------------------------------------- /transformers/convert_bert_pytorch_checkpoint_to_original_tf.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2018 The HuggingFace Inc. team. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | """Convert Huggingface Pytorch checkpoint to Tensorflow checkpoint.""" 17 | 18 | import os 19 | import argparse 20 | import torch 21 | import numpy as np 22 | import tensorflow as tf 23 | from transformers import BertModel 24 | 25 | 26 | def convert_pytorch_checkpoint_to_tf(model:BertModel, ckpt_dir:str, model_name:str): 27 | 28 | """ 29 | :param model:BertModel Pytorch model instance to be converted 30 | :param ckpt_dir: Tensorflow model directory 31 | :param model_name: model name 32 | :return: 33 | 34 | Currently supported HF models: 35 | Y BertModel 36 | N BertForMaskedLM 37 | N BertForPreTraining 38 | N BertForMultipleChoice 39 | N BertForNextSentencePrediction 40 | N BertForSequenceClassification 41 | N BertForQuestionAnswering 42 | """ 43 | 44 | tensors_to_transpose = ( 45 | "dense.weight", 46 | "attention.self.query", 47 | "attention.self.key", 48 | "attention.self.value" 49 | ) 50 | 51 | var_map = ( 52 | ('layer.', 'layer_'), 53 | ('word_embeddings.weight', 'word_embeddings'), 54 | ('position_embeddings.weight', 'position_embeddings'), 55 | ('token_type_embeddings.weight', 'token_type_embeddings'), 56 | ('.', '/'), 57 | ('LayerNorm/weight', 'LayerNorm/gamma'), 58 | ('LayerNorm/bias', 'LayerNorm/beta'), 59 | ('weight', 'kernel') 60 | ) 61 | 62 | if not os.path.isdir(ckpt_dir): 63 | os.makedirs(ckpt_dir) 64 | 65 | state_dict = model.state_dict() 66 | 67 | def to_tf_var_name(name:str): 68 | for patt, repl in iter(var_map): 69 | name = name.replace(patt, repl) 70 | return 'bert/{}'.format(name) 71 | 72 | def create_tf_var(tensor:np.ndarray, name:str, session:tf.Session): 73 | tf_dtype = tf.dtypes.as_dtype(tensor.dtype) 74 | tf_var = tf.get_variable(dtype=tf_dtype, shape=tensor.shape, name=name, initializer=tf.zeros_initializer()) 75 | session.run(tf.variables_initializer([tf_var])) 76 | session.run(tf_var) 77 | return tf_var 78 | 79 | tf.reset_default_graph() 80 | with tf.Session() as session: 81 | for var_name in state_dict: 82 | tf_name = to_tf_var_name(var_name) 83 | torch_tensor = state_dict[var_name].numpy() 84 | if any([x in var_name for x in tensors_to_transpose]): 85 | torch_tensor = torch_tensor.T 86 | tf_var = create_tf_var(tensor=torch_tensor, name=tf_name, session=session) 87 | tf.keras.backend.set_value(tf_var, torch_tensor) 88 | tf_weight = session.run(tf_var) 89 | print("Successfully created {}: {}".format(tf_name, np.allclose(tf_weight, torch_tensor))) 90 | 91 | saver = tf.train.Saver(tf.trainable_variables()) 92 | saver.save(session, os.path.join(ckpt_dir, model_name.replace("-", "_") + ".ckpt")) 93 | 94 | 95 | def main(raw_args=None): 96 | parser = argparse.ArgumentParser() 97 | parser.add_argument("--model_name", 98 | type=str, 99 | required=True, 100 | help="model name e.g. bert-base-uncased") 101 | parser.add_argument("--cache_dir", 102 | type=str, 103 | default=None, 104 | required=False, 105 | help="Directory containing pytorch model") 106 | parser.add_argument("--pytorch_model_path", 107 | type=str, 108 | required=True, 109 | help="/path/to/.bin") 110 | parser.add_argument("--tf_cache_dir", 111 | type=str, 112 | required=True, 113 | help="Directory in which to save tensorflow model") 114 | args = parser.parse_args(raw_args) 115 | 116 | model = BertModel.from_pretrained( 117 | pretrained_model_name_or_path=args.model_name, 118 | state_dict=torch.load(args.pytorch_model_path), 119 | cache_dir=args.cache_dir 120 | ) 121 | 122 | convert_pytorch_checkpoint_to_tf( 123 | model=model, 124 | ckpt_dir=args.tf_cache_dir, 125 | model_name=args.model_name 126 | ) 127 | 128 | 129 | if __name__ == "__main__": 130 | main() 131 | -------------------------------------------------------------------------------- /transformers/convert_pytorch_checkpoint_to_tf2.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2018 The HuggingFace Inc. team. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | """ Convert pytorch checkpoints to TensorFlow """ 16 | 17 | from __future__ import absolute_import 18 | from __future__ import division 19 | from __future__ import print_function 20 | 21 | import os 22 | import argparse 23 | import tensorflow as tf 24 | 25 | from transformers import is_torch_available, cached_path 26 | 27 | from transformers import (load_pytorch_checkpoint_in_tf2_model, 28 | BertConfig, TFBertForPreTraining, TFBertForQuestionAnswering, TFBertForSequenceClassification, BERT_PRETRAINED_CONFIG_ARCHIVE_MAP, 29 | GPT2Config, TFGPT2LMHeadModel, GPT2_PRETRAINED_CONFIG_ARCHIVE_MAP, 30 | XLNetConfig, TFXLNetLMHeadModel, XLNET_PRETRAINED_CONFIG_ARCHIVE_MAP, 31 | XLMConfig, TFXLMWithLMHeadModel, XLM_PRETRAINED_CONFIG_ARCHIVE_MAP, 32 | TransfoXLConfig, TFTransfoXLLMHeadModel, TRANSFO_XL_PRETRAINED_CONFIG_ARCHIVE_MAP, 33 | OpenAIGPTConfig, TFOpenAIGPTLMHeadModel, OPENAI_GPT_PRETRAINED_CONFIG_ARCHIVE_MAP, 34 | RobertaConfig, TFRobertaForMaskedLM, TFRobertaForSequenceClassification, ROBERTA_PRETRAINED_CONFIG_ARCHIVE_MAP, 35 | DistilBertConfig, TFDistilBertForMaskedLM, TFDistilBertForQuestionAnswering, DISTILBERT_PRETRAINED_CONFIG_ARCHIVE_MAP, 36 | CTRLConfig, TFCTRLLMHeadModel, CTRL_PRETRAINED_CONFIG_ARCHIVE_MAP) 37 | 38 | if is_torch_available(): 39 | import torch 40 | import numpy as np 41 | from transformers import (BertForPreTraining, BertForQuestionAnswering, BertForSequenceClassification, BERT_PRETRAINED_MODEL_ARCHIVE_MAP, 42 | GPT2LMHeadModel, GPT2_PRETRAINED_MODEL_ARCHIVE_MAP, 43 | XLNetLMHeadModel, XLNET_PRETRAINED_MODEL_ARCHIVE_MAP, 44 | XLMWithLMHeadModel, XLM_PRETRAINED_MODEL_ARCHIVE_MAP, 45 | TransfoXLLMHeadModel, TRANSFO_XL_PRETRAINED_MODEL_ARCHIVE_MAP, 46 | OpenAIGPTLMHeadModel, OPENAI_GPT_PRETRAINED_MODEL_ARCHIVE_MAP, 47 | RobertaForMaskedLM, RobertaForSequenceClassification, ROBERTA_PRETRAINED_MODEL_ARCHIVE_MAP, 48 | DistilBertForMaskedLM, DistilBertForQuestionAnswering, DISTILBERT_PRETRAINED_MODEL_ARCHIVE_MAP, 49 | CTRLLMHeadModel, CTRL_PRETRAINED_MODEL_ARCHIVE_MAP) 50 | else: 51 | (BertForPreTraining, BertForQuestionAnswering, BertForSequenceClassification, BERT_PRETRAINED_MODEL_ARCHIVE_MAP, 52 | GPT2LMHeadModel, GPT2_PRETRAINED_MODEL_ARCHIVE_MAP, 53 | XLNetLMHeadModel, XLNET_PRETRAINED_MODEL_ARCHIVE_MAP, 54 | XLMWithLMHeadModel, XLM_PRETRAINED_MODEL_ARCHIVE_MAP, 55 | TransfoXLLMHeadModel, TRANSFO_XL_PRETRAINED_MODEL_ARCHIVE_MAP, 56 | OpenAIGPTLMHeadModel, OPENAI_GPT_PRETRAINED_MODEL_ARCHIVE_MAP, 57 | RobertaForMaskedLM, RobertaForSequenceClassification, ROBERTA_PRETRAINED_MODEL_ARCHIVE_MAP, 58 | DistilBertForMaskedLM, DistilBertForQuestionAnswering, DISTILBERT_PRETRAINED_MODEL_ARCHIVE_MAP, 59 | CTRLLMHeadModel, CTRL_PRETRAINED_MODEL_ARCHIVE_MAP) = ( 60 | None, None, None, None, 61 | None, None, 62 | None, None, 63 | None, None, 64 | None, None, 65 | None, None, 66 | None, None, None, 67 | None, None, None, 68 | None, None) 69 | 70 | 71 | import logging 72 | logging.basicConfig(level=logging.INFO) 73 | 74 | MODEL_CLASSES = { 75 | 'bert': (BertConfig, TFBertForPreTraining, BertForPreTraining, BERT_PRETRAINED_MODEL_ARCHIVE_MAP, BERT_PRETRAINED_CONFIG_ARCHIVE_MAP), 76 | 'bert-large-uncased-whole-word-masking-finetuned-squad': (BertConfig, TFBertForQuestionAnswering, BertForQuestionAnswering, BERT_PRETRAINED_MODEL_ARCHIVE_MAP, BERT_PRETRAINED_CONFIG_ARCHIVE_MAP), 77 | 'bert-large-cased-whole-word-masking-finetuned-squad': (BertConfig, TFBertForQuestionAnswering, BertForQuestionAnswering, BERT_PRETRAINED_MODEL_ARCHIVE_MAP, BERT_PRETRAINED_CONFIG_ARCHIVE_MAP), 78 | 'bert-base-cased-finetuned-mrpc': (BertConfig, TFBertForSequenceClassification, BertForSequenceClassification, BERT_PRETRAINED_MODEL_ARCHIVE_MAP, BERT_PRETRAINED_CONFIG_ARCHIVE_MAP), 79 | 'gpt2': (GPT2Config, TFGPT2LMHeadModel, GPT2LMHeadModel, GPT2_PRETRAINED_MODEL_ARCHIVE_MAP, GPT2_PRETRAINED_CONFIG_ARCHIVE_MAP), 80 | 'xlnet': (XLNetConfig, TFXLNetLMHeadModel, XLNetLMHeadModel, XLNET_PRETRAINED_MODEL_ARCHIVE_MAP, XLNET_PRETRAINED_CONFIG_ARCHIVE_MAP), 81 | 'xlm': (XLMConfig, TFXLMWithLMHeadModel, XLMWithLMHeadModel, XLM_PRETRAINED_MODEL_ARCHIVE_MAP, XLM_PRETRAINED_CONFIG_ARCHIVE_MAP), 82 | 'transfo-xl': (TransfoXLConfig, TFTransfoXLLMHeadModel, TransfoXLLMHeadModel, TRANSFO_XL_PRETRAINED_MODEL_ARCHIVE_MAP, TRANSFO_XL_PRETRAINED_CONFIG_ARCHIVE_MAP), 83 | 'openai-gpt': (OpenAIGPTConfig, TFOpenAIGPTLMHeadModel, OpenAIGPTLMHeadModel, OPENAI_GPT_PRETRAINED_MODEL_ARCHIVE_MAP, OPENAI_GPT_PRETRAINED_CONFIG_ARCHIVE_MAP), 84 | 'roberta': (RobertaConfig, TFRobertaForMaskedLM, RobertaForMaskedLM, ROBERTA_PRETRAINED_MODEL_ARCHIVE_MAP, ROBERTA_PRETRAINED_CONFIG_ARCHIVE_MAP), 85 | 'roberta-large-mnli': (RobertaConfig, TFRobertaForSequenceClassification, RobertaForSequenceClassification, ROBERTA_PRETRAINED_MODEL_ARCHIVE_MAP, ROBERTA_PRETRAINED_CONFIG_ARCHIVE_MAP), 86 | 'distilbert': (DistilBertConfig, TFDistilBertForMaskedLM, DistilBertForMaskedLM, DISTILBERT_PRETRAINED_MODEL_ARCHIVE_MAP, DISTILBERT_PRETRAINED_CONFIG_ARCHIVE_MAP), 87 | 'distilbert-base-uncased-distilled-squad': (DistilBertConfig, TFDistilBertForQuestionAnswering, DistilBertForQuestionAnswering, DISTILBERT_PRETRAINED_MODEL_ARCHIVE_MAP, DISTILBERT_PRETRAINED_CONFIG_ARCHIVE_MAP), 88 | 'ctrl': (CTRLConfig, TFCTRLLMHeadModel, CTRLLMHeadModel, CTRL_PRETRAINED_MODEL_ARCHIVE_MAP, CTRL_PRETRAINED_CONFIG_ARCHIVE_MAP) 89 | } 90 | 91 | def convert_pt_checkpoint_to_tf(model_type, pytorch_checkpoint_path, config_file, tf_dump_path, compare_with_pt_model=False, use_cached_models=True): 92 | if model_type not in MODEL_CLASSES: 93 | raise ValueError("Unrecognized model type, should be one of {}.".format(list(MODEL_CLASSES.keys()))) 94 | 95 | config_class, model_class, pt_model_class, aws_model_maps, aws_config_map = MODEL_CLASSES[model_type] 96 | 97 | # Initialise TF model 98 | if config_file in aws_config_map: 99 | config_file = cached_path(aws_config_map[config_file], force_download=not use_cached_models) 100 | config = config_class.from_json_file(config_file) 101 | config.output_hidden_states = True 102 | config.output_attentions = True 103 | print("Building TensorFlow model from configuration: {}".format(str(config))) 104 | tf_model = model_class(config) 105 | 106 | # Load weights from tf checkpoint 107 | if pytorch_checkpoint_path in aws_model_maps: 108 | pytorch_checkpoint_path = cached_path(aws_model_maps[pytorch_checkpoint_path], force_download=not use_cached_models) 109 | # Load PyTorch checkpoint in tf2 model: 110 | tf_model = load_pytorch_checkpoint_in_tf2_model(tf_model, pytorch_checkpoint_path) 111 | 112 | if compare_with_pt_model: 113 | inputs_list = [[7, 6, 0, 0, 1], [1, 2, 3, 0, 0], [0, 0, 0, 4, 5]] 114 | tf_inputs = tf.constant(inputs_list) 115 | tfo = tf_model(tf_inputs, training=False) # build the network 116 | 117 | pt_model = pt_model_class.from_pretrained(None, 118 | config=config, 119 | state_dict=torch.load(pytorch_checkpoint_path, 120 | map_location='cpu')) 121 | pt_inputs = torch.tensor(inputs_list) 122 | with torch.no_grad(): 123 | pto = pt_model(pt_inputs) 124 | 125 | np_pt = pto[0].detach().numpy() 126 | np_tf = tfo[0].numpy() 127 | diff = np.amax(np.abs(np_pt - np_tf)) 128 | print("Max absolute difference between models outputs {}".format(diff)) 129 | assert diff <= 2e-2, "Error, model absolute difference is >2e-2" 130 | 131 | # Save pytorch-model 132 | print("Save TensorFlow model to {}".format(tf_dump_path)) 133 | tf_model.save_weights(tf_dump_path, save_format='h5') 134 | 135 | 136 | def convert_all_pt_checkpoints_to_tf(args_model_type, tf_dump_path, model_shortcut_names_or_path=None, config_shortcut_names_or_path=None, 137 | compare_with_pt_model=False, use_cached_models=False, only_convert_finetuned_models=False): 138 | assert os.path.isdir(args.tf_dump_path), "--tf_dump_path should be a directory" 139 | 140 | if args_model_type is None: 141 | model_types = list(MODEL_CLASSES.keys()) 142 | else: 143 | model_types = [args_model_type] 144 | 145 | for j, model_type in enumerate(model_types, start=1): 146 | print("=" * 100) 147 | print(" Converting model type {}/{}: {}".format(j, len(model_types), model_type)) 148 | print("=" * 100) 149 | if model_type not in MODEL_CLASSES: 150 | raise ValueError("Unrecognized model type {}, should be one of {}.".format(model_type, list(MODEL_CLASSES.keys()))) 151 | 152 | config_class, model_class, pt_model_class, aws_model_maps, aws_config_map = MODEL_CLASSES[model_type] 153 | 154 | if model_shortcut_names_or_path is None: 155 | model_shortcut_names_or_path = list(aws_model_maps.keys()) 156 | if config_shortcut_names_or_path is None: 157 | config_shortcut_names_or_path = model_shortcut_names_or_path 158 | 159 | for i, (model_shortcut_name, config_shortcut_name) in enumerate( 160 | zip(model_shortcut_names_or_path, config_shortcut_names_or_path), start=1): 161 | print("-" * 100) 162 | if '-squad' in model_shortcut_name or '-mrpc' in model_shortcut_name or '-mnli' in model_shortcut_name: 163 | if not only_convert_finetuned_models: 164 | print(" Skipping finetuned checkpoint {}".format(model_shortcut_name)) 165 | continue 166 | model_type = model_shortcut_name 167 | elif only_convert_finetuned_models: 168 | print(" Skipping not finetuned checkpoint {}".format(model_shortcut_name)) 169 | continue 170 | print(" Converting checkpoint {}/{}: {} - model_type {}".format(i, len(aws_config_map), model_shortcut_name, model_type)) 171 | print("-" * 100) 172 | 173 | if config_shortcut_name in aws_config_map: 174 | config_file = cached_path(aws_config_map[config_shortcut_name], force_download=not use_cached_models) 175 | else: 176 | config_file = cached_path(config_shortcut_name, force_download=not use_cached_models) 177 | 178 | if model_shortcut_name in aws_model_maps: 179 | model_file = cached_path(aws_model_maps[model_shortcut_name], force_download=not use_cached_models) 180 | else: 181 | model_file = cached_path(model_shortcut_name, force_download=not use_cached_models) 182 | 183 | if os.path.isfile(model_shortcut_name): 184 | model_shortcut_name = 'converted_model' 185 | convert_pt_checkpoint_to_tf(model_type=model_type, 186 | pytorch_checkpoint_path=model_file, 187 | config_file=config_file, 188 | tf_dump_path=os.path.join(tf_dump_path, model_shortcut_name + '-tf_model.h5'), 189 | compare_with_pt_model=compare_with_pt_model) 190 | os.remove(config_file) 191 | os.remove(model_file) 192 | 193 | 194 | if __name__ == "__main__": 195 | parser = argparse.ArgumentParser() 196 | ## Required parameters 197 | parser.add_argument("--tf_dump_path", 198 | default = None, 199 | type = str, 200 | required = True, 201 | help = "Path to the output Tensorflow dump file.") 202 | parser.add_argument("--model_type", 203 | default = None, 204 | type = str, 205 | help = "Model type selected in the list of {}. If not given, will download and convert all the models from AWS.".format(list(MODEL_CLASSES.keys()))) 206 | parser.add_argument("--pytorch_checkpoint_path", 207 | default = None, 208 | type = str, 209 | help = "Path to the PyTorch checkpoint path or shortcut name to download from AWS. " 210 | "If not given, will download and convert all the checkpoints from AWS.") 211 | parser.add_argument("--config_file", 212 | default = None, 213 | type = str, 214 | help = "The config json file corresponding to the pre-trained model. \n" 215 | "This specifies the model architecture. If not given and " 216 | "--pytorch_checkpoint_path is not given or is a shortcut name" 217 | "use the configuration associated to the shortcut name on the AWS") 218 | parser.add_argument("--compare_with_pt_model", 219 | action='store_true', 220 | help = "Compare Tensorflow and PyTorch model predictions.") 221 | parser.add_argument("--use_cached_models", 222 | action='store_true', 223 | help = "Use cached models if possible instead of updating to latest checkpoint versions.") 224 | parser.add_argument("--only_convert_finetuned_models", 225 | action='store_true', 226 | help = "Only convert finetuned models.") 227 | args = parser.parse_args() 228 | 229 | # if args.pytorch_checkpoint_path is not None: 230 | # convert_pt_checkpoint_to_tf(args.model_type.lower(), 231 | # args.pytorch_checkpoint_path, 232 | # args.config_file if args.config_file is not None else args.pytorch_checkpoint_path, 233 | # args.tf_dump_path, 234 | # compare_with_pt_model=args.compare_with_pt_model, 235 | # use_cached_models=args.use_cached_models) 236 | # else: 237 | convert_all_pt_checkpoints_to_tf(args.model_type.lower() if args.model_type is not None else None, 238 | args.tf_dump_path, 239 | model_shortcut_names_or_path=[args.pytorch_checkpoint_path] if args.pytorch_checkpoint_path is not None else None, 240 | config_shortcut_names_or_path=[args.config_file] if args.config_file is not None else None, 241 | compare_with_pt_model=args.compare_with_pt_model, 242 | use_cached_models=args.use_cached_models, 243 | only_convert_finetuned_models=args.only_convert_finetuned_models) 244 | -------------------------------------------------------------------------------- /transformers/data/.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tuvuumass/task-transferability/88ac7e11b7d2befb6e049d1276f275c8a23ae3a0/transformers/data/.DS_Store -------------------------------------------------------------------------------- /transformers/data/__init__.py: -------------------------------------------------------------------------------- 1 | from .processors import InputExample, InputFeatures, DataProcessor 2 | from .processors import glue_output_modes, glue_processors, glue_tasks_num_labels, glue_convert_examples_to_features 3 | 4 | from .metrics import is_sklearn_available 5 | if is_sklearn_available(): 6 | from .metrics import glue_compute_metrics 7 | -------------------------------------------------------------------------------- /transformers/data/metrics/.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tuvuumass/task-transferability/88ac7e11b7d2befb6e049d1276f275c8a23ae3a0/transformers/data/metrics/.DS_Store -------------------------------------------------------------------------------- /transformers/data/metrics/__init__.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2018 The Google AI Language Team Authors and The HuggingFace Inc. team. 3 | # Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved. 4 | # 5 | # Licensed under the Apache License, Version 2.0 (the "License"); 6 | # you may not use this file except in compliance with the License. 7 | # You may obtain a copy of the License at 8 | # 9 | # http://www.apache.org/licenses/LICENSE-2.0 10 | # 11 | # Unless required by applicable law or agreed to in writing, software 12 | # distributed under the License is distributed on an "AS IS" BASIS, 13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | # See the License for the specific language governing permissions and 15 | # limitations under the License. 16 | 17 | import csv 18 | import sys 19 | import logging 20 | 21 | logger = logging.getLogger(__name__) 22 | 23 | try: 24 | from scipy.stats import pearsonr, spearmanr 25 | from sklearn.metrics import matthews_corrcoef, f1_score 26 | _has_sklearn = True 27 | except (AttributeError, ImportError) as e: 28 | logger.warning("To use data.metrics please install scikit-learn. See https://scikit-learn.org/stable/index.html") 29 | _has_sklearn = False 30 | 31 | def is_sklearn_available(): 32 | return _has_sklearn 33 | 34 | if _has_sklearn: 35 | 36 | def simple_accuracy(preds, labels): 37 | return (preds == labels).mean() 38 | 39 | 40 | def acc_and_f1(preds, labels, average='binary'): 41 | acc = simple_accuracy(preds, labels) 42 | f1 = f1_score(y_true=labels, y_pred=preds, average=average) 43 | return { 44 | "acc": acc, 45 | "f1": f1, 46 | "acc_and_f1": (acc + f1) / 2, 47 | } 48 | 49 | 50 | def pearson_and_spearman(preds, labels): 51 | pearson_corr = pearsonr(preds, labels)[0] 52 | spearman_corr = spearmanr(preds, labels)[0] 53 | return { 54 | "pearson": pearson_corr, 55 | "spearmanr": spearman_corr, 56 | "corr": (pearson_corr + spearman_corr) / 2, 57 | } 58 | 59 | 60 | def glue_compute_metrics(task_name, preds, labels): 61 | assert len(preds) == len(labels) 62 | if task_name == "cola": 63 | return {"mcc": matthews_corrcoef(labels, preds)} 64 | elif task_name == "sst-2": 65 | return {"acc": simple_accuracy(preds, labels)} 66 | elif task_name == "mrpc": 67 | return acc_and_f1(preds, labels) 68 | elif task_name == "sts-b": 69 | return pearson_and_spearman(preds, labels) 70 | elif task_name == "qqp": 71 | return acc_and_f1(preds, labels) 72 | elif task_name == "mnli": 73 | return {"acc": simple_accuracy(preds, labels)} 74 | elif task_name == "mnli-mm": 75 | return {"acc": simple_accuracy(preds, labels)} 76 | elif task_name == "qnli": 77 | return {"acc": simple_accuracy(preds, labels)} 78 | elif task_name == "rte": 79 | return {"acc": simple_accuracy(preds, labels)} 80 | elif task_name == "wnli": 81 | return {"acc": simple_accuracy(preds, labels)} 82 | elif task_name == "snli": 83 | return {"acc": simple_accuracy(preds, labels)} 84 | elif task_name == "scitail": 85 | return {"acc": simple_accuracy(preds, labels)} 86 | else: 87 | raise KeyError(task_name) 88 | 89 | -------------------------------------------------------------------------------- /transformers/data/processors/.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tuvuumass/task-transferability/88ac7e11b7d2befb6e049d1276f275c8a23ae3a0/transformers/data/processors/.DS_Store -------------------------------------------------------------------------------- /transformers/data/processors/__init__.py: -------------------------------------------------------------------------------- 1 | from .utils import InputExample, InputFeatures, DataProcessor 2 | from .glue import glue_output_modes, glue_processors, glue_tasks_num_labels, glue_convert_examples_to_features 3 | 4 | -------------------------------------------------------------------------------- /transformers/data/processors/utils.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2018 The Google AI Language Team Authors and The HuggingFace Inc. team. 3 | # Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved. 4 | # 5 | # Licensed under the Apache License, Version 2.0 (the "License"); 6 | # you may not use this file except in compliance with the License. 7 | # You may obtain a copy of the License at 8 | # 9 | # http://www.apache.org/licenses/LICENSE-2.0 10 | # 11 | # Unless required by applicable law or agreed to in writing, software 12 | # distributed under the License is distributed on an "AS IS" BASIS, 13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | # See the License for the specific language governing permissions and 15 | # limitations under the License. 16 | 17 | import csv 18 | import sys 19 | import copy 20 | import json 21 | 22 | class InputExample(object): 23 | """ 24 | A single training/test example for simple sequence classification. 25 | 26 | Args: 27 | guid: Unique id for the example. 28 | text_a: string. The untokenized text of the first sequence. For single 29 | sequence tasks, only this sequence must be specified. 30 | text_b: (Optional) string. The untokenized text of the second sequence. 31 | Only must be specified for sequence pair tasks. 32 | label: (Optional) string. The label of the example. This should be 33 | specified for train and dev examples, but not for test examples. 34 | """ 35 | def __init__(self, guid, text_a, text_b=None, label=None): 36 | self.guid = guid 37 | self.text_a = text_a 38 | self.text_b = text_b 39 | self.label = label 40 | 41 | def __repr__(self): 42 | return str(self.to_json_string()) 43 | 44 | def to_dict(self): 45 | """Serializes this instance to a Python dictionary.""" 46 | output = copy.deepcopy(self.__dict__) 47 | return output 48 | 49 | def to_json_string(self): 50 | """Serializes this instance to a JSON string.""" 51 | return json.dumps(self.to_dict(), indent=2, sort_keys=True) + "\n" 52 | 53 | 54 | class InputFeatures(object): 55 | """ 56 | A single set of features of data. 57 | 58 | Args: 59 | input_ids: Indices of input sequence tokens in the vocabulary. 60 | attention_mask: Mask to avoid performing attention on padding token indices. 61 | Mask values selected in ``[0, 1]``: 62 | Usually ``1`` for tokens that are NOT MASKED, ``0`` for MASKED (padded) tokens. 63 | token_type_ids: Segment token indices to indicate first and second portions of the inputs. 64 | label: Label corresponding to the input 65 | """ 66 | 67 | def __init__(self, input_ids, attention_mask, token_type_ids, label): 68 | self.input_ids = input_ids 69 | self.attention_mask = attention_mask 70 | self.token_type_ids = token_type_ids 71 | self.label = label 72 | 73 | def __repr__(self): 74 | return str(self.to_json_string()) 75 | 76 | def to_dict(self): 77 | """Serializes this instance to a Python dictionary.""" 78 | output = copy.deepcopy(self.__dict__) 79 | return output 80 | 81 | def to_json_string(self): 82 | """Serializes this instance to a JSON string.""" 83 | return json.dumps(self.to_dict(), indent=2, sort_keys=True) + "\n" 84 | 85 | 86 | class DataProcessor(object): 87 | """Base class for data converters for sequence classification data sets.""" 88 | 89 | def get_example_from_tensor_dict(self, tensor_dict): 90 | """Gets an example from a dict with tensorflow tensors 91 | 92 | Args: 93 | tensor_dict: Keys and values should match the corresponding Glue 94 | tensorflow_dataset examples. 95 | """ 96 | raise NotImplementedError() 97 | 98 | def get_train_examples(self, data_dir): 99 | """Gets a collection of `InputExample`s for the train set.""" 100 | raise NotImplementedError() 101 | 102 | def get_dev_examples(self, data_dir): 103 | """Gets a collection of `InputExample`s for the dev set.""" 104 | raise NotImplementedError() 105 | 106 | def get_labels(self): 107 | """Gets the list of labels for this data set.""" 108 | raise NotImplementedError() 109 | 110 | def tfds_map(self, example): 111 | """Some tensorflow_datasets datasets are not formatted the same way the GLUE datasets are. 112 | This method converts examples to the correct format.""" 113 | if len(self.get_labels()) > 1: 114 | example.label = self.get_labels()[int(example.label)] 115 | return example 116 | 117 | @classmethod 118 | def _read_tsv(cls, input_file, quotechar=None): 119 | """Reads a tab separated value file.""" 120 | with open(input_file, "r", encoding="utf-8-sig") as f: 121 | reader = csv.reader(f, delimiter="\t", quotechar=quotechar) 122 | lines = [] 123 | for line in reader: 124 | if sys.version_info[0] == 2: 125 | line = list(unicode(cell, 'utf-8') for cell in line) 126 | lines.append(line) 127 | return lines 128 | -------------------------------------------------------------------------------- /transformers/file_utils.py: -------------------------------------------------------------------------------- 1 | """ 2 | Utilities for working with the local dataset cache. 3 | This file is adapted from the AllenNLP library at https://github.com/allenai/allennlp 4 | Copyright by the AllenNLP authors. 5 | """ 6 | from __future__ import (absolute_import, division, print_function, unicode_literals) 7 | 8 | import sys 9 | import json 10 | import logging 11 | import os 12 | import six 13 | import shutil 14 | import tempfile 15 | import fnmatch 16 | from functools import wraps 17 | from hashlib import sha256 18 | from io import open 19 | 20 | import boto3 21 | from botocore.config import Config 22 | from botocore.exceptions import ClientError 23 | import requests 24 | from tqdm import tqdm 25 | 26 | logger = logging.getLogger(__name__) # pylint: disable=invalid-name 27 | 28 | try: 29 | import tensorflow as tf 30 | assert hasattr(tf, '__version__') and int(tf.__version__[0]) >= 2 31 | _tf_available = True # pylint: disable=invalid-name 32 | logger.info("TensorFlow version {} available.".format(tf.__version__)) 33 | except (ImportError, AssertionError): 34 | _tf_available = False # pylint: disable=invalid-name 35 | 36 | try: 37 | import torch 38 | _torch_available = True # pylint: disable=invalid-name 39 | logger.info("PyTorch version {} available.".format(torch.__version__)) 40 | except ImportError: 41 | _torch_available = False # pylint: disable=invalid-name 42 | 43 | 44 | try: 45 | from torch.hub import _get_torch_home 46 | torch_cache_home = _get_torch_home() 47 | except ImportError: 48 | torch_cache_home = os.path.expanduser( 49 | os.getenv('TORCH_HOME', os.path.join( 50 | os.getenv('XDG_CACHE_HOME', '~/.cache'), 'torch'))) 51 | default_cache_path = os.path.join(torch_cache_home, 'transformers') 52 | 53 | try: 54 | from urllib.parse import urlparse 55 | except ImportError: 56 | from urlparse import urlparse 57 | 58 | try: 59 | from pathlib import Path 60 | PYTORCH_PRETRAINED_BERT_CACHE = Path( 61 | os.getenv('PYTORCH_TRANSFORMERS_CACHE', os.getenv('PYTORCH_PRETRAINED_BERT_CACHE', default_cache_path))) 62 | except (AttributeError, ImportError): 63 | PYTORCH_PRETRAINED_BERT_CACHE = os.getenv('PYTORCH_TRANSFORMERS_CACHE', 64 | os.getenv('PYTORCH_PRETRAINED_BERT_CACHE', 65 | default_cache_path)) 66 | 67 | PYTORCH_TRANSFORMERS_CACHE = PYTORCH_PRETRAINED_BERT_CACHE # Kept for backward compatibility 68 | TRANSFORMERS_CACHE = PYTORCH_PRETRAINED_BERT_CACHE # Kept for backward compatibility 69 | 70 | WEIGHTS_NAME = "pytorch_model.bin" 71 | TF2_WEIGHTS_NAME = 'tf_model.h5' 72 | TF_WEIGHTS_NAME = 'model.ckpt' 73 | CONFIG_NAME = "config.json" 74 | 75 | def is_torch_available(): 76 | return _torch_available 77 | 78 | def is_tf_available(): 79 | return _tf_available 80 | 81 | if not six.PY2: 82 | def add_start_docstrings(*docstr): 83 | def docstring_decorator(fn): 84 | fn.__doc__ = ''.join(docstr) + fn.__doc__ 85 | return fn 86 | return docstring_decorator 87 | 88 | def add_end_docstrings(*docstr): 89 | def docstring_decorator(fn): 90 | fn.__doc__ = fn.__doc__ + ''.join(docstr) 91 | return fn 92 | return docstring_decorator 93 | else: 94 | # Not possible to update class docstrings on python2 95 | def add_start_docstrings(*docstr): 96 | def docstring_decorator(fn): 97 | return fn 98 | return docstring_decorator 99 | 100 | def add_end_docstrings(*docstr): 101 | def docstring_decorator(fn): 102 | return fn 103 | return docstring_decorator 104 | 105 | def url_to_filename(url, etag=None): 106 | """ 107 | Convert `url` into a hashed filename in a repeatable way. 108 | If `etag` is specified, append its hash to the url's, delimited 109 | by a period. 110 | If the url ends with .h5 (Keras HDF5 weights) ands '.h5' to the name 111 | so that TF 2.0 can identify it as a HDF5 file 112 | (see https://github.com/tensorflow/tensorflow/blob/00fad90125b18b80fe054de1055770cfb8fe4ba3/tensorflow/python/keras/engine/network.py#L1380) 113 | """ 114 | url_bytes = url.encode('utf-8') 115 | url_hash = sha256(url_bytes) 116 | filename = url_hash.hexdigest() 117 | 118 | if etag: 119 | etag_bytes = etag.encode('utf-8') 120 | etag_hash = sha256(etag_bytes) 121 | filename += '.' + etag_hash.hexdigest() 122 | 123 | if url.endswith('.h5'): 124 | filename += '.h5' 125 | 126 | return filename 127 | 128 | 129 | def filename_to_url(filename, cache_dir=None): 130 | """ 131 | Return the url and etag (which may be ``None``) stored for `filename`. 132 | Raise ``EnvironmentError`` if `filename` or its stored metadata do not exist. 133 | """ 134 | if cache_dir is None: 135 | cache_dir = TRANSFORMERS_CACHE 136 | if sys.version_info[0] == 3 and isinstance(cache_dir, Path): 137 | cache_dir = str(cache_dir) 138 | 139 | cache_path = os.path.join(cache_dir, filename) 140 | if not os.path.exists(cache_path): 141 | raise EnvironmentError("file {} not found".format(cache_path)) 142 | 143 | meta_path = cache_path + '.json' 144 | if not os.path.exists(meta_path): 145 | raise EnvironmentError("file {} not found".format(meta_path)) 146 | 147 | with open(meta_path, encoding="utf-8") as meta_file: 148 | metadata = json.load(meta_file) 149 | url = metadata['url'] 150 | etag = metadata['etag'] 151 | 152 | return url, etag 153 | 154 | 155 | def cached_path(url_or_filename, cache_dir=None, force_download=False, proxies=None): 156 | """ 157 | Given something that might be a URL (or might be a local path), 158 | determine which. If it's a URL, download the file and cache it, and 159 | return the path to the cached file. If it's already a local path, 160 | make sure the file exists and then return the path. 161 | Args: 162 | cache_dir: specify a cache directory to save the file to (overwrite the default cache dir). 163 | force_download: if True, re-dowload the file even if it's already cached in the cache dir. 164 | """ 165 | if cache_dir is None: 166 | cache_dir = TRANSFORMERS_CACHE 167 | if sys.version_info[0] == 3 and isinstance(url_or_filename, Path): 168 | url_or_filename = str(url_or_filename) 169 | if sys.version_info[0] == 3 and isinstance(cache_dir, Path): 170 | cache_dir = str(cache_dir) 171 | 172 | parsed = urlparse(url_or_filename) 173 | 174 | if parsed.scheme in ('http', 'https', 's3'): 175 | # URL, so get it from the cache (downloading if necessary) 176 | return get_from_cache(url_or_filename, cache_dir=cache_dir, force_download=force_download, proxies=proxies) 177 | elif os.path.exists(url_or_filename): 178 | # File, and it exists. 179 | return url_or_filename 180 | elif parsed.scheme == '': 181 | # File, but it doesn't exist. 182 | raise EnvironmentError("file {} not found".format(url_or_filename)) 183 | else: 184 | # Something unknown 185 | raise ValueError("unable to parse {} as a URL or as a local path".format(url_or_filename)) 186 | 187 | 188 | def split_s3_path(url): 189 | """Split a full s3 path into the bucket name and path.""" 190 | parsed = urlparse(url) 191 | if not parsed.netloc or not parsed.path: 192 | raise ValueError("bad s3 path {}".format(url)) 193 | bucket_name = parsed.netloc 194 | s3_path = parsed.path 195 | # Remove '/' at beginning of path. 196 | if s3_path.startswith("/"): 197 | s3_path = s3_path[1:] 198 | return bucket_name, s3_path 199 | 200 | 201 | def s3_request(func): 202 | """ 203 | Wrapper function for s3 requests in order to create more helpful error 204 | messages. 205 | """ 206 | 207 | @wraps(func) 208 | def wrapper(url, *args, **kwargs): 209 | try: 210 | return func(url, *args, **kwargs) 211 | except ClientError as exc: 212 | if int(exc.response["Error"]["Code"]) == 404: 213 | raise EnvironmentError("file {} not found".format(url)) 214 | else: 215 | raise 216 | 217 | return wrapper 218 | 219 | 220 | @s3_request 221 | def s3_etag(url, proxies=None): 222 | """Check ETag on S3 object.""" 223 | s3_resource = boto3.resource("s3", config=Config(proxies=proxies)) 224 | bucket_name, s3_path = split_s3_path(url) 225 | s3_object = s3_resource.Object(bucket_name, s3_path) 226 | return s3_object.e_tag 227 | 228 | 229 | @s3_request 230 | def s3_get(url, temp_file, proxies=None): 231 | """Pull a file directly from S3.""" 232 | s3_resource = boto3.resource("s3", config=Config(proxies=proxies)) 233 | bucket_name, s3_path = split_s3_path(url) 234 | s3_resource.Bucket(bucket_name).download_fileobj(s3_path, temp_file) 235 | 236 | 237 | def http_get(url, temp_file, proxies=None): 238 | req = requests.get(url, stream=True, proxies=proxies) 239 | content_length = req.headers.get('Content-Length') 240 | total = int(content_length) if content_length is not None else None 241 | progress = tqdm(unit="B", total=total) 242 | for chunk in req.iter_content(chunk_size=1024): 243 | if chunk: # filter out keep-alive new chunks 244 | progress.update(len(chunk)) 245 | temp_file.write(chunk) 246 | progress.close() 247 | 248 | 249 | def get_from_cache(url, cache_dir=None, force_download=False, proxies=None, etag_timeout=10): 250 | """ 251 | Given a URL, look for the corresponding dataset in the local cache. 252 | If it's not there, download it. Then return the path to the cached file. 253 | """ 254 | if cache_dir is None: 255 | cache_dir = TRANSFORMERS_CACHE 256 | if sys.version_info[0] == 3 and isinstance(cache_dir, Path): 257 | cache_dir = str(cache_dir) 258 | if sys.version_info[0] == 2 and not isinstance(cache_dir, str): 259 | cache_dir = str(cache_dir) 260 | 261 | if not os.path.exists(cache_dir): 262 | os.makedirs(cache_dir) 263 | 264 | # Get eTag to add to filename, if it exists. 265 | if url.startswith("s3://"): 266 | etag = s3_etag(url, proxies=proxies) 267 | else: 268 | try: 269 | response = requests.head(url, allow_redirects=True, proxies=proxies, timeout=etag_timeout) 270 | if response.status_code != 200: 271 | etag = None 272 | else: 273 | etag = response.headers.get("ETag") 274 | except (EnvironmentError, requests.exceptions.Timeout): 275 | etag = None 276 | 277 | if sys.version_info[0] == 2 and etag is not None: 278 | etag = etag.decode('utf-8') 279 | filename = url_to_filename(url, etag) 280 | 281 | # get cache path to put the file 282 | cache_path = os.path.join(cache_dir, filename) 283 | 284 | # If we don't have a connection (etag is None) and can't identify the file 285 | # try to get the last downloaded one 286 | if not os.path.exists(cache_path) and etag is None: 287 | matching_files = fnmatch.filter(os.listdir(cache_dir), filename + '.*') 288 | matching_files = list(filter(lambda s: not s.endswith('.json'), matching_files)) 289 | if matching_files: 290 | cache_path = os.path.join(cache_dir, matching_files[-1]) 291 | 292 | if not os.path.exists(cache_path) or force_download: 293 | # Download to temporary file, then copy to cache dir once finished. 294 | # Otherwise you get corrupt cache entries if the download gets interrupted. 295 | with tempfile.NamedTemporaryFile() as temp_file: 296 | logger.info("%s not found in cache or force_download set to True, downloading to %s", url, temp_file.name) 297 | 298 | # GET file object 299 | if url.startswith("s3://"): 300 | s3_get(url, temp_file, proxies=proxies) 301 | else: 302 | http_get(url, temp_file, proxies=proxies) 303 | 304 | # we are copying the file before closing it, so flush to avoid truncation 305 | temp_file.flush() 306 | # shutil.copyfileobj() starts at the current position, so go to the start 307 | temp_file.seek(0) 308 | 309 | logger.info("copying %s to cache at %s", temp_file.name, cache_path) 310 | with open(cache_path, 'wb') as cache_file: 311 | shutil.copyfileobj(temp_file, cache_file) 312 | 313 | logger.info("creating metadata file for %s", cache_path) 314 | meta = {'url': url, 'etag': etag} 315 | meta_path = cache_path + '.json' 316 | with open(meta_path, 'w') as meta_file: 317 | output_string = json.dumps(meta) 318 | if sys.version_info[0] == 2 and isinstance(output_string, str): 319 | output_string = unicode(output_string, 'utf-8') # The beauty of python 2 320 | meta_file.write(output_string) 321 | 322 | logger.info("removing temp file %s", temp_file.name) 323 | 324 | return cache_path 325 | -------------------------------------------------------------------------------- /transformers/modeling_tf_pytorch_utils.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2018 The Google AI Language Team Authors and The HuggingFace Inc. team. 3 | # Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved. 4 | # 5 | # Licensed under the Apache License, Version 2.0 (the "License"); 6 | # you may not use this file except in compliance with the License. 7 | # You may obtain a copy of the License at 8 | # 9 | # http://www.apache.org/licenses/LICENSE-2.0 10 | # 11 | # Unless required by applicable law or agreed to in writing, software 12 | # distributed under the License is distributed on an "AS IS" BASIS, 13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | # See the License for the specific language governing permissions and 15 | # limitations under the License. 16 | """ PyTorch - TF 2.0 general utilities.""" 17 | 18 | from __future__ import (absolute_import, division, print_function, 19 | unicode_literals) 20 | 21 | import logging 22 | import os 23 | import re 24 | import numpy 25 | 26 | logger = logging.getLogger(__name__) 27 | 28 | def convert_tf_weight_name_to_pt_weight_name(tf_name, start_prefix_to_remove=''): 29 | """ Convert a TF 2.0 model variable name in a pytorch model weight name. 30 | 31 | Conventions for TF2.0 scopes -> PyTorch attribute names conversions: 32 | - '$1___$2' is replaced by $2 (can be used to duplicate or remove layers in TF2.0 vs PyTorch) 33 | - '_._' is replaced by a new level separation (can be used to convert TF2.0 lists in PyTorch nn.ModulesList) 34 | 35 | return tuple with: 36 | - pytorch model weight name 37 | - transpose: boolean indicating weither TF2.0 and PyTorch weights matrices are transposed with regards to each other 38 | """ 39 | tf_name = tf_name.replace(':0', '') # device ids 40 | tf_name = re.sub(r'/[^/]*___([^/]*)/', r'/\1/', tf_name) # '$1___$2' is replaced by $2 (can be used to duplicate or remove layers in TF2.0 vs PyTorch) 41 | tf_name = tf_name.replace('_._', '/') # '_._' is replaced by a level separation (can be used to convert TF2.0 lists in PyTorch nn.ModulesList) 42 | tf_name = re.sub(r'//+', '/', tf_name) # Remove empty levels at the end 43 | tf_name = tf_name.split('/') # Convert from TF2.0 '/' separators to PyTorch '.' separators 44 | tf_name = tf_name[1:] # Remove level zero 45 | 46 | # When should we transpose the weights 47 | transpose = bool(tf_name[-1] == 'kernel' or 'emb_projs' in tf_name or 'out_projs' in tf_name) 48 | 49 | # Convert standard TF2.0 names in PyTorch names 50 | if tf_name[-1] == 'kernel' or tf_name[-1] == 'embeddings' or tf_name[-1] == 'gamma': 51 | tf_name[-1] = 'weight' 52 | if tf_name[-1] == 'beta': 53 | tf_name[-1] = 'bias' 54 | 55 | # Remove prefix if needed 56 | tf_name = '.'.join(tf_name) 57 | if start_prefix_to_remove: 58 | tf_name = tf_name.replace(start_prefix_to_remove, '', 1) 59 | 60 | return tf_name, transpose 61 | 62 | 63 | ##################### 64 | ### PyTorch => TF 2.0 65 | 66 | def load_pytorch_checkpoint_in_tf2_model(tf_model, pytorch_checkpoint_path, tf_inputs=None, allow_missing_keys=False): 67 | """ Load pytorch checkpoints in a TF 2.0 model 68 | """ 69 | try: 70 | import tensorflow as tf 71 | import torch 72 | except ImportError as e: 73 | logger.error("Loading a PyTorch model in TensorFlow, requires both PyTorch and TensorFlow to be installed. Please see " 74 | "https://pytorch.org/ and https://www.tensorflow.org/install/ for installation instructions.") 75 | raise e 76 | 77 | pt_path = os.path.abspath(pytorch_checkpoint_path) 78 | logger.info("Loading PyTorch weights from {}".format(pt_path)) 79 | 80 | pt_state_dict = torch.load(pt_path, map_location='cpu') 81 | 82 | return load_pytorch_weights_in_tf2_model(tf_model, pt_state_dict, tf_inputs=tf_inputs, allow_missing_keys=allow_missing_keys) 83 | 84 | 85 | def load_pytorch_model_in_tf2_model(tf_model, pt_model, tf_inputs=None, allow_missing_keys=False): 86 | """ Load pytorch checkpoints in a TF 2.0 model 87 | """ 88 | pt_state_dict = pt_model.state_dict() 89 | 90 | return load_pytorch_weights_in_tf2_model(tf_model, pt_state_dict, tf_inputs=tf_inputs, allow_missing_keys=allow_missing_keys) 91 | 92 | 93 | def load_pytorch_weights_in_tf2_model(tf_model, pt_state_dict, tf_inputs=None, allow_missing_keys=False): 94 | """ Load pytorch state_dict in a TF 2.0 model. 95 | """ 96 | try: 97 | import torch 98 | import tensorflow as tf 99 | from tensorflow.python.keras import backend as K 100 | except ImportError as e: 101 | logger.error("Loading a PyTorch model in TensorFlow, requires both PyTorch and TensorFlow to be installed. Please see " 102 | "https://pytorch.org/ and https://www.tensorflow.org/install/ for installation instructions.") 103 | raise e 104 | 105 | if tf_inputs is None: 106 | tf_inputs = tf_model.dummy_inputs 107 | 108 | if tf_inputs is not None: 109 | tfo = tf_model(tf_inputs, training=False) # Make sure model is built 110 | 111 | # Adapt state dict - TODO remove this and update the AWS weights files instead 112 | # Convert old format to new format if needed from a PyTorch state_dict 113 | old_keys = [] 114 | new_keys = [] 115 | for key in pt_state_dict.keys(): 116 | new_key = None 117 | if 'gamma' in key: 118 | new_key = key.replace('gamma', 'weight') 119 | if 'beta' in key: 120 | new_key = key.replace('beta', 'bias') 121 | if new_key: 122 | old_keys.append(key) 123 | new_keys.append(new_key) 124 | for old_key, new_key in zip(old_keys, new_keys): 125 | pt_state_dict[new_key] = pt_state_dict.pop(old_key) 126 | 127 | # Make sure we are able to load PyTorch base models as well as derived models (with heads) 128 | # TF models always have a prefix, some of PyTorch models (base ones) don't 129 | start_prefix_to_remove = '' 130 | if not any(s.startswith(tf_model.base_model_prefix) for s in pt_state_dict.keys()): 131 | start_prefix_to_remove = tf_model.base_model_prefix + '.' 132 | 133 | symbolic_weights = tf_model.trainable_weights + tf_model.non_trainable_weights 134 | 135 | weight_value_tuples = [] 136 | all_pytorch_weights = set(list(pt_state_dict.keys())) 137 | for symbolic_weight in symbolic_weights: 138 | sw_name = symbolic_weight.name 139 | name, transpose = convert_tf_weight_name_to_pt_weight_name(sw_name, start_prefix_to_remove=start_prefix_to_remove) 140 | 141 | # Find associated numpy array in pytorch model state dict 142 | assert name in pt_state_dict, "{} not found in PyTorch model".format(name) 143 | array = pt_state_dict[name].numpy() 144 | 145 | if transpose: 146 | array = numpy.transpose(array) 147 | 148 | if len(symbolic_weight.shape) < len(array.shape): 149 | array = numpy.squeeze(array) 150 | elif len(symbolic_weight.shape) > len(array.shape): 151 | array = numpy.expand_dims(array, axis=0) 152 | 153 | try: 154 | assert list(symbolic_weight.shape) == list(array.shape) 155 | except AssertionError as e: 156 | e.args += (symbolic_weight.shape, array.shape) 157 | raise e 158 | 159 | logger.info("Initialize TF weight {}".format(symbolic_weight.name)) 160 | 161 | weight_value_tuples.append((symbolic_weight, array)) 162 | all_pytorch_weights.discard(name) 163 | 164 | K.batch_set_value(weight_value_tuples) 165 | 166 | if tf_inputs is not None: 167 | tfo = tf_model(tf_inputs, training=False) # Make sure restore ops are run 168 | 169 | logger.info("Weights or buffers not loaded from PyTorch model: {}".format(all_pytorch_weights)) 170 | 171 | return tf_model 172 | 173 | 174 | ##################### 175 | ### TF 2.0 => PyTorch 176 | 177 | def load_tf2_checkpoint_in_pytorch_model(pt_model, tf_checkpoint_path, tf_inputs=None, allow_missing_keys=False): 178 | """ Load TF 2.0 HDF5 checkpoint in a PyTorch model 179 | We use HDF5 to easily do transfer learning 180 | (see https://github.com/tensorflow/tensorflow/blob/ee16fcac960ae660e0e4496658a366e2f745e1f0/tensorflow/python/keras/engine/network.py#L1352-L1357). 181 | """ 182 | try: 183 | import tensorflow as tf 184 | import torch 185 | except ImportError as e: 186 | logger.error("Loading a TensorFlow model in PyTorch, requires both PyTorch and TensorFlow to be installed. Please see " 187 | "https://pytorch.org/ and https://www.tensorflow.org/install/ for installation instructions.") 188 | raise e 189 | 190 | import transformers 191 | 192 | tf_path = os.path.abspath(tf_checkpoint_path) 193 | logger.info("Loading TensorFlow weights from {}".format(tf_checkpoint_path)) 194 | 195 | # Instantiate and load the associated TF 2.0 model 196 | tf_model_class_name = "TF" + pt_model.__class__.__name__ # Add "TF" at the beggining 197 | tf_model_class = getattr(transformers, tf_model_class_name) 198 | tf_model = tf_model_class(pt_model.config) 199 | 200 | if tf_inputs is None: 201 | tf_inputs = tf_model.dummy_inputs 202 | 203 | if tf_inputs is not None: 204 | tfo = tf_model(tf_inputs, training=False) # Make sure model is built 205 | 206 | tf_model.load_weights(tf_checkpoint_path, by_name=True) 207 | 208 | return load_tf2_model_in_pytorch_model(pt_model, tf_model, allow_missing_keys=allow_missing_keys) 209 | 210 | def load_tf2_model_in_pytorch_model(pt_model, tf_model, allow_missing_keys=False): 211 | """ Load TF 2.0 model in a pytorch model 212 | """ 213 | weights = tf_model.weights 214 | 215 | return load_tf2_weights_in_pytorch_model(pt_model, weights, allow_missing_keys=allow_missing_keys) 216 | 217 | 218 | def load_tf2_weights_in_pytorch_model(pt_model, tf_weights, allow_missing_keys=False): 219 | """ Load TF2.0 symbolic weights in a PyTorch model 220 | """ 221 | try: 222 | import tensorflow as tf 223 | import torch 224 | except ImportError as e: 225 | logger.error("Loading a TensorFlow model in PyTorch, requires both PyTorch and TensorFlow to be installed. Please see " 226 | "https://pytorch.org/ and https://www.tensorflow.org/install/ for installation instructions.") 227 | raise e 228 | 229 | new_pt_params_dict = {} 230 | current_pt_params_dict = dict(pt_model.named_parameters()) 231 | 232 | # Make sure we are able to load PyTorch base models as well as derived models (with heads) 233 | # TF models always have a prefix, some of PyTorch models (base ones) don't 234 | start_prefix_to_remove = '' 235 | if not any(s.startswith(pt_model.base_model_prefix) for s in current_pt_params_dict.keys()): 236 | start_prefix_to_remove = pt_model.base_model_prefix + '.' 237 | 238 | # Build a map from potential PyTorch weight names to TF 2.0 Variables 239 | tf_weights_map = {} 240 | for tf_weight in tf_weights: 241 | pt_name, transpose = convert_tf_weight_name_to_pt_weight_name(tf_weight.name, start_prefix_to_remove=start_prefix_to_remove) 242 | tf_weights_map[pt_name] = (tf_weight.numpy(), transpose) 243 | 244 | all_tf_weights = set(list(tf_weights_map.keys())) 245 | loaded_pt_weights_data_ptr = {} 246 | for pt_weight_name, pt_weight in current_pt_params_dict.items(): 247 | # Handle PyTorch shared weight ()not duplicated in TF 2.0 248 | if pt_weight.data_ptr() in loaded_pt_weights_data_ptr: 249 | new_pt_params_dict[pt_weight_name] = loaded_pt_weights_data_ptr[pt_weight.data_ptr()] 250 | continue 251 | 252 | # Find associated numpy array in pytorch model state dict 253 | if pt_weight_name not in tf_weights_map: 254 | raise ValueError("{} not found in TF 2.0 model".format(pt_weight_name)) 255 | 256 | array, transpose = tf_weights_map[pt_weight_name] 257 | 258 | if transpose: 259 | array = numpy.transpose(array) 260 | 261 | if len(pt_weight.shape) < len(array.shape): 262 | array = numpy.squeeze(array) 263 | elif len(pt_weight.shape) > len(array.shape): 264 | array = numpy.expand_dims(array, axis=0) 265 | 266 | try: 267 | assert list(pt_weight.shape) == list(array.shape) 268 | except AssertionError as e: 269 | e.args += (pt_weight.shape, array.shape) 270 | raise e 271 | 272 | logger.info("Initialize PyTorch weight {}".format(pt_weight_name)) 273 | 274 | new_pt_params_dict[pt_weight_name] = torch.from_numpy(array) 275 | loaded_pt_weights_data_ptr[pt_weight.data_ptr()] = torch.from_numpy(array) 276 | all_tf_weights.discard(pt_weight_name) 277 | 278 | missing_keys, unexpected_keys = pt_model.load_state_dict(new_pt_params_dict, strict=False) 279 | 280 | if len(missing_keys) > 0: 281 | logger.info("Weights of {} not initialized from TF 2.0 model: {}".format( 282 | pt_model.__class__.__name__, missing_keys)) 283 | if len(unexpected_keys) > 0: 284 | logger.info("Weights from TF 2.0 model not used in {}: {}".format( 285 | pt_model.__class__.__name__, unexpected_keys)) 286 | 287 | logger.info("Weights or buffers not loaded from TF 2.0 model: {}".format(all_tf_weights)) 288 | 289 | return pt_model 290 | -------------------------------------------------------------------------------- /transformers/optimization.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2018 The Google AI Language Team Authors and The HuggingFace Inc. team. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | """PyTorch optimization for BERT model.""" 16 | 17 | import logging 18 | import math 19 | 20 | import torch 21 | from torch.optim import Optimizer 22 | from torch.optim.lr_scheduler import LambdaLR 23 | 24 | logger = logging.getLogger(__name__) 25 | 26 | 27 | def get_constant_schedule(optimizer, last_epoch=-1): 28 | """ Create a schedule with a constant learning rate. 29 | """ 30 | return LambdaLR(optimizer, lambda _: 1, last_epoch=last_epoch) 31 | 32 | 33 | def get_constant_schedule_with_warmup(optimizer, num_warmup_steps, last_epoch=-1): 34 | """ Create a schedule with a constant learning rate preceded by a warmup 35 | period during which the learning rate increases linearly between 0 and 1. 36 | """ 37 | def lr_lambda(current_step): 38 | if current_step < num_warmup_steps: 39 | return float(current_step) / float(max(1.0, num_warmup_steps)) 40 | return 1. 41 | 42 | return LambdaLR(optimizer, lr_lambda, last_epoch=last_epoch) 43 | 44 | 45 | def get_linear_schedule_with_warmup(optimizer, num_warmup_steps, num_training_steps, last_epoch=-1): 46 | """ Create a schedule with a learning rate that decreases linearly after 47 | linearly increasing during a warmup period. 48 | """ 49 | def lr_lambda(current_step): 50 | if current_step < num_warmup_steps: 51 | return float(current_step) / float(max(1, num_warmup_steps)) 52 | return max(0.0, float(num_training_steps - current_step) / float(max(1, num_training_steps - num_warmup_steps))) 53 | 54 | return LambdaLR(optimizer, lr_lambda, last_epoch) 55 | 56 | 57 | def get_cosine_schedule_with_warmup(optimizer, num_warmup_steps, num_training_steps, num_cycles=.5, last_epoch=-1): 58 | """ Create a schedule with a learning rate that decreases following the 59 | values of the cosine function between 0 and `pi * cycles` after a warmup 60 | period during which it increases linearly between 0 and 1. 61 | """ 62 | def lr_lambda(current_step): 63 | if current_step < num_warmup_steps: 64 | return float(current_step) / float(max(1, num_warmup_steps)) 65 | progress = float(current_step - num_warmup_steps) / float(max(1, num_training_steps - num_warmup_steps)) 66 | return max(0., 0.5 * (1. + math.cos(math.pi * float(num_cycles) * 2. * progress))) 67 | 68 | return LambdaLR(optimizer, lr_lambda, last_epoch) 69 | 70 | 71 | def get_cosine_with_hard_restarts_schedule_with_warmup(optimizer, num_warmup_steps, num_training_steps, num_cycles=1., last_epoch=-1): 72 | """ Create a schedule with a learning rate that decreases following the 73 | values of the cosine function with several hard restarts, after a warmup 74 | period during which it increases linearly between 0 and 1. 75 | """ 76 | def lr_lambda(current_step): 77 | if current_step < num_warmup_steps: 78 | return float(current_step) / float(max(1, num_warmup_steps)) 79 | progress = float(current_step - num_warmup_steps) / float(max(1, num_training_steps - num_warmup_steps)) 80 | if progress >= 1.: 81 | return 0. 82 | return max(0., 0.5 * (1. + math.cos(math.pi * ((float(num_cycles) * progress) % 1.)))) 83 | 84 | return LambdaLR(optimizer, lr_lambda, last_epoch) 85 | 86 | 87 | class AdamW(Optimizer): 88 | """ Implements Adam algorithm with weight decay fix. 89 | 90 | Parameters: 91 | lr (float): learning rate. Default 1e-3. 92 | betas (tuple of 2 floats): Adams beta parameters (b1, b2). Default: (0.9, 0.999) 93 | eps (float): Adams epsilon. Default: 1e-6 94 | weight_decay (float): Weight decay. Default: 0.0 95 | correct_bias (bool): can be set to False to avoid correcting bias in Adam (e.g. like in Bert TF repository). Default True. 96 | """ 97 | def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-6, weight_decay=0.0, correct_bias=True): 98 | if lr < 0.0: 99 | raise ValueError("Invalid learning rate: {} - should be >= 0.0".format(lr)) 100 | if not 0.0 <= betas[0] < 1.0: 101 | raise ValueError("Invalid beta parameter: {} - should be in [0.0, 1.0[".format(betas[0])) 102 | if not 0.0 <= betas[1] < 1.0: 103 | raise ValueError("Invalid beta parameter: {} - should be in [0.0, 1.0[".format(betas[1])) 104 | if not 0.0 <= eps: 105 | raise ValueError("Invalid epsilon value: {} - should be >= 0.0".format(eps)) 106 | defaults = dict(lr=lr, betas=betas, eps=eps, weight_decay=weight_decay, 107 | correct_bias=correct_bias) 108 | super(AdamW, self).__init__(params, defaults) 109 | 110 | def step(self, closure=None): 111 | """Performs a single optimization step. 112 | 113 | Arguments: 114 | closure (callable, optional): A closure that reevaluates the model 115 | and returns the loss. 116 | """ 117 | loss = None 118 | if closure is not None: 119 | loss = closure() 120 | 121 | for group in self.param_groups: 122 | for p in group['params']: 123 | if p.grad is None: 124 | continue 125 | grad = p.grad.data 126 | if grad.is_sparse: 127 | raise RuntimeError('Adam does not support sparse gradients, please consider SparseAdam instead') 128 | 129 | state = self.state[p] 130 | 131 | # State initialization 132 | if len(state) == 0: 133 | state['step'] = 0 134 | # Exponential moving average of gradient values 135 | state['exp_avg'] = torch.zeros_like(p.data) 136 | # Exponential moving average of squared gradient values 137 | state['exp_avg_sq'] = torch.zeros_like(p.data) 138 | 139 | exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq'] 140 | beta1, beta2 = group['betas'] 141 | 142 | state['step'] += 1 143 | 144 | # Decay the first and second moment running average coefficient 145 | # In-place operations to update the averages at the same time 146 | exp_avg.mul_(beta1).add_(1.0 - beta1, grad) 147 | exp_avg_sq.mul_(beta2).addcmul_(1.0 - beta2, grad, grad) 148 | denom = exp_avg_sq.sqrt().add_(group['eps']) 149 | 150 | step_size = group['lr'] 151 | if group['correct_bias']: # No bias correction for Bert 152 | bias_correction1 = 1.0 - beta1 ** state['step'] 153 | bias_correction2 = 1.0 - beta2 ** state['step'] 154 | step_size = step_size * math.sqrt(bias_correction2) / bias_correction1 155 | 156 | p.data.addcdiv_(-step_size, exp_avg, denom) 157 | 158 | # Just adding the square of the weights to the loss function is *not* 159 | # the correct way of using L2 regularization/weight decay with Adam, 160 | # since that will interact with the m and v parameters in strange ways. 161 | # 162 | # Instead we want to decay the weights in a manner that doesn't interact 163 | # with the m/v parameters. This is equivalent to adding the square 164 | # of the weights to the loss with plain (non-momentum) SGD. 165 | # Add weight decay at the end (fixed version) 166 | if group['weight_decay'] > 0.0: 167 | p.data.add_(-group['lr'] * group['weight_decay'], p.data) 168 | 169 | return loss 170 | -------------------------------------------------------------------------------- /transformers/tests/.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tuvuumass/task-transferability/88ac7e11b7d2befb6e049d1276f275c8a23ae3a0/transformers/tests/.DS_Store -------------------------------------------------------------------------------- /transformers/tests/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tuvuumass/task-transferability/88ac7e11b7d2befb6e049d1276f275c8a23ae3a0/transformers/tests/__init__.py -------------------------------------------------------------------------------- /transformers/tests/configuration_common_test.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2019 HuggingFace Inc. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | from __future__ import absolute_import 16 | from __future__ import division 17 | from __future__ import print_function 18 | 19 | import copy 20 | import os 21 | import shutil 22 | import json 23 | import random 24 | import uuid 25 | 26 | import unittest 27 | import logging 28 | 29 | 30 | class ConfigTester(object): 31 | def __init__(self, parent, config_class=None, **kwargs): 32 | self.parent = parent 33 | self.config_class = config_class 34 | self.inputs_dict = kwargs 35 | 36 | def create_and_test_config_common_properties(self): 37 | config = self.config_class(**self.inputs_dict) 38 | self.parent.assertTrue(hasattr(config, 'vocab_size')) 39 | self.parent.assertTrue(hasattr(config, 'hidden_size')) 40 | self.parent.assertTrue(hasattr(config, 'num_attention_heads')) 41 | self.parent.assertTrue(hasattr(config, 'num_hidden_layers')) 42 | 43 | def create_and_test_config_to_json_string(self): 44 | config = self.config_class(**self.inputs_dict) 45 | obj = json.loads(config.to_json_string()) 46 | for key, value in self.inputs_dict.items(): 47 | self.parent.assertEqual(obj[key], value) 48 | 49 | def create_and_test_config_to_json_file(self): 50 | config_first = self.config_class(**self.inputs_dict) 51 | json_file_path = os.path.join(os.getcwd(), "config_" + str(uuid.uuid4()) + ".json") 52 | config_first.to_json_file(json_file_path) 53 | config_second = self.config_class.from_json_file(json_file_path) 54 | os.remove(json_file_path) 55 | self.parent.assertEqual(config_second.to_dict(), config_first.to_dict()) 56 | 57 | def run_common_tests(self): 58 | self.create_and_test_config_common_properties() 59 | self.create_and_test_config_to_json_string() 60 | self.create_and_test_config_to_json_file() 61 | 62 | if __name__ == "__main__": 63 | unittest.main() -------------------------------------------------------------------------------- /transformers/tests/conftest.py: -------------------------------------------------------------------------------- 1 | # content of conftest.py 2 | 3 | import pytest 4 | 5 | 6 | def pytest_addoption(parser): 7 | parser.addoption( 8 | "--runslow", action="store_true", default=False, help="run slow tests" 9 | ) 10 | parser.addoption( 11 | "--use_cuda", action="store_true", default=False, help="run tests on gpu" 12 | ) 13 | 14 | 15 | def pytest_configure(config): 16 | config.addinivalue_line("markers", "slow: mark test as slow to run") 17 | 18 | 19 | def pytest_collection_modifyitems(config, items): 20 | if config.getoption("--runslow"): 21 | # --runslow given in cli: do not skip slow tests 22 | return 23 | skip_slow = pytest.mark.skip(reason="need --runslow option to run") 24 | for item in items: 25 | if "slow" in item.keywords: 26 | item.add_marker(skip_slow) 27 | 28 | @pytest.fixture 29 | def use_cuda(request): 30 | """ Run test on gpu """ 31 | return request.config.getoption("--use_cuda") 32 | -------------------------------------------------------------------------------- /transformers/tests/fixtures/input.txt: -------------------------------------------------------------------------------- 1 | Who was Jim Henson ? ||| Jim Henson was a puppeteer 2 | -------------------------------------------------------------------------------- /transformers/tests/fixtures/sample_text.txt: -------------------------------------------------------------------------------- 1 | This text is included to make sure Unicode is handled properly: 力加勝北区ᴵᴺᵀᵃছজটডণত 2 | Text should be one-sentence-per-line, with empty lines between documents. 3 | This sample text is public domain and was randomly selected from Project Guttenberg. 4 | 5 | The rain had only ceased with the gray streaks of morning at Blazing Star, and the settlement awoke to a moral sense of cleanliness, and the finding of forgotten knives, tin cups, and smaller camp utensils, where the heavy showers had washed away the debris and dust heaps before the cabin doors. 6 | Indeed, it was recorded in Blazing Star that a fortunate early riser had once picked up on the highway a solid chunk of gold quartz which the rain had freed from its incumbering soil, and washed into immediate and glittering popularity. 7 | Possibly this may have been the reason why early risers in that locality, during the rainy season, adopted a thoughtful habit of body, and seldom lifted their eyes to the rifted or india-ink washed skies above them. 8 | "Cass" Beard had risen early that morning, but not with a view to discovery. 9 | A leak in his cabin roof,--quite consistent with his careless, improvident habits,--had roused him at 4 A. M., with a flooded "bunk" and wet blankets. 10 | The chips from his wood pile refused to kindle a fire to dry his bed-clothes, and he had recourse to a more provident neighbor's to supply the deficiency. 11 | This was nearly opposite. 12 | Mr. Cassius crossed the highway, and stopped suddenly. 13 | Something glittered in the nearest red pool before him. 14 | Gold, surely! 15 | But, wonderful to relate, not an irregular, shapeless fragment of crude ore, fresh from Nature's crucible, but a bit of jeweler's handicraft in the form of a plain gold ring. 16 | Looking at it more attentively, he saw that it bore the inscription, "May to Cass." 17 | Like most of his fellow gold-seekers, Cass was superstitious. 18 | 19 | The fountain of classic wisdom, Hypatia herself. 20 | As the ancient sage--the name is unimportant to a monk--pumped water nightly that he might study by day, so I, the guardian of cloaks and parasols, at the sacred doors of her lecture-room, imbibe celestial knowledge. 21 | From my youth I felt in me a soul above the matter-entangled herd. 22 | She revealed to me the glorious fact, that I am a spark of Divinity itself. 23 | A fallen star, I am, sir!' continued he, pensively, stroking his lean stomach--'a fallen star!--fallen, if the dignity of philosophy will allow of the simile, among the hogs of the lower world--indeed, even into the hog-bucket itself. Well, after all, I will show you the way to the Archbishop's. 24 | There is a philosophic pleasure in opening one's treasures to the modest young. 25 | Perhaps you will assist me by carrying this basket of fruit?' And the little man jumped up, put his basket on Philammon's head, and trotted off up a neighbouring street. 26 | Philammon followed, half contemptuous, half wondering at what this philosophy might be, which could feed the self-conceit of anything so abject as his ragged little apish guide; 27 | but the novel roar and whirl of the street, the perpetual stream of busy faces, the line of curricles, palanquins, laden asses, camels, elephants, which met and passed him, and squeezed him up steps and into doorways, as they threaded their way through the great Moon-gate into the ample street beyond, drove everything from his mind but wondering curiosity, and a vague, helpless dread of that great living wilderness, more terrible than any dead wilderness of sand which he had left behind. 28 | Already he longed for the repose, the silence of the Laura--for faces which knew him and smiled upon him; but it was too late to turn back now. 29 | His guide held on for more than a mile up the great main street, crossed in the centre of the city, at right angles, by one equally magnificent, at each end of which, miles away, appeared, dim and distant over the heads of the living stream of passengers, the yellow sand-hills of the desert; 30 | while at the end of the vista in front of them gleamed the blue harbour, through a network of countless masts. 31 | At last they reached the quay at the opposite end of the street; 32 | and there burst on Philammon's astonished eyes a vast semicircle of blue sea, ringed with palaces and towers. 33 | He stopped involuntarily; and his little guide stopped also, and looked askance at the young monk, to watch the effect which that grand panorama should produce on him. 34 | -------------------------------------------------------------------------------- /transformers/tests/fixtures/test_sentencepiece.model: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tuvuumass/task-transferability/88ac7e11b7d2befb6e049d1276f275c8a23ae3a0/transformers/tests/fixtures/test_sentencepiece.model -------------------------------------------------------------------------------- /transformers/tests/modeling_auto_test.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2018 The Google AI Language Team Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | from __future__ import absolute_import 16 | from __future__ import division 17 | from __future__ import print_function 18 | 19 | import unittest 20 | import shutil 21 | import pytest 22 | import logging 23 | 24 | from transformers import is_torch_available 25 | 26 | if is_torch_available(): 27 | from transformers import (AutoConfig, BertConfig, 28 | AutoModel, BertModel, 29 | AutoModelWithLMHead, BertForMaskedLM, 30 | AutoModelForSequenceClassification, BertForSequenceClassification, 31 | AutoModelForQuestionAnswering, BertForQuestionAnswering) 32 | from transformers.modeling_bert import BERT_PRETRAINED_MODEL_ARCHIVE_MAP 33 | 34 | from .modeling_common_test import (CommonTestCases, ids_tensor) 35 | from .configuration_common_test import ConfigTester 36 | else: 37 | pytestmark = pytest.mark.skip("Require Torch") 38 | 39 | 40 | class AutoModelTest(unittest.TestCase): 41 | @pytest.mark.slow 42 | def test_model_from_pretrained(self): 43 | logging.basicConfig(level=logging.INFO) 44 | for model_name in list(BERT_PRETRAINED_MODEL_ARCHIVE_MAP.keys())[:1]: 45 | config = AutoConfig.from_pretrained(model_name) 46 | self.assertIsNotNone(config) 47 | self.assertIsInstance(config, BertConfig) 48 | 49 | model = AutoModel.from_pretrained(model_name) 50 | model, loading_info = AutoModel.from_pretrained(model_name, output_loading_info=True) 51 | self.assertIsNotNone(model) 52 | self.assertIsInstance(model, BertModel) 53 | for value in loading_info.values(): 54 | self.assertEqual(len(value), 0) 55 | 56 | @pytest.mark.slow 57 | def test_lmhead_model_from_pretrained(self): 58 | logging.basicConfig(level=logging.INFO) 59 | for model_name in list(BERT_PRETRAINED_MODEL_ARCHIVE_MAP.keys())[:1]: 60 | config = AutoConfig.from_pretrained(model_name) 61 | self.assertIsNotNone(config) 62 | self.assertIsInstance(config, BertConfig) 63 | 64 | model = AutoModelWithLMHead.from_pretrained(model_name) 65 | model, loading_info = AutoModelWithLMHead.from_pretrained(model_name, output_loading_info=True) 66 | self.assertIsNotNone(model) 67 | self.assertIsInstance(model, BertForMaskedLM) 68 | 69 | @pytest.mark.slow 70 | def test_sequence_classification_model_from_pretrained(self): 71 | logging.basicConfig(level=logging.INFO) 72 | for model_name in list(BERT_PRETRAINED_MODEL_ARCHIVE_MAP.keys())[:1]: 73 | config = AutoConfig.from_pretrained(model_name) 74 | self.assertIsNotNone(config) 75 | self.assertIsInstance(config, BertConfig) 76 | 77 | model = AutoModelForSequenceClassification.from_pretrained(model_name) 78 | model, loading_info = AutoModelForSequenceClassification.from_pretrained(model_name, output_loading_info=True) 79 | self.assertIsNotNone(model) 80 | self.assertIsInstance(model, BertForSequenceClassification) 81 | 82 | @pytest.mark.slow 83 | def test_question_answering_model_from_pretrained(self): 84 | logging.basicConfig(level=logging.INFO) 85 | for model_name in list(BERT_PRETRAINED_MODEL_ARCHIVE_MAP.keys())[:1]: 86 | config = AutoConfig.from_pretrained(model_name) 87 | self.assertIsNotNone(config) 88 | self.assertIsInstance(config, BertConfig) 89 | 90 | model = AutoModelForQuestionAnswering.from_pretrained(model_name) 91 | model, loading_info = AutoModelForQuestionAnswering.from_pretrained(model_name, output_loading_info=True) 92 | self.assertIsNotNone(model) 93 | self.assertIsInstance(model, BertForQuestionAnswering) 94 | 95 | 96 | if __name__ == "__main__": 97 | unittest.main() 98 | -------------------------------------------------------------------------------- /transformers/tests/optimization_test.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2018 The Google AI Language Team Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | from __future__ import absolute_import 16 | from __future__ import division 17 | from __future__ import print_function 18 | 19 | import unittest 20 | import os 21 | import pytest 22 | 23 | from transformers import is_torch_available 24 | 25 | if is_torch_available(): 26 | import torch 27 | 28 | from transformers import (AdamW, 29 | get_constant_schedule, 30 | get_constant_schedule_with_warmup, 31 | get_cosine_schedule_with_warmup, 32 | get_cosine_with_hard_restarts_schedule_with_warmup, 33 | get_linear_schedule_with_warmup) 34 | else: 35 | pytestmark = pytest.mark.skip("Require Torch") 36 | 37 | from .tokenization_tests_commons import TemporaryDirectory 38 | 39 | 40 | def unwrap_schedule(scheduler, num_steps=10): 41 | lrs = [] 42 | for _ in range(num_steps): 43 | scheduler.step() 44 | lrs.append(scheduler.get_lr()) 45 | return lrs 46 | 47 | def unwrap_and_save_reload_schedule(scheduler, num_steps=10): 48 | lrs = [] 49 | for step in range(num_steps): 50 | scheduler.step() 51 | lrs.append(scheduler.get_lr()) 52 | if step == num_steps // 2: 53 | with TemporaryDirectory() as tmpdirname: 54 | file_name = os.path.join(tmpdirname, 'schedule.bin') 55 | torch.save(scheduler.state_dict(), file_name) 56 | 57 | state_dict = torch.load(file_name) 58 | scheduler.load_state_dict(state_dict) 59 | return lrs 60 | 61 | class OptimizationTest(unittest.TestCase): 62 | 63 | def assertListAlmostEqual(self, list1, list2, tol): 64 | self.assertEqual(len(list1), len(list2)) 65 | for a, b in zip(list1, list2): 66 | self.assertAlmostEqual(a, b, delta=tol) 67 | 68 | def test_adam_w(self): 69 | w = torch.tensor([0.1, -0.2, -0.1], requires_grad=True) 70 | target = torch.tensor([0.4, 0.2, -0.5]) 71 | criterion = torch.nn.MSELoss() 72 | # No warmup, constant schedule, no gradient clipping 73 | optimizer = AdamW(params=[w], lr=2e-1, weight_decay=0.0) 74 | for _ in range(100): 75 | loss = criterion(w, target) 76 | loss.backward() 77 | optimizer.step() 78 | w.grad.detach_() # No zero_grad() function on simple tensors. we do it ourselves. 79 | w.grad.zero_() 80 | self.assertListAlmostEqual(w.tolist(), [0.4, 0.2, -0.5], tol=1e-2) 81 | 82 | 83 | class ScheduleInitTest(unittest.TestCase): 84 | m = torch.nn.Linear(50, 50) if is_torch_available() else None 85 | optimizer = AdamW(m.parameters(), lr=10.) if is_torch_available() else None 86 | num_steps = 10 87 | 88 | def assertListAlmostEqual(self, list1, list2, tol): 89 | self.assertEqual(len(list1), len(list2)) 90 | for a, b in zip(list1, list2): 91 | self.assertAlmostEqual(a, b, delta=tol) 92 | 93 | def test_constant_scheduler(self): 94 | scheduler = get_constant_schedule(self.optimizer) 95 | lrs = unwrap_schedule(scheduler, self.num_steps) 96 | expected_learning_rates = [10.] * self.num_steps 97 | self.assertEqual(len(lrs[0]), 1) 98 | self.assertListEqual([l[0] for l in lrs], expected_learning_rates) 99 | 100 | scheduler = get_constant_schedule(self.optimizer) 101 | lrs_2 = unwrap_and_save_reload_schedule(scheduler, self.num_steps) 102 | self.assertListEqual([l[0] for l in lrs], [l[0] for l in lrs_2]) 103 | 104 | def test_warmup_constant_scheduler(self): 105 | scheduler = get_constant_schedule_with_warmup(self.optimizer, num_warmup_steps=4) 106 | lrs = unwrap_schedule(scheduler, self.num_steps) 107 | expected_learning_rates = [2.5, 5.0, 7.5, 10.0, 10.0, 10.0, 10.0, 10.0, 10.0, 10.0] 108 | self.assertEqual(len(lrs[0]), 1) 109 | self.assertListEqual([l[0] for l in lrs], expected_learning_rates) 110 | 111 | scheduler = get_constant_schedule_with_warmup(self.optimizer, num_warmup_steps=4) 112 | lrs_2 = unwrap_and_save_reload_schedule(scheduler, self.num_steps) 113 | self.assertListEqual([l[0] for l in lrs], [l[0] for l in lrs_2]) 114 | 115 | def test_warmup_linear_scheduler(self): 116 | scheduler = get_linear_schedule_with_warmup(self.optimizer, num_warmup_steps=2, num_training_steps=10) 117 | lrs = unwrap_schedule(scheduler, self.num_steps) 118 | expected_learning_rates = [5.0, 10.0, 8.75, 7.5, 6.25, 5.0, 3.75, 2.5, 1.25, 0.0] 119 | self.assertEqual(len(lrs[0]), 1) 120 | self.assertListEqual([l[0] for l in lrs], expected_learning_rates) 121 | 122 | scheduler = get_linear_schedule_with_warmup(self.optimizer, num_warmup_steps=2, num_training_steps=10) 123 | lrs_2 = unwrap_and_save_reload_schedule(scheduler, self.num_steps) 124 | self.assertListEqual([l[0] for l in lrs], [l[0] for l in lrs_2]) 125 | 126 | def test_warmup_cosine_scheduler(self): 127 | scheduler = get_cosine_schedule_with_warmup(self.optimizer, num_warmup_steps=2, num_training_steps=10) 128 | lrs = unwrap_schedule(scheduler, self.num_steps) 129 | expected_learning_rates = [5.0, 10.0, 9.61, 8.53, 6.91, 5.0, 3.08, 1.46, 0.38, 0.0] 130 | self.assertEqual(len(lrs[0]), 1) 131 | self.assertListAlmostEqual([l[0] for l in lrs], expected_learning_rates, tol=1e-2) 132 | 133 | scheduler = get_cosine_schedule_with_warmup(self.optimizer, num_warmup_steps=2, num_training_steps=10) 134 | lrs_2 = unwrap_and_save_reload_schedule(scheduler, self.num_steps) 135 | self.assertListEqual([l[0] for l in lrs], [l[0] for l in lrs_2]) 136 | 137 | def test_warmup_cosine_hard_restart_scheduler(self): 138 | scheduler = get_cosine_with_hard_restarts_schedule_with_warmup(self.optimizer, num_warmup_steps=2, num_cycles=2, num_training_steps=10) 139 | lrs = unwrap_schedule(scheduler, self.num_steps) 140 | expected_learning_rates = [5.0, 10.0, 8.53, 5.0, 1.46, 10.0, 8.53, 5.0, 1.46, 0.0] 141 | self.assertEqual(len(lrs[0]), 1) 142 | self.assertListAlmostEqual([l[0] for l in lrs], expected_learning_rates, tol=1e-2) 143 | 144 | scheduler = get_cosine_with_hard_restarts_schedule_with_warmup(self.optimizer, num_warmup_steps=2, num_cycles=2, num_training_steps=10) 145 | lrs_2 = unwrap_and_save_reload_schedule(scheduler, self.num_steps) 146 | self.assertListEqual([l[0] for l in lrs], [l[0] for l in lrs_2]) 147 | 148 | 149 | if __name__ == "__main__": 150 | unittest.main() 151 | -------------------------------------------------------------------------------- /transformers/tests/tokenization_auto_test.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2018 The Google AI Language Team Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | from __future__ import absolute_import 16 | from __future__ import division 17 | from __future__ import print_function 18 | 19 | import unittest 20 | import shutil 21 | import pytest 22 | import logging 23 | 24 | from transformers import AutoTokenizer, BertTokenizer, AutoTokenizer, GPT2Tokenizer 25 | from transformers import BERT_PRETRAINED_CONFIG_ARCHIVE_MAP, GPT2_PRETRAINED_CONFIG_ARCHIVE_MAP 26 | 27 | 28 | class AutoTokenizerTest(unittest.TestCase): 29 | @pytest.mark.slow 30 | def test_tokenizer_from_pretrained(self): 31 | logging.basicConfig(level=logging.INFO) 32 | for model_name in list(BERT_PRETRAINED_CONFIG_ARCHIVE_MAP.keys())[:1]: 33 | tokenizer = AutoTokenizer.from_pretrained(model_name) 34 | self.assertIsNotNone(tokenizer) 35 | self.assertIsInstance(tokenizer, BertTokenizer) 36 | self.assertGreater(len(tokenizer), 0) 37 | 38 | for model_name in list(GPT2_PRETRAINED_CONFIG_ARCHIVE_MAP.keys())[:1]: 39 | tokenizer = AutoTokenizer.from_pretrained(model_name) 40 | self.assertIsNotNone(tokenizer) 41 | self.assertIsInstance(tokenizer, GPT2Tokenizer) 42 | self.assertGreater(len(tokenizer), 0) 43 | 44 | 45 | if __name__ == "__main__": 46 | unittest.main() 47 | -------------------------------------------------------------------------------- /transformers/tests/tokenization_bert_test.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2018 The Google AI Language Team Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | from __future__ import absolute_import, division, print_function, unicode_literals 16 | 17 | import os 18 | import unittest 19 | import pytest 20 | from io import open 21 | 22 | from transformers.tokenization_bert import (BasicTokenizer, 23 | BertTokenizer, 24 | WordpieceTokenizer, 25 | _is_control, _is_punctuation, 26 | _is_whitespace, VOCAB_FILES_NAMES) 27 | 28 | from .tokenization_tests_commons import CommonTestCases 29 | 30 | class BertTokenizationTest(CommonTestCases.CommonTokenizerTester): 31 | 32 | tokenizer_class = BertTokenizer 33 | 34 | def setUp(self): 35 | super(BertTokenizationTest, self).setUp() 36 | 37 | vocab_tokens = [ 38 | "[UNK]", "[CLS]", "[SEP]", "want", "##want", "##ed", "wa", "un", "runn", 39 | "##ing", ",", "low", "lowest", 40 | ] 41 | self.vocab_file = os.path.join(self.tmpdirname, VOCAB_FILES_NAMES['vocab_file']) 42 | with open(self.vocab_file, "w", encoding='utf-8') as vocab_writer: 43 | vocab_writer.write("".join([x + "\n" for x in vocab_tokens])) 44 | 45 | def get_tokenizer(self, **kwargs): 46 | return BertTokenizer.from_pretrained(self.tmpdirname, **kwargs) 47 | 48 | def get_input_output_texts(self): 49 | input_text = u"UNwant\u00E9d,running" 50 | output_text = u"unwanted, running" 51 | return input_text, output_text 52 | 53 | def test_full_tokenizer(self): 54 | tokenizer = self.tokenizer_class(self.vocab_file) 55 | 56 | tokens = tokenizer.tokenize(u"UNwant\u00E9d,running") 57 | self.assertListEqual(tokens, ["un", "##want", "##ed", ",", "runn", "##ing"]) 58 | self.assertListEqual(tokenizer.convert_tokens_to_ids(tokens), [7, 4, 5, 10, 8, 9]) 59 | 60 | def test_chinese(self): 61 | tokenizer = BasicTokenizer() 62 | 63 | self.assertListEqual( 64 | tokenizer.tokenize(u"ah\u535A\u63A8zz"), 65 | [u"ah", u"\u535A", u"\u63A8", u"zz"]) 66 | 67 | def test_basic_tokenizer_lower(self): 68 | tokenizer = BasicTokenizer(do_lower_case=True) 69 | 70 | self.assertListEqual( 71 | tokenizer.tokenize(u" \tHeLLo!how \n Are yoU? "), 72 | ["hello", "!", "how", "are", "you", "?"]) 73 | self.assertListEqual(tokenizer.tokenize(u"H\u00E9llo"), ["hello"]) 74 | 75 | def test_basic_tokenizer_no_lower(self): 76 | tokenizer = BasicTokenizer(do_lower_case=False) 77 | 78 | self.assertListEqual( 79 | tokenizer.tokenize(u" \tHeLLo!how \n Are yoU? "), 80 | ["HeLLo", "!", "how", "Are", "yoU", "?"]) 81 | 82 | def test_wordpiece_tokenizer(self): 83 | vocab_tokens = [ 84 | "[UNK]", "[CLS]", "[SEP]", "want", "##want", "##ed", "wa", "un", "runn", 85 | "##ing" 86 | ] 87 | 88 | vocab = {} 89 | for (i, token) in enumerate(vocab_tokens): 90 | vocab[token] = i 91 | tokenizer = WordpieceTokenizer(vocab=vocab, unk_token="[UNK]") 92 | 93 | self.assertListEqual(tokenizer.tokenize(""), []) 94 | 95 | self.assertListEqual( 96 | tokenizer.tokenize("unwanted running"), 97 | ["un", "##want", "##ed", "runn", "##ing"]) 98 | 99 | self.assertListEqual( 100 | tokenizer.tokenize("unwantedX running"), ["[UNK]", "runn", "##ing"]) 101 | 102 | def test_is_whitespace(self): 103 | self.assertTrue(_is_whitespace(u" ")) 104 | self.assertTrue(_is_whitespace(u"\t")) 105 | self.assertTrue(_is_whitespace(u"\r")) 106 | self.assertTrue(_is_whitespace(u"\n")) 107 | self.assertTrue(_is_whitespace(u"\u00A0")) 108 | 109 | self.assertFalse(_is_whitespace(u"A")) 110 | self.assertFalse(_is_whitespace(u"-")) 111 | 112 | def test_is_control(self): 113 | self.assertTrue(_is_control(u"\u0005")) 114 | 115 | self.assertFalse(_is_control(u"A")) 116 | self.assertFalse(_is_control(u" ")) 117 | self.assertFalse(_is_control(u"\t")) 118 | self.assertFalse(_is_control(u"\r")) 119 | 120 | def test_is_punctuation(self): 121 | self.assertTrue(_is_punctuation(u"-")) 122 | self.assertTrue(_is_punctuation(u"$")) 123 | self.assertTrue(_is_punctuation(u"`")) 124 | self.assertTrue(_is_punctuation(u".")) 125 | 126 | self.assertFalse(_is_punctuation(u"A")) 127 | self.assertFalse(_is_punctuation(u" ")) 128 | 129 | @pytest.mark.slow 130 | def test_sequence_builders(self): 131 | tokenizer = self.tokenizer_class.from_pretrained("bert-base-uncased") 132 | 133 | text = tokenizer.encode("sequence builders", add_special_tokens=False) 134 | text_2 = tokenizer.encode("multi-sequence build", add_special_tokens=False) 135 | 136 | encoded_sentence = tokenizer.build_inputs_with_special_tokens(text) 137 | encoded_pair = tokenizer.build_inputs_with_special_tokens(text, text_2) 138 | 139 | assert encoded_sentence == [101] + text + [102] 140 | assert encoded_pair == [101] + text + [102] + text_2 + [102] 141 | 142 | if __name__ == '__main__': 143 | unittest.main() 144 | -------------------------------------------------------------------------------- /transformers/tests/tokenization_utils_test.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2018 HuggingFace Inc.. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | from __future__ import absolute_import 16 | from __future__ import division 17 | from __future__ import print_function 18 | 19 | import unittest 20 | import six 21 | import pytest 22 | 23 | from transformers import PreTrainedTokenizer 24 | from transformers.tokenization_gpt2 import GPT2Tokenizer 25 | 26 | class TokenizerUtilsTest(unittest.TestCase): 27 | @pytest.mark.slow 28 | def check_tokenizer_from_pretrained(self, tokenizer_class): 29 | s3_models = list(tokenizer_class.max_model_input_sizes.keys()) 30 | for model_name in s3_models[:1]: 31 | tokenizer = tokenizer_class.from_pretrained(model_name) 32 | self.assertIsNotNone(tokenizer) 33 | self.assertIsInstance(tokenizer, tokenizer_class) 34 | self.assertIsInstance(tokenizer, PreTrainedTokenizer) 35 | 36 | for special_tok in tokenizer.all_special_tokens: 37 | if six.PY2: 38 | self.assertIsInstance(special_tok, unicode) 39 | else: 40 | self.assertIsInstance(special_tok, str) 41 | special_tok_id = tokenizer.convert_tokens_to_ids(special_tok) 42 | self.assertIsInstance(special_tok_id, int) 43 | 44 | def test_pretrained_tokenizers(self): 45 | self.check_tokenizer_from_pretrained(GPT2Tokenizer) 46 | 47 | if __name__ == "__main__": 48 | unittest.main() 49 | -------------------------------------------------------------------------------- /transformers/tokenization_auto.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2018 The HuggingFace Inc. team. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | """ Auto Model class. """ 16 | 17 | from __future__ import absolute_import, division, print_function, unicode_literals 18 | 19 | import logging 20 | 21 | from .tokenization_bert import BertTokenizer 22 | 23 | logger = logging.getLogger(__name__) 24 | 25 | class AutoTokenizer(object): 26 | r""":class:`~transformers.AutoTokenizer` is a generic tokenizer class 27 | that will be instantiated as one of the tokenizer classes of the library 28 | when created with the `AutoTokenizer.from_pretrained(pretrained_model_name_or_path)` 29 | class method. 30 | 31 | The `from_pretrained()` method take care of returning the correct tokenizer class instance 32 | using pattern matching on the `pretrained_model_name_or_path` string. 33 | 34 | The tokenizer class to instantiate is selected as the first pattern matching 35 | in the `pretrained_model_name_or_path` string (in the following order): 36 | - contains `bert`: BertTokenizer (Bert model) 37 | 38 | This class cannot be instantiated using `__init__()` (throw an error). 39 | """ 40 | def __init__(self): 41 | raise EnvironmentError("AutoTokenizer is designed to be instantiated " 42 | "using the `AutoTokenizer.from_pretrained(pretrained_model_name_or_path)` method.") 43 | 44 | @classmethod 45 | def from_pretrained(cls, pretrained_model_name_or_path, *inputs, **kwargs): 46 | r""" Instantiate a one of the tokenizer classes of the library 47 | from a pre-trained model vocabulary. 48 | 49 | The tokenizer class to instantiate is selected as the first pattern matching 50 | in the `pretrained_model_name_or_path` string (in the following order): 51 | - contains `bert`: BertTokenizer (Bert model) 52 | 53 | Params: 54 | pretrained_model_name_or_path: either: 55 | 56 | - a string with the `shortcut name` of a predefined tokenizer to load from cache or download, e.g.: ``bert-base-uncased``. 57 | - a path to a `directory` containing vocabulary files required by the tokenizer, for instance saved using the :func:`~transformers.PreTrainedTokenizer.save_pretrained` method, e.g.: ``./my_model_directory/``. 58 | - (not applicable to all derived classes) a path or url to a single saved vocabulary file if and only if the tokenizer only requires a single vocabulary file (e.g. Bert, XLNet), e.g.: ``./my_model_directory/vocab.txt``. 59 | 60 | cache_dir: (`optional`) string: 61 | Path to a directory in which a downloaded predefined tokenizer vocabulary files should be cached if the standard cache should not be used. 62 | 63 | force_download: (`optional`) boolean, default False: 64 | Force to (re-)download the vocabulary files and override the cached versions if they exists. 65 | 66 | proxies: (`optional`) dict, default None: 67 | A dictionary of proxy servers to use by protocol or endpoint, e.g.: {'http': 'foo.bar:3128', 'http://hostname': 'foo.bar:4012'}. 68 | The proxies are used on each request. 69 | 70 | inputs: (`optional`) positional arguments: will be passed to the Tokenizer ``__init__`` method. 71 | 72 | kwargs: (`optional`) keyword arguments: will be passed to the Tokenizer ``__init__`` method. Can be used to set special tokens like ``bos_token``, ``eos_token``, ``unk_token``, ``sep_token``, ``pad_token``, ``cls_token``, ``mask_token``, ``additional_special_tokens``. See parameters in the doc string of :class:`~transformers.PreTrainedTokenizer` for details. 73 | 74 | Examples:: 75 | 76 | tokenizer = AutoTokenizer.from_pretrained('bert-base-uncased') # Download vocabulary from S3 and cache. 77 | tokenizer = AutoTokenizer.from_pretrained('./test/bert_saved_model/') # E.g. tokenizer was saved using `save_pretrained('./test/saved_model/')` 78 | 79 | """ 80 | if 'bert' in pretrained_model_name_or_path: 81 | return BertTokenizer.from_pretrained(pretrained_model_name_or_path, *inputs, **kwargs) 82 | raise ValueError("Unrecognized model identifier in {}. Should contains one of " 83 | "'bert',".format(pretrained_model_name_or_path)) 84 | 85 | -------------------------------------------------------------------------------- /utils_ner.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2018 The Google AI Language Team Authors and The HuggingFace Inc. team. 3 | # Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved. 4 | # 5 | # Licensed under the Apache License, Version 2.0 (the "License"); 6 | # you may not use this file except in compliance with the License. 7 | # You may obtain a copy of the License at 8 | # 9 | # http://www.apache.org/licenses/LICENSE-2.0 10 | # 11 | # Unless required by applicable law or agreed to in writing, software 12 | # distributed under the License is distributed on an "AS IS" BASIS, 13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | # See the License for the specific language governing permissions and 15 | # limitations under the License. 16 | """ Named entity recognition fine-tuning: utilities to work with CoNLL-2003 task. """ 17 | 18 | from __future__ import absolute_import, division, print_function 19 | 20 | import logging 21 | import os 22 | from io import open 23 | 24 | logger = logging.getLogger(__name__) 25 | 26 | 27 | class InputExample(object): 28 | """A single training/test example for token classification.""" 29 | 30 | def __init__(self, guid, words, labels): 31 | """Constructs a InputExample. 32 | Args: 33 | guid: Unique id for the example. 34 | words: list. The words of the sequence. 35 | labels: (Optional) list. The labels for each word of the sequence. This should be 36 | specified for train and dev examples, but not for test examples. 37 | """ 38 | self.guid = guid 39 | self.words = words 40 | self.labels = labels 41 | 42 | 43 | class InputFeatures(object): 44 | """A single set of features of data.""" 45 | 46 | def __init__(self, input_ids, input_mask, segment_ids, label_ids): 47 | self.input_ids = input_ids 48 | self.input_mask = input_mask 49 | self.segment_ids = segment_ids 50 | self.label_ids = label_ids 51 | 52 | 53 | def read_examples_from_file(data_dir, mode): 54 | file_path = os.path.join(data_dir, "{}.txt".format(mode)) 55 | guid_index = 1 56 | examples = [] 57 | with open(file_path, encoding="utf-8") as f: 58 | words = [] 59 | labels = [] 60 | for line in f: 61 | if line.startswith("-DOCSTART-") or line == "" or line == "\n": 62 | if words: 63 | examples.append(InputExample(guid="{}-{}".format(mode, guid_index), 64 | words=words, 65 | labels=labels)) 66 | guid_index += 1 67 | words = [] 68 | labels = [] 69 | else: 70 | splits = line.split(" ") 71 | words.append(splits[0]) 72 | if len(splits) > 1: 73 | labels.append(splits[-1].replace("\n", "")) 74 | else: 75 | # Examples could have no label for mode = "test" 76 | labels.append("O") 77 | if words: 78 | examples.append(InputExample(guid="%s-%d".format(mode, guid_index), 79 | words=words, 80 | labels=labels)) 81 | return examples 82 | 83 | 84 | def convert_examples_to_features(examples, 85 | label_list, 86 | max_seq_length, 87 | tokenizer, 88 | cls_token_at_end=False, 89 | cls_token="[CLS]", 90 | cls_token_segment_id=1, 91 | sep_token="[SEP]", 92 | sep_token_extra=False, 93 | pad_on_left=False, 94 | pad_token=0, 95 | pad_token_segment_id=0, 96 | pad_token_label_id=-1, 97 | sequence_a_segment_id=0, 98 | mask_padding_with_zero=True): 99 | """ Loads a data file into a list of `InputBatch`s 100 | `cls_token_at_end` define the location of the CLS token: 101 | - False: [CLS] + A + [SEP] + B + [SEP] 102 | - True: A + [SEP] + B + [SEP] + [CLS] 103 | `cls_token_segment_id` define the segment id associated to the CLS token (0 for BERT) 104 | """ 105 | 106 | label_map = {label: i for i, label in enumerate(label_list)} 107 | 108 | features = [] 109 | for (ex_index, example) in enumerate(examples): 110 | if ex_index % 10000 == 0: 111 | logger.info("Writing example %d of %d", ex_index, len(examples)) 112 | 113 | tokens = [] 114 | label_ids = [] 115 | for word, label in zip(example.words, example.labels): 116 | word_tokens = tokenizer.tokenize(word) 117 | tokens.extend(word_tokens) 118 | # Use the real label id for the first token of the word, and padding ids for the remaining tokens 119 | label_ids.extend([label_map[label]] + [pad_token_label_id] * (len(word_tokens) - 1)) 120 | 121 | # Account for [CLS] and [SEP] with "- 2" and with "- 3" for RoBERTa. 122 | special_tokens_count = 3 if sep_token_extra else 2 123 | if len(tokens) > max_seq_length - special_tokens_count: 124 | tokens = tokens[:(max_seq_length - special_tokens_count)] 125 | label_ids = label_ids[:(max_seq_length - special_tokens_count)] 126 | 127 | # The convention in BERT is: 128 | # (a) For sequence pairs: 129 | # tokens: [CLS] is this jack ##son ##ville ? [SEP] no it is not . [SEP] 130 | # type_ids: 0 0 0 0 0 0 0 0 1 1 1 1 1 1 131 | # (b) For single sequences: 132 | # tokens: [CLS] the dog is hairy . [SEP] 133 | # type_ids: 0 0 0 0 0 0 0 134 | # 135 | # Where "type_ids" are used to indicate whether this is the first 136 | # sequence or the second sequence. The embedding vectors for `type=0` and 137 | # `type=1` were learned during pre-training and are added to the wordpiece 138 | # embedding vector (and position vector). This is not *strictly* necessary 139 | # since the [SEP] token unambiguously separates the sequences, but it makes 140 | # it easier for the model to learn the concept of sequences. 141 | # 142 | # For classification tasks, the first vector (corresponding to [CLS]) is 143 | # used as as the "sentence vector". Note that this only makes sense because 144 | # the entire model is fine-tuned. 145 | tokens += [sep_token] 146 | label_ids += [pad_token_label_id] 147 | if sep_token_extra: 148 | # roberta uses an extra separator b/w pairs of sentences 149 | tokens += [sep_token] 150 | label_ids += [pad_token_label_id] 151 | segment_ids = [sequence_a_segment_id] * len(tokens) 152 | 153 | if cls_token_at_end: 154 | tokens += [cls_token] 155 | label_ids += [pad_token_label_id] 156 | segment_ids += [cls_token_segment_id] 157 | else: 158 | tokens = [cls_token] + tokens 159 | label_ids = [pad_token_label_id] + label_ids 160 | segment_ids = [cls_token_segment_id] + segment_ids 161 | 162 | input_ids = tokenizer.convert_tokens_to_ids(tokens) 163 | 164 | # The mask has 1 for real tokens and 0 for padding tokens. Only real 165 | # tokens are attended to. 166 | input_mask = [1 if mask_padding_with_zero else 0] * len(input_ids) 167 | 168 | # Zero-pad up to the sequence length. 169 | padding_length = max_seq_length - len(input_ids) 170 | if pad_on_left: 171 | input_ids = ([pad_token] * padding_length) + input_ids 172 | input_mask = ([0 if mask_padding_with_zero else 1] * padding_length) + input_mask 173 | segment_ids = ([pad_token_segment_id] * padding_length) + segment_ids 174 | label_ids = ([pad_token_label_id] * padding_length) + label_ids 175 | else: 176 | input_ids += ([pad_token] * padding_length) 177 | input_mask += ([0 if mask_padding_with_zero else 1] * padding_length) 178 | segment_ids += ([pad_token_segment_id] * padding_length) 179 | label_ids += ([pad_token_label_id] * padding_length) 180 | 181 | assert len(input_ids) == max_seq_length 182 | assert len(input_mask) == max_seq_length 183 | assert len(segment_ids) == max_seq_length 184 | assert len(label_ids) == max_seq_length 185 | 186 | if ex_index < 5: 187 | logger.info("*** Example ***") 188 | logger.info("guid: %s", example.guid) 189 | logger.info("tokens: %s", " ".join([str(x) for x in tokens])) 190 | logger.info("input_ids: %s", " ".join([str(x) for x in input_ids])) 191 | logger.info("input_mask: %s", " ".join([str(x) for x in input_mask])) 192 | logger.info("segment_ids: %s", " ".join([str(x) for x in segment_ids])) 193 | logger.info("label_ids: %s", " ".join([str(x) for x in label_ids])) 194 | 195 | features.append( 196 | InputFeatures(input_ids=input_ids, 197 | input_mask=input_mask, 198 | segment_ids=segment_ids, 199 | label_ids=label_ids)) 200 | return features 201 | 202 | 203 | def get_labels(path): 204 | if path: 205 | with open(path, "r") as f: 206 | labels = f.read().splitlines() 207 | return labels 208 | else: 209 | return ["O", "B-MISC", "I-MISC", "B-PER", "I-PER", "B-ORG", "I-ORG", "B-LOC", "I-LOC"] 210 | -------------------------------------------------------------------------------- /utils_squad_evaluate.py: -------------------------------------------------------------------------------- 1 | """ Official evaluation script for SQuAD version 2.0. 2 | Modified by XLNet authors to update `find_best_threshold` scripts for SQuAD V2.0 3 | In addition to basic functionality, we also compute additional statistics and 4 | plot precision-recall curves if an additional na_prob.json file is provided. 5 | This file is expected to map question ID's to the model's predicted probability 6 | that a question is unanswerable. 7 | """ 8 | import argparse 9 | import collections 10 | import json 11 | import numpy as np 12 | import os 13 | import re 14 | import string 15 | import sys 16 | 17 | class EVAL_OPTS(): 18 | def __init__(self, data_file, pred_file, out_file="", 19 | na_prob_file="na_prob.json", na_prob_thresh=1.0, 20 | out_image_dir=None, verbose=False): 21 | self.data_file = data_file 22 | self.pred_file = pred_file 23 | self.out_file = out_file 24 | self.na_prob_file = na_prob_file 25 | self.na_prob_thresh = na_prob_thresh 26 | self.out_image_dir = out_image_dir 27 | self.verbose = verbose 28 | 29 | OPTS = None 30 | 31 | def parse_args(): 32 | parser = argparse.ArgumentParser('Official evaluation script for SQuAD version 2.0.') 33 | parser.add_argument('data_file', metavar='data.json', help='Input data JSON file.') 34 | parser.add_argument('pred_file', metavar='pred.json', help='Model predictions.') 35 | parser.add_argument('--out-file', '-o', metavar='eval.json', 36 | help='Write accuracy metrics to file (default is stdout).') 37 | parser.add_argument('--na-prob-file', '-n', metavar='na_prob.json', 38 | help='Model estimates of probability of no answer.') 39 | parser.add_argument('--na-prob-thresh', '-t', type=float, default=1.0, 40 | help='Predict "" if no-answer probability exceeds this (default = 1.0).') 41 | parser.add_argument('--out-image-dir', '-p', metavar='out_images', default=None, 42 | help='Save precision-recall curves to directory.') 43 | parser.add_argument('--verbose', '-v', action='store_true') 44 | if len(sys.argv) == 1: 45 | parser.print_help() 46 | sys.exit(1) 47 | return parser.parse_args() 48 | 49 | def make_qid_to_has_ans(dataset): 50 | qid_to_has_ans = {} 51 | for article in dataset: 52 | for p in article['paragraphs']: 53 | for qa in p['qas']: 54 | qid_to_has_ans[qa['id']] = bool(qa['answers']) 55 | return qid_to_has_ans 56 | 57 | def normalize_answer(s): 58 | """Lower text and remove punctuation, articles and extra whitespace.""" 59 | def remove_articles(text): 60 | regex = re.compile(r'\b(a|an|the)\b', re.UNICODE) 61 | return re.sub(regex, ' ', text) 62 | def white_space_fix(text): 63 | return ' '.join(text.split()) 64 | def remove_punc(text): 65 | exclude = set(string.punctuation) 66 | return ''.join(ch for ch in text if ch not in exclude) 67 | def lower(text): 68 | return text.lower() 69 | return white_space_fix(remove_articles(remove_punc(lower(s)))) 70 | 71 | def get_tokens(s): 72 | if not s: return [] 73 | return normalize_answer(s).split() 74 | 75 | def compute_exact(a_gold, a_pred): 76 | return int(normalize_answer(a_gold) == normalize_answer(a_pred)) 77 | 78 | def compute_f1(a_gold, a_pred): 79 | gold_toks = get_tokens(a_gold) 80 | pred_toks = get_tokens(a_pred) 81 | common = collections.Counter(gold_toks) & collections.Counter(pred_toks) 82 | num_same = sum(common.values()) 83 | if len(gold_toks) == 0 or len(pred_toks) == 0: 84 | # If either is no-answer, then F1 is 1 if they agree, 0 otherwise 85 | return int(gold_toks == pred_toks) 86 | if num_same == 0: 87 | return 0 88 | precision = 1.0 * num_same / len(pred_toks) 89 | recall = 1.0 * num_same / len(gold_toks) 90 | f1 = (2 * precision * recall) / (precision + recall) 91 | return f1 92 | 93 | def get_raw_scores(dataset, preds): 94 | exact_scores = {} 95 | f1_scores = {} 96 | for article in dataset: 97 | for p in article['paragraphs']: 98 | for qa in p['qas']: 99 | qid = qa['id'] 100 | gold_answers = [a['text'] for a in qa['answers'] 101 | if normalize_answer(a['text'])] 102 | if not gold_answers: 103 | # For unanswerable questions, only correct answer is empty string 104 | gold_answers = [''] 105 | if qid not in preds: 106 | print('Missing prediction for %s' % qid) 107 | continue 108 | a_pred = preds[qid] 109 | # Take max over all gold answers 110 | exact_scores[qid] = max(compute_exact(a, a_pred) for a in gold_answers) 111 | f1_scores[qid] = max(compute_f1(a, a_pred) for a in gold_answers) 112 | return exact_scores, f1_scores 113 | 114 | def apply_no_ans_threshold(scores, na_probs, qid_to_has_ans, na_prob_thresh): 115 | new_scores = {} 116 | for qid, s in scores.items(): 117 | pred_na = na_probs[qid] > na_prob_thresh 118 | if pred_na: 119 | new_scores[qid] = float(not qid_to_has_ans[qid]) 120 | else: 121 | new_scores[qid] = s 122 | return new_scores 123 | 124 | def make_eval_dict(exact_scores, f1_scores, qid_list=None): 125 | if not qid_list: 126 | total = len(exact_scores) 127 | return collections.OrderedDict([ 128 | ('exact', 100.0 * sum(exact_scores.values()) / total), 129 | ('f1', 100.0 * sum(f1_scores.values()) / total), 130 | ('total', total), 131 | ]) 132 | else: 133 | total = len(qid_list) 134 | 135 | return collections.OrderedDict([ 136 | ('exact', 100.0 * sum(exact_scores[k] for k in qid_list if k in exact_scores) / total), 137 | ('f1', 100.0 * sum(f1_scores[k] for k in qid_list if k in f1_scores) / total), 138 | ('total', total), 139 | ]) 140 | # return collections.OrderedDict([ 141 | # ('exact', 100.0 * sum(exact_scores[k] for k in qid_list) / total), 142 | # ('f1', 100.0 * sum(f1_scores[k] for k in qid_list) / total), 143 | # ('total', total), 144 | # ]) 145 | 146 | def merge_eval(main_eval, new_eval, prefix): 147 | for k in new_eval: 148 | main_eval['%s_%s' % (prefix, k)] = new_eval[k] 149 | 150 | def plot_pr_curve(precisions, recalls, out_image, title): 151 | plt.step(recalls, precisions, color='b', alpha=0.2, where='post') 152 | plt.fill_between(recalls, precisions, step='post', alpha=0.2, color='b') 153 | plt.xlabel('Recall') 154 | plt.ylabel('Precision') 155 | plt.xlim([0.0, 1.05]) 156 | plt.ylim([0.0, 1.05]) 157 | plt.title(title) 158 | plt.savefig(out_image) 159 | plt.clf() 160 | 161 | def make_precision_recall_eval(scores, na_probs, num_true_pos, qid_to_has_ans, 162 | out_image=None, title=None): 163 | qid_list = sorted(na_probs, key=lambda k: na_probs[k]) 164 | true_pos = 0.0 165 | cur_p = 1.0 166 | cur_r = 0.0 167 | precisions = [1.0] 168 | recalls = [0.0] 169 | avg_prec = 0.0 170 | for i, qid in enumerate(qid_list): 171 | if qid_to_has_ans[qid]: 172 | true_pos += scores[qid] 173 | cur_p = true_pos / float(i+1) 174 | cur_r = true_pos / float(num_true_pos) 175 | if i == len(qid_list) - 1 or na_probs[qid] != na_probs[qid_list[i+1]]: 176 | # i.e., if we can put a threshold after this point 177 | avg_prec += cur_p * (cur_r - recalls[-1]) 178 | precisions.append(cur_p) 179 | recalls.append(cur_r) 180 | if out_image: 181 | plot_pr_curve(precisions, recalls, out_image, title) 182 | return {'ap': 100.0 * avg_prec} 183 | 184 | def run_precision_recall_analysis(main_eval, exact_raw, f1_raw, na_probs, 185 | qid_to_has_ans, out_image_dir): 186 | if out_image_dir and not os.path.exists(out_image_dir): 187 | os.makedirs(out_image_dir) 188 | num_true_pos = sum(1 for v in qid_to_has_ans.values() if v) 189 | if num_true_pos == 0: 190 | return 191 | pr_exact = make_precision_recall_eval( 192 | exact_raw, na_probs, num_true_pos, qid_to_has_ans, 193 | out_image=os.path.join(out_image_dir, 'pr_exact.png'), 194 | title='Precision-Recall curve for Exact Match score') 195 | pr_f1 = make_precision_recall_eval( 196 | f1_raw, na_probs, num_true_pos, qid_to_has_ans, 197 | out_image=os.path.join(out_image_dir, 'pr_f1.png'), 198 | title='Precision-Recall curve for F1 score') 199 | oracle_scores = {k: float(v) for k, v in qid_to_has_ans.items()} 200 | pr_oracle = make_precision_recall_eval( 201 | oracle_scores, na_probs, num_true_pos, qid_to_has_ans, 202 | out_image=os.path.join(out_image_dir, 'pr_oracle.png'), 203 | title='Oracle Precision-Recall curve (binary task of HasAns vs. NoAns)') 204 | merge_eval(main_eval, pr_exact, 'pr_exact') 205 | merge_eval(main_eval, pr_f1, 'pr_f1') 206 | merge_eval(main_eval, pr_oracle, 'pr_oracle') 207 | 208 | def histogram_na_prob(na_probs, qid_list, image_dir, name): 209 | if not qid_list: 210 | return 211 | x = [na_probs[k] for k in qid_list] 212 | weights = np.ones_like(x) / float(len(x)) 213 | plt.hist(x, weights=weights, bins=20, range=(0.0, 1.0)) 214 | plt.xlabel('Model probability of no-answer') 215 | plt.ylabel('Proportion of dataset') 216 | plt.title('Histogram of no-answer probability: %s' % name) 217 | plt.savefig(os.path.join(image_dir, 'na_prob_hist_%s.png' % name)) 218 | plt.clf() 219 | 220 | def find_best_thresh(preds, scores, na_probs, qid_to_has_ans): 221 | num_no_ans = sum(1 for k in qid_to_has_ans if not qid_to_has_ans[k]) 222 | cur_score = num_no_ans 223 | best_score = cur_score 224 | best_thresh = 0.0 225 | qid_list = sorted(na_probs, key=lambda k: na_probs[k]) 226 | for i, qid in enumerate(qid_list): 227 | if qid not in scores: continue 228 | if qid_to_has_ans[qid]: 229 | diff = scores[qid] 230 | else: 231 | if preds[qid]: 232 | diff = -1 233 | else: 234 | diff = 0 235 | cur_score += diff 236 | if cur_score > best_score: 237 | best_score = cur_score 238 | best_thresh = na_probs[qid] 239 | return 100.0 * best_score / len(scores), best_thresh 240 | 241 | def find_best_thresh_v2(preds, scores, na_probs, qid_to_has_ans): 242 | num_no_ans = sum(1 for k in qid_to_has_ans if not qid_to_has_ans[k]) 243 | cur_score = num_no_ans 244 | best_score = cur_score 245 | best_thresh = 0.0 246 | qid_list = sorted(na_probs, key=lambda k: na_probs[k]) 247 | for i, qid in enumerate(qid_list): 248 | if qid not in scores: continue 249 | if qid_to_has_ans[qid]: 250 | diff = scores[qid] 251 | else: 252 | if preds[qid]: 253 | diff = -1 254 | else: 255 | diff = 0 256 | cur_score += diff 257 | if cur_score > best_score: 258 | best_score = cur_score 259 | best_thresh = na_probs[qid] 260 | 261 | has_ans_score, has_ans_cnt = 0, 0 262 | for qid in qid_list: 263 | if not qid_to_has_ans[qid]: continue 264 | has_ans_cnt += 1 265 | 266 | if qid not in scores: continue 267 | has_ans_score += scores[qid] 268 | 269 | return 100.0 * best_score / len(scores), best_thresh, 1.0 * has_ans_score / has_ans_cnt 270 | 271 | def find_all_best_thresh(main_eval, preds, exact_raw, f1_raw, na_probs, qid_to_has_ans): 272 | best_exact, exact_thresh = find_best_thresh(preds, exact_raw, na_probs, qid_to_has_ans) 273 | best_f1, f1_thresh = find_best_thresh(preds, f1_raw, na_probs, qid_to_has_ans) 274 | main_eval['best_exact'] = best_exact 275 | main_eval['best_exact_thresh'] = exact_thresh 276 | main_eval['best_f1'] = best_f1 277 | main_eval['best_f1_thresh'] = f1_thresh 278 | 279 | def find_all_best_thresh_v2(main_eval, preds, exact_raw, f1_raw, na_probs, qid_to_has_ans): 280 | best_exact, exact_thresh, has_ans_exact = find_best_thresh_v2(preds, exact_raw, na_probs, qid_to_has_ans) 281 | best_f1, f1_thresh, has_ans_f1 = find_best_thresh_v2(preds, f1_raw, na_probs, qid_to_has_ans) 282 | main_eval['best_exact'] = best_exact 283 | main_eval['best_exact_thresh'] = exact_thresh 284 | main_eval['best_f1'] = best_f1 285 | main_eval['best_f1_thresh'] = f1_thresh 286 | main_eval['has_ans_exact'] = has_ans_exact 287 | main_eval['has_ans_f1'] = has_ans_f1 288 | 289 | def main(OPTS): 290 | with open(OPTS.data_file) as f: 291 | dataset_json = json.load(f) 292 | dataset = dataset_json['data'] 293 | with open(OPTS.pred_file) as f: 294 | preds = json.load(f) 295 | if OPTS.na_prob_file: 296 | with open(OPTS.na_prob_file) as f: 297 | na_probs = json.load(f) 298 | else: 299 | na_probs = {k: 0.0 for k in preds} 300 | qid_to_has_ans = make_qid_to_has_ans(dataset) # maps qid to True/False 301 | has_ans_qids = [k for k, v in qid_to_has_ans.items() if v] 302 | no_ans_qids = [k for k, v in qid_to_has_ans.items() if not v] 303 | exact_raw, f1_raw = get_raw_scores(dataset, preds) 304 | exact_thresh = apply_no_ans_threshold(exact_raw, na_probs, qid_to_has_ans, 305 | OPTS.na_prob_thresh) 306 | f1_thresh = apply_no_ans_threshold(f1_raw, na_probs, qid_to_has_ans, 307 | OPTS.na_prob_thresh) 308 | out_eval = make_eval_dict(exact_thresh, f1_thresh) 309 | if has_ans_qids: 310 | has_ans_eval = make_eval_dict(exact_thresh, f1_thresh, qid_list=has_ans_qids) 311 | merge_eval(out_eval, has_ans_eval, 'HasAns') 312 | if no_ans_qids: 313 | no_ans_eval = make_eval_dict(exact_thresh, f1_thresh, qid_list=no_ans_qids) 314 | merge_eval(out_eval, no_ans_eval, 'NoAns') 315 | if OPTS.na_prob_file: 316 | find_all_best_thresh(out_eval, preds, exact_raw, f1_raw, na_probs, qid_to_has_ans) 317 | if OPTS.na_prob_file and OPTS.out_image_dir: 318 | run_precision_recall_analysis(out_eval, exact_raw, f1_raw, na_probs, 319 | qid_to_has_ans, OPTS.out_image_dir) 320 | histogram_na_prob(na_probs, has_ans_qids, OPTS.out_image_dir, 'hasAns') 321 | histogram_na_prob(na_probs, no_ans_qids, OPTS.out_image_dir, 'noAns') 322 | if OPTS.out_file: 323 | with open(OPTS.out_file, 'w') as f: 324 | json.dump(out_eval, f) 325 | else: 326 | print(json.dumps(out_eval, indent=2)) 327 | return out_eval 328 | 329 | if __name__ == '__main__': 330 | OPTS = parse_args() 331 | if OPTS.out_image_dir: 332 | import matplotlib 333 | matplotlib.use('Agg') 334 | import matplotlib.pyplot as plt 335 | main(OPTS) 336 | --------------------------------------------------------------------------------