├── README.md ├── extract_ontonotes_all.py ├── modeling_gated_gpt2.py ├── modeling_gpt2_dp.py ├── modeling_gpt2_pp.py ├── requirements.txt ├── run_clm.py ├── run_dp.py ├── run_pp.py ├── trainer_pp.py └── utils.py /README.md: -------------------------------------------------------------------------------- 1 | # probing-via-prompting 2 | This repository is in accompany with the paper: [Probing via Prompting](https://arxiv.org/abs/2207.01736). 3 | 4 | ## Dependencies 5 | - python 3.8.5 6 | - pytorch 1.7.1+cu110 7 | 8 | ## Setup 9 | Install required packages: 10 | ``` 11 | pip install -r requirements.txt 12 | ``` 13 | 14 | ## Data Prcoessing 15 | 1. Process your OntoNotes data with the [script](https://github.com/yuchenlin/OntoNotes-5.0-NER-BIO) 16 | 2. Extract all tasks: 17 | ``` 18 | python extract_ontonotes_all.py --ontonotes /path/to/conll-formatted-ontonotes-5.0 -o ontonotes 19 | ``` 20 | This will create two folders under `ontonotes/`, one for diagnostic probing (DP), one for probing via prompting (PP). 21 | 22 | ## Probing via Prompting 23 | ``` 24 | export task= 25 | python run_pp.py \ 26 | --num_train_epochs 1.0 \ 27 | --do_train \ 28 | --do_eval \ 29 | --per_device_train_batch_size 4 \ 30 | --per_device_eval_batch_size 4 \ 31 | --gpt2_name_or_path gpt2 \ 32 | --data_dir ontonotes/pp/ \ 33 | --task $task \ 34 | --output_dir outputs/pp/$task/ \ 35 | --overwrite_output_dir \ 36 | --use_fast_tokenizer False \ 37 | --cache_dir cache/\ 38 | --save_strategy no \ 39 | --prefix_len 200 40 | ``` 41 | `task` can be any one of `["pos", "const", "coref", "ner", "srl", "pos_control"]`. 42 | 43 | If you want to experiment on the random model, replace `--gpt2_name_or_path gpt2` with 44 | ``` 45 | --config_name gpt2 \ 46 | --tokenizer_name gpt2 \ 47 | ``` 48 | 49 | To prune attention heads for analysis, use `--do_prune`: 50 | ``` 51 | export task= 52 | python run_pp.py \ 53 | --num_train_epochs 1.0 \ 54 | --do_train \ 55 | --do_eval \ 56 | --per_device_train_batch_size 4 \ 57 | --per_device_eval_batch_size 4 \ 58 | --gpt2_name_or_path gpt2 \ 59 | --data_dir ontonotes/pp/ \ 60 | --task $task \ 61 | --output_dir outputs/pp/pruning/$task/ \ 62 | --overwrite_output_dir \ 63 | --use_fast_tokenizer False \ 64 | --cache_dir cache/ \ 65 | --save_strategy no \ 66 | --prefix_len 200 \ 67 | --do_prune \ 68 | --num_of_heads 96 \ 69 | --pruning_lr 0.1 \ 70 | --seed 0 71 | ``` 72 | 73 | ## Diagnostic Probing 74 | Multi-layer perceptron (MLP) probe: 75 | ``` 76 | export task= 77 | python run_dp.py \ 78 | --num_train_epochs 1.0 \ 79 | --do_train \ 80 | --do_eval \ 81 | --per_device_train_batch_size 32 \ 82 | --per_device_eval_batch_size 32 \ 83 | --gpt2_name_or_path gpt2 \ 84 | --data_dir ontonotes/dp/ \ 85 | --task $task \ 86 | --output_dir outputs/dp/mlp/$task/ \ 87 | --overwrite_output_dir \ 88 | --cache_dir cache/\ 89 | --save_strategy no 90 | ``` 91 | Please note that DP (MLP) does not support multi-gpus due to the incompatibility between `nn.ParameterList` in AllenNLP's `ScalarMix` and `DataParallel`. 92 | 93 | You can use linear regression (LR) probe instead by setting `--use_mlp False`: 94 | ``` 95 | export task= 96 | python run_dp.py \ 97 | --num_train_epochs 1.0 \ 98 | --do_train \ 99 | --do_eval \ 100 | --per_device_train_batch_size 32 \ 101 | --per_device_eval_batch_size 32 \ 102 | --gpt2_name_or_path gpt2 \ 103 | --data_dir ontonotes/dp/ \ 104 | --task $task \ 105 | --output_dir outputs/dp/lr/$task/ \ 106 | --overwrite_output_dir \ 107 | --cache_dir cache/\ 108 | --save_strategy no \ 109 | --mlp_dropout 0.1 \ 110 | --use_mlp False 111 | ``` 112 | 113 | DP (LR) also supports head pruning: 114 | ``` 115 | export task= 116 | python run_dp.py \ 117 | --num_train_epochs 1.0 \ 118 | --do_train \ 119 | --do_eval \ 120 | --per_device_train_batch_size 32 \ 121 | --per_device_eval_batch_size 32 \ 122 | --gpt2_name_or_path gpt2 \ 123 | --data_dir ontonotes/dp/ \ 124 | --task $task \ 125 | --output_dir outputs/dp/lr/pruning/$task/ \ 126 | --overwrite_output_dir \ 127 | --cache_dir cache/\ 128 | --save_strategy no \ 129 | --mlp_dropout 0.1 \ 130 | --use_mlp False \ 131 | --do_prune \ 132 | --num_of_heads 96 \ 133 | --pruning_lr 0.1 \ 134 | --seed 0 135 | ``` 136 | 137 | ## Amnesic Probing 138 | To evaluate language modeling loss when the essential heads stored in `/path/to/head_mask` are pruned, run 139 | ``` 140 | python run_clm.py \ 141 | --model_name_or_path gpt2 \ 142 | --dataset_name wikitext \ 143 | --dataset_config_name wikitext-103-raw-v1 \ 144 | --do_eval \ 145 | --output_dir outputs/lm/ \ 146 | --overwrite_output_dir \ 147 | --per_device_eval_batch_size 32 \ 148 | --cache_dir cache/ \ 149 | --head_mask_path /path/to/head_mask 150 | ``` 151 | -------------------------------------------------------------------------------- /extract_ontonotes_all.py: -------------------------------------------------------------------------------- 1 | import collections 2 | import json 3 | import logging as log 4 | import os 5 | import sys 6 | from typing import Dict, List, Tuple 7 | import random 8 | 9 | import numpy as np 10 | from allennlp.data.dataset_readers.dataset_utils import Ontonotes 11 | from allennlp.data.dataset_readers.dataset_utils.span_utils import bio_tags_to_spans 12 | from tqdm import tqdm 13 | 14 | from utils import LABEL_DICT 15 | 16 | CONTROL_MAPPING = {} 17 | 18 | def _span_to_string(sentence, span: Tuple[int, int]): 19 | return " ".join(sentence.split(" ")[span[0]:span[1] + 1]) 20 | 21 | def _make_target(label: List[str], span1: Tuple[int, int], span2: Tuple[int, int] = None): 22 | t = {"span1": span1, "label": label} 23 | if span2 is not None: 24 | t["span2"] = span2 25 | return t 26 | 27 | 28 | def make_record(spans, sentence): 29 | record = {} 30 | record["text"] = " ".join(sentence.words) 31 | record["targets"] = [_make_target(*s) for s in spans] 32 | return record 33 | 34 | def constituents_to_record(parse_tree): 35 | """Function converting Tree object to dictionary compatible with common JSON format 36 | copied from ptb_process.py so it doesn't have dependencies 37 | """ 38 | punctuations = ["-LRB-", "-RRB-", "-LCB-", "-RCB-", "-LSB-", "-RSB-"] 39 | 40 | pos_record = {} 41 | pos_record["text"] = " ".join(parse_tree.flatten()) 42 | pos_record["targets"] = [] 43 | 44 | non_record = {} 45 | non_record["text"] = " ".join(parse_tree.flatten()) 46 | non_record["targets"] = [] 47 | 48 | pos_control_record = {} 49 | pos_control_record["text"] = " ".join(parse_tree.flatten()) 50 | pos_control_record["targets"] = [] 51 | labels = list(LABEL_DICT['pos'].keys()) 52 | num_labels = len(labels) 53 | 54 | for i, leaf in enumerate(parse_tree.subtrees(lambda t: t.height() == 2)): 55 | # modify the leafs by adding their index in the parse_tree 56 | leaf[0] = (leaf[0], str(i)) 57 | 58 | for subtree in parse_tree.subtrees(): 59 | assoc_words = subtree.leaves() 60 | assoc_words = [(i, int(j)) for i, j in assoc_words] 61 | assoc_words.sort(key=lambda elem: elem[1]) 62 | indices = [int(assoc_words[0][1]), int(assoc_words[-1][1])] 63 | span = " ".join([word[0] for word in assoc_words]) 64 | 65 | tmp_tag_list = subtree.label().replace("=", "-").replace("|", "-").split("-") 66 | label = tmp_tag_list[0] 67 | # Special cases: 68 | if len(tmp_tag_list) > 1 and tmp_tag_list[1] == "S": # Case when we have 'PRP-S' or 'WP-S' 69 | label = tmp_tag_list[0] + "-" + tmp_tag_list[1] 70 | if ( 71 | subtree.label() in punctuations 72 | ): # Case when we have one of the strange punctions, such as round brackets 73 | label = subtree.label() 74 | target = {"span1": indices, "label": label} 75 | 76 | if subtree.height() == 2: 77 | pos_record["targets"].append(target) 78 | 79 | if span not in CONTROL_MAPPING: 80 | CONTROL_MAPPING[span] = labels[random.randint(0, num_labels - 1)] 81 | control_label = CONTROL_MAPPING[span] 82 | control_target = {"span1": indices, "label": control_label} 83 | 84 | pos_control_record["targets"].append(control_target) 85 | 86 | elif subtree.height() > 2: 87 | non_record['targets'].append(target) 88 | 89 | return pos_record, pos_control_record, non_record 90 | 91 | 92 | def get_frames(sentence): 93 | for frame, bio_tags in sentence.srl_frames: 94 | frame_targets = [] 95 | spans = bio_tags_to_spans(bio_tags) 96 | head_span = None 97 | other_spans = [] 98 | for (tag, indices) in spans: 99 | if tag == "V": 100 | head_span = indices 101 | else: 102 | other_spans.append((tag, indices)) 103 | if head_span is None: 104 | print(frame, bio_tags) 105 | for span2_tag, span2 in other_spans: 106 | frame_targets.append((span2_tag, head_span, span2)) 107 | yield frame_targets 108 | 109 | def find_links(span_list): 110 | pairs = [] 111 | for i, span1 in enumerate(span_list): 112 | for span2 in span_list[i + 1 :]: 113 | pairs.append((str(span1[0] == span2[0]), span1[1], span2[1])) 114 | return pairs 115 | 116 | def process_ontonotes(ontonotes_reader): 117 | records = {} 118 | records['ner'], records['pos'], records['pos_control'], records['const'], records['coref'], records['srl'] = [], [], [], [], [], [] 119 | for sentence in ontonotes_reader: 120 | # NER 121 | spans = bio_tags_to_spans(sentence.named_entities) 122 | if spans: 123 | records['ner'].append(make_record(spans, sentence)) 124 | 125 | # POS and constituent 126 | if sentence.parse_tree is not None: 127 | pos_record, pos_control_record, const_record = constituents_to_record(sentence.parse_tree) 128 | records['pos'].append(pos_record) 129 | records['pos_control'].append(pos_control_record) 130 | records['const'].append(const_record) 131 | 132 | # coreference 133 | spans = find_links(list(sentence.coref_spans)) 134 | if spans: 135 | records['coref'].append(make_record(spans, sentence)) 136 | 137 | # SRL 138 | for frame_spans in get_frames(sentence): 139 | if frame_spans: 140 | records['srl'].append(make_record(frame_spans, sentence)) 141 | 142 | return records 143 | 144 | def make_patterns(records, label_dict): 145 | patterns = [] 146 | for record in records: 147 | sentence = record['text'] 148 | prompt = f'{sentence}' 149 | for target in record['targets']: 150 | span1 = _span_to_string(sentence, target['span1']) 151 | temp = prompt + f"{span1}" 152 | if 'span2' in target: 153 | span2 = _span_to_string(sentence, target['span2']) 154 | temp += f"{span2}" 155 | patterns.append({"text": temp + f"<|endoftext|>{label_dict[target['label']]}"}) 156 | return patterns 157 | 158 | def write_json_data(fname, lines): 159 | with open(fname, 'w') as fd: 160 | for line in lines: 161 | fd.write(json.dumps(line)) 162 | fd.write("\n") 163 | 164 | 165 | def main(args): 166 | import argparse 167 | 168 | parser = argparse.ArgumentParser() 169 | parser.add_argument( 170 | "--ontonotes", 171 | type=str, 172 | default="conll-formatted-ontonotes-5.0", 173 | help="Path to OntoNotes, e.g. /path/to/conll-formatted-ontonotes-5.0", 174 | ) 175 | parser.add_argument( 176 | "--tasks", 177 | default=["pos", "const", "coref", "ner", "srl", "pos_control"], 178 | type=str, nargs="+", help="Tasks, one or more of {pos, const, coref, ner, srl}." 179 | ) 180 | parser.add_argument( 181 | "--splits", 182 | type=str, 183 | nargs="+", 184 | default=["train", "development", "test", "conll-2012-test"], 185 | help="Splits, one or more of {train, development, test, conll-2012-test}.", 186 | ) 187 | parser.add_argument( 188 | "-o", dest="output_dir", type=str, default="ontonotes/", help="Output directory for JSON files." 189 | ) 190 | args = parser.parse_args(args) 191 | 192 | if not os.path.isdir(args.output_dir): 193 | os.mkdir(args.output_dir) 194 | 195 | ontonotes = Ontonotes() 196 | for split in args.splits: 197 | source_path = os.path.join(args.ontonotes, "data", split) 198 | ontonotes_reader = ontonotes.dataset_iterator(file_path=source_path) 199 | converted_records = process_ontonotes(tqdm(ontonotes_reader)) 200 | for task in args.tasks: 201 | pp_task_dir = os.path.join(args.output_dir, "pp", task) 202 | dp_task_dir = os.path.join(args.output_dir, "dp", task) 203 | if not os.path.isdir(pp_task_dir): 204 | os.makedirs(pp_task_dir) 205 | if not os.path.isdir(dp_task_dir): 206 | os.makedirs(dp_task_dir) 207 | 208 | write_json_data(os.path.join(dp_task_dir, f"{split}.json"), converted_records[task]) 209 | 210 | if 'control' in task: 211 | label_dict = LABEL_DICT[task.replace("_control", "")] 212 | else: 213 | label_dict = LABEL_DICT[task] 214 | patterns = make_patterns(converted_records[task], label_dict) 215 | write_json_data(os.path.join(pp_task_dir, f"{split}.json"), patterns) 216 | 217 | 218 | 219 | if __name__ == "__main__": 220 | main(sys.argv[1:]) 221 | sys.exit(0) -------------------------------------------------------------------------------- /modeling_gated_gpt2.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2018 The OpenAI Team Authors and HuggingFace Inc. team. 3 | # Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved. 4 | # 5 | # Licensed under the Apache License, Version 2.0 (the "License"); 6 | # you may not use this file except in compliance with the License. 7 | # You may obtain a copy of the License at 8 | # 9 | # http://www.apache.org/licenses/LICENSE-2.0 10 | # 11 | # Unless required by applicable law or agreed to in writing, software 12 | # distributed under the License is distributed on an "AS IS" BASIS, 13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | # See the License for the specific language governing permissions and 15 | # limitations under the License. 16 | """PyTorch OpenAI GPT-2 model.""" 17 | 18 | import os 19 | from dataclasses import dataclass 20 | from typing import Optional, Tuple 21 | 22 | import torch 23 | import torch.utils.checkpoint 24 | from torch import nn 25 | from torch.nn import CrossEntropyLoss, MSELoss 26 | 27 | from transformers.activations import ACT2FN 28 | from transformers.file_utils import ( 29 | add_code_sample_docstrings, 30 | add_start_docstrings, 31 | add_start_docstrings_to_model_forward, 32 | ModelOutput 33 | ) 34 | from transformers.modeling_outputs import ( 35 | BaseModelOutputWithPastAndCrossAttentions, 36 | ) 37 | from transformers.modeling_utils import ( 38 | Conv1D, 39 | PreTrainedModel, 40 | find_pruneable_heads_and_indices, 41 | prune_conv1d_layer, 42 | ) 43 | from transformers.utils import logging 44 | from transformers.utils.model_parallel_utils import assert_device_map, get_device_map 45 | from transformers import GPT2Config 46 | from utils import STEFunction 47 | 48 | logger = logging.get_logger(__name__) 49 | 50 | _CHECKPOINT_FOR_DOC = "gpt2" 51 | _CONFIG_FOR_DOC = "GPT2Config" 52 | _TOKENIZER_FOR_DOC = "GPT2Tokenizer" 53 | 54 | GPT2_PRETRAINED_MODEL_ARCHIVE_LIST = [ 55 | "gpt2", 56 | "gpt2-medium", 57 | "gpt2-large", 58 | "gpt2-xl", 59 | "distilgpt2", 60 | # See all GPT-2 models at https://huggingface.co/models?filter=gpt2 61 | ] 62 | 63 | def load_tf_weights_in_gpt2(model, config, gpt2_checkpoint_path): 64 | """Load tf checkpoints in a pytorch model""" 65 | try: 66 | import re 67 | 68 | import tensorflow as tf 69 | except ImportError: 70 | logger.error( 71 | "Loading a TensorFlow model in PyTorch, requires TensorFlow to be installed. Please see " 72 | "https://www.tensorflow.org/install/ for installation instructions." 73 | ) 74 | raise 75 | tf_path = os.path.abspath(gpt2_checkpoint_path) 76 | logger.info(f"Converting TensorFlow checkpoint from {tf_path}") 77 | # Load weights from TF model 78 | init_vars = tf.train.list_variables(tf_path) 79 | names = [] 80 | arrays = [] 81 | for name, shape in init_vars: 82 | logger.info(f"Loading TF weight {name} with shape {shape}") 83 | array = tf.train.load_variable(tf_path, name) 84 | names.append(name) 85 | arrays.append(array.squeeze()) 86 | 87 | for name, array in zip(names, arrays): 88 | name = name[6:] # skip "model/" 89 | name = name.split("/") 90 | pointer = model 91 | for m_name in name: 92 | if re.fullmatch(r"[A-Za-z]+\d+", m_name): 93 | scope_names = re.split(r"(\d+)", m_name) 94 | else: 95 | scope_names = [m_name] 96 | if scope_names[0] == "w" or scope_names[0] == "g": 97 | pointer = getattr(pointer, "weight") 98 | elif scope_names[0] == "b": 99 | pointer = getattr(pointer, "bias") 100 | elif scope_names[0] == "wpe" or scope_names[0] == "wte": 101 | pointer = getattr(pointer, scope_names[0]) 102 | pointer = getattr(pointer, "weight") 103 | else: 104 | pointer = getattr(pointer, scope_names[0]) 105 | if len(scope_names) >= 2: 106 | num = int(scope_names[1]) 107 | pointer = pointer[num] 108 | try: 109 | assert ( 110 | pointer.shape == array.shape 111 | ), f"Pointer shape {pointer.shape} and array shape {array.shape} mismatched" 112 | except AssertionError as e: 113 | e.args += (pointer.shape, array.shape) 114 | raise 115 | logger.info(f"Initialize PyTorch weight {name}") 116 | pointer.data = torch.from_numpy(array) 117 | return model 118 | 119 | @dataclass 120 | class ProbingViaPromptingOutputs(ModelOutput): 121 | """ 122 | Base class for causal language model (or autoregressive) outputs. 123 | 124 | Args: 125 | loss (:obj:`torch.FloatTensor` of shape :obj:`(1,)`, `optional`, returned when :obj:`labels` is provided): 126 | Language modeling loss (for next-token prediction). 127 | logits (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, config.vocab_size)`): 128 | Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax). 129 | hidden_states (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``): 130 | Tuple of :obj:`torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer) 131 | of shape :obj:`(batch_size, sequence_length, hidden_size)`. 132 | 133 | Hidden-states of the model at the output of each layer plus the initial embedding outputs. 134 | attentions (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_attentions=True`` is passed or when ``config.output_attentions=True``): 135 | Tuple of :obj:`torch.FloatTensor` (one for each layer) of shape :obj:`(batch_size, num_heads, 136 | sequence_length, sequence_length)`. 137 | 138 | Attentions weights after the attention softmax, used to compute the weighted average in the self-attention 139 | heads. 140 | cross_attentions (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_attentions=True`` is passed or when ``config.output_attentions=True``): 141 | Tuple of :obj:`torch.FloatTensor` (one for each layer) of shape :obj:`(batch_size, num_heads, 142 | sequence_length, sequence_length)`. 143 | 144 | Cross attentions weights after the attention softmax, used to compute the weighted average in the 145 | cross-attention heads. 146 | past_key_values (:obj:`tuple(tuple(torch.FloatTensor))`, `optional`, returned when ``use_cache=True`` is passed or when ``config.use_cache=True``): 147 | Tuple of :obj:`torch.FloatTensor` tuples of length :obj:`config.n_layers`, with each tuple containing the 148 | cached key, value states of the self-attention and the cross-attention layers if model is used in 149 | encoder-decoder setting. Only relevant if ``config.is_decoder = True``. 150 | 151 | Contains pre-computed hidden-states (key and values in the attention blocks) that can be used (see 152 | :obj:`past_key_values` input) to speed up sequential decoding. 153 | """ 154 | 155 | loss: Optional[torch.FloatTensor] = None 156 | accuracy: Optional[torch.FloatTensor] = None 157 | logits: torch.FloatTensor = None 158 | past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None 159 | hidden_states: Optional[Tuple[torch.FloatTensor]] = None 160 | attentions: Optional[Tuple[torch.FloatTensor]] = None 161 | cross_attentions: Optional[Tuple[torch.FloatTensor]] = None 162 | 163 | class GPT2Attention(nn.Module): 164 | def __init__(self, config, is_cross_attention=False): 165 | super().__init__() 166 | 167 | max_positions = config.max_position_embeddings 168 | self.register_buffer( 169 | "bias", 170 | torch.tril(torch.ones((max_positions, max_positions), dtype=torch.uint8)).view( 171 | 1, 1, max_positions, max_positions 172 | ), 173 | ) 174 | self.register_buffer("masked_bias", torch.tensor(-1e4)) 175 | 176 | self.embed_dim = config.hidden_size 177 | self.num_heads = config.num_attention_heads 178 | self.head_dim = self.embed_dim // self.num_heads 179 | self.split_size = self.embed_dim 180 | if self.head_dim * self.num_heads != self.embed_dim: 181 | raise ValueError( 182 | f"`embed_dim` must be divisible by num_heads (got `embed_dim`: {self.embed_dim} and `num_heads`: {self.num_heads})." 183 | ) 184 | 185 | self.scale_attn_weights = config.scale_attn_weights 186 | self.is_cross_attention = is_cross_attention 187 | 188 | if self.is_cross_attention: 189 | self.c_attn = Conv1D(2 * self.embed_dim, self.embed_dim) 190 | self.q_attn = Conv1D(self.embed_dim, self.embed_dim) 191 | else: 192 | self.c_attn = Conv1D(3 * self.embed_dim, self.embed_dim) 193 | self.c_proj = Conv1D(self.embed_dim, self.embed_dim) 194 | 195 | self.attn_dropout = nn.Dropout(config.attn_pdrop) 196 | self.resid_dropout = nn.Dropout(config.resid_pdrop) 197 | 198 | self.pruned_heads = set() 199 | 200 | self.head_mask = None 201 | 202 | def prune_heads(self, heads): 203 | if len(heads) == 0: 204 | return 205 | heads, index = find_pruneable_heads_and_indices(heads, self.num_heads, self.head_dim, self.pruned_heads) 206 | index_attn = torch.cat([index, index + self.split_size, index + (2 * self.split_size)]) 207 | 208 | # Prune conv1d layers 209 | self.c_attn = prune_conv1d_layer(self.c_attn, index_attn, dim=1) 210 | self.c_proj = prune_conv1d_layer(self.c_proj, index, dim=0) 211 | 212 | # Update hyper params 213 | self.split_size = (self.split_size // self.num_heads) * (self.num_heads - len(heads)) 214 | self.num_heads = self.num_heads - len(heads) 215 | self.pruned_heads = self.pruned_heads.union(heads) 216 | 217 | def _attn(self, query, key, value, attention_mask=None): 218 | attn_weights = torch.matmul(query, key.transpose(-1, -2)) 219 | 220 | if self.scale_attn_weights: 221 | attn_weights = attn_weights / (float(value.size(-1)) ** 0.5) 222 | 223 | if not self.is_cross_attention: 224 | # if only "normal" attention layer implements causal mask 225 | query_length, key_length = query.size(-2), key.size(-2) 226 | causal_mask = self.bias[:, :, key_length - query_length : key_length, :key_length].bool() 227 | attn_weights = torch.where(causal_mask, attn_weights, self.masked_bias.to(attn_weights.dtype)) 228 | 229 | if attention_mask is not None: 230 | # Apply the attention mask 231 | attn_weights = attn_weights + attention_mask 232 | 233 | attn_weights = nn.Softmax(dim=-1)(attn_weights) 234 | attn_weights = self.attn_dropout(attn_weights) 235 | 236 | # Mask heads if we want to 237 | if self.head_mask is not None: 238 | head_mask = self.head_mask.to(attn_weights.device) 239 | attn_weights = attn_weights * head_mask 240 | 241 | attn_output = torch.matmul(attn_weights, value) 242 | 243 | return attn_output, attn_weights 244 | 245 | def _split_heads(self, tensor, num_heads, attn_head_size): 246 | """ 247 | Splits hidden_size dim into attn_head_size and num_heads 248 | """ 249 | new_shape = tensor.size()[:-1] + (num_heads, attn_head_size) 250 | tensor = tensor.view(*new_shape) 251 | return tensor.permute(0, 2, 1, 3) # (batch, head, seq_length, head_features) 252 | 253 | def _merge_heads(self, tensor, num_heads, attn_head_size): 254 | """ 255 | Merges attn_head_size dim and num_attn_heads dim into hidden_size 256 | """ 257 | tensor = tensor.permute(0, 2, 1, 3).contiguous() 258 | new_shape = tensor.size()[:-2] + (num_heads * attn_head_size,) 259 | return tensor.view(new_shape) 260 | 261 | def forward( 262 | self, 263 | hidden_states, 264 | layer_past=None, 265 | attention_mask=None, 266 | head_mask=None, 267 | encoder_hidden_states=None, 268 | encoder_attention_mask=None, 269 | use_cache=False, 270 | output_attentions=False, 271 | ): 272 | if encoder_hidden_states is not None: 273 | if not hasattr(self, "q_attn"): 274 | raise ValueError( 275 | "If class is used as cross attention, the weights `q_attn` have to be defined. " 276 | "Please make sure to instantiate class with `GPT2Attention(..., is_cross_attention=True)`." 277 | ) 278 | 279 | query = self.q_attn(hidden_states) 280 | key, value = self.c_attn(encoder_hidden_states).split(self.split_size, dim=2) 281 | attention_mask = encoder_attention_mask 282 | else: 283 | query, key, value = self.c_attn(hidden_states).split(self.split_size, dim=2) 284 | 285 | query = self._split_heads(query, self.num_heads, self.head_dim) 286 | key = self._split_heads(key, self.num_heads, self.head_dim) 287 | value = self._split_heads(value, self.num_heads, self.head_dim) 288 | 289 | if layer_past is not None: 290 | past_key, past_value = layer_past 291 | key = torch.cat((past_key, key), dim=-2) 292 | value = torch.cat((past_value, value), dim=-2) 293 | 294 | if use_cache is True: 295 | present = (key, value) 296 | else: 297 | present = None 298 | 299 | attn_output, attn_weights = self._attn(query, key, value, attention_mask) 300 | 301 | attn_output = self._merge_heads(attn_output, self.num_heads, self.head_dim) 302 | attn_output = self.c_proj(attn_output) 303 | attn_output = self.resid_dropout(attn_output) 304 | 305 | outputs = (attn_output, present) 306 | if output_attentions: 307 | outputs += (attn_weights,) 308 | 309 | return outputs # a, present, (attentions) 310 | 311 | def apply_masks(self, head_mask): 312 | self.head_mask = head_mask 313 | 314 | def get_masks(self): 315 | if self.head_mask is not None: 316 | return self.head_mask.flatten() 317 | else: 318 | return None 319 | 320 | 321 | class GPT2MLP(nn.Module): 322 | def __init__(self, intermediate_size, config): 323 | super().__init__() 324 | embed_dim = config.hidden_size 325 | self.c_fc = Conv1D(intermediate_size, embed_dim) 326 | self.c_proj = Conv1D(embed_dim, intermediate_size) 327 | self.act = ACT2FN[config.activation_function] 328 | self.dropout = nn.Dropout(config.resid_pdrop) 329 | 330 | def forward(self, hidden_states): 331 | hidden_states = self.c_fc(hidden_states) 332 | hidden_states = self.act(hidden_states) 333 | hidden_states = self.c_proj(hidden_states) 334 | hidden_states = self.dropout(hidden_states) 335 | return hidden_states 336 | 337 | 338 | class GPT2Block(nn.Module): 339 | def __init__(self, config): 340 | super().__init__() 341 | hidden_size = config.hidden_size 342 | inner_dim = config.n_inner if config.n_inner is not None else 4 * hidden_size 343 | 344 | self.ln_1 = nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon) 345 | self.attn = GPT2Attention(config) 346 | self.ln_2 = nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon) 347 | 348 | if config.add_cross_attention: 349 | self.crossattention = GPT2Attention(config, is_cross_attention=True) 350 | self.ln_cross_attn = nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon) 351 | 352 | self.mlp = GPT2MLP(inner_dim, config) 353 | 354 | def forward( 355 | self, 356 | hidden_states, 357 | layer_past=None, 358 | attention_mask=None, 359 | head_mask=None, 360 | encoder_hidden_states=None, 361 | encoder_attention_mask=None, 362 | use_cache=False, 363 | output_attentions=False, 364 | ): 365 | residual = hidden_states 366 | hidden_states = self.ln_1(hidden_states) 367 | attn_outputs = self.attn( 368 | hidden_states, 369 | layer_past=layer_past, 370 | attention_mask=attention_mask, 371 | head_mask=head_mask, 372 | use_cache=use_cache, 373 | output_attentions=output_attentions, 374 | ) 375 | attn_output = attn_outputs[0] # output_attn: a, present, (attentions) 376 | outputs = attn_outputs[1:] 377 | # residual connection 378 | hidden_states = attn_output + residual 379 | 380 | if encoder_hidden_states is not None: 381 | # add one self-attention block for cross-attention 382 | if not hasattr(self, "crossattention"): 383 | raise ValueError( 384 | f"If `encoder_hidden_states` are passed, {self} has to be instantiated with " 385 | "cross-attention layers by setting `config.add_cross_attention=True`" 386 | ) 387 | residual = hidden_states 388 | hidden_states = self.ln_cross_attn(hidden_states) 389 | cross_attn_outputs = self.crossattention( 390 | hidden_states, 391 | attention_mask=attention_mask, 392 | head_mask=head_mask, 393 | encoder_hidden_states=encoder_hidden_states, 394 | encoder_attention_mask=encoder_attention_mask, 395 | output_attentions=output_attentions, 396 | ) 397 | attn_output = cross_attn_outputs[0] 398 | # residual connection 399 | hidden_states = residual + attn_output 400 | outputs = outputs + cross_attn_outputs[2:] # add cross attentions if we output attention weights 401 | 402 | residual = hidden_states 403 | hidden_states = self.ln_2(hidden_states) 404 | feed_forward_hidden_states = self.mlp(hidden_states) 405 | # residual connection 406 | hidden_states = residual + feed_forward_hidden_states 407 | 408 | if use_cache: 409 | outputs = (hidden_states,) + outputs 410 | else: 411 | outputs = (hidden_states,) + outputs[1:] 412 | 413 | return outputs # hidden_states, present, (attentions, cross_attentions) 414 | 415 | def apply_masks(self, head_mask): 416 | self.attn.apply_masks(head_mask) 417 | 418 | def get_masks(self): 419 | return self.attn.get_masks() 420 | 421 | 422 | class GPT2PreTrainedModel(PreTrainedModel): 423 | """ 424 | An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained 425 | models. 426 | """ 427 | 428 | config_class = GPT2Config 429 | load_tf_weights = load_tf_weights_in_gpt2 430 | base_model_prefix = "transformer" 431 | is_parallelizable = True 432 | 433 | def __init__(self, *inputs, **kwargs): 434 | super().__init__(*inputs, **kwargs) 435 | 436 | def _init_weights(self, module): 437 | """Initialize the weights.""" 438 | if isinstance(module, (nn.Linear, Conv1D)): 439 | # Slightly different from the TF version which uses truncated_normal for initialization 440 | # cf https://github.com/pytorch/pytorch/pull/5617 441 | module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) 442 | if module.bias is not None: 443 | module.bias.data.zero_() 444 | elif isinstance(module, nn.Embedding): 445 | module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) 446 | if module.padding_idx is not None: 447 | module.weight.data[module.padding_idx].zero_() 448 | elif isinstance(module, nn.LayerNorm): 449 | module.bias.data.zero_() 450 | module.weight.data.fill_(1.0) 451 | 452 | GPT2_START_DOCSTRING = r""" 453 | 454 | This model inherits from :class:`~transformers.PreTrainedModel`. Check the superclass documentation for the generic 455 | methods the library implements for all its model (such as downloading or saving, resizing the input embeddings, 456 | pruning heads etc.) 457 | 458 | This model is also a PyTorch `torch.nn.Module `__ 459 | subclass. Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to 460 | general usage and behavior. 461 | 462 | Parameters: 463 | config (:class:`~transformers.GPT2Config`): Model configuration class with all the parameters of the model. 464 | Initializing with a config file does not load the weights associated with the model, only the 465 | configuration. Check out the :meth:`~transformers.PreTrainedModel.from_pretrained` method to load the model 466 | weights. 467 | """ 468 | 469 | GPT2_INPUTS_DOCSTRING = r""" 470 | Args: 471 | input_ids (:obj:`torch.LongTensor` of shape :obj:`(batch_size, input_ids_length)`): 472 | :obj:`input_ids_length` = ``sequence_length`` if :obj:`past_key_values` is ``None`` else 473 | ``past_key_values[0][0].shape[-2]`` (``sequence_length`` of input past key value states). Indices of input 474 | sequence tokens in the vocabulary. 475 | 476 | If :obj:`past_key_values` is used, only ``input_ids`` that do not have their past calculated should be 477 | passed as ``input_ids``. 478 | 479 | Indices can be obtained using :class:`~transformers.GPT2Tokenizer`. See 480 | :meth:`transformers.PreTrainedTokenizer.encode` and :meth:`transformers.PreTrainedTokenizer.__call__` for 481 | details. 482 | 483 | `What are input IDs? <../glossary.html#input-ids>`__ 484 | past_key_values (:obj:`Tuple[Tuple[torch.Tensor]]` of length :obj:`config.n_layers`): 485 | Contains precomputed hidden-states (key and values in the attention blocks) as computed by the model (see 486 | :obj:`past_key_values` output below). Can be used to speed up sequential decoding. The ``input_ids`` which 487 | have their past given to this model should not be passed as ``input_ids`` as they have already been 488 | computed. 489 | attention_mask (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`): 490 | Mask to avoid performing attention on padding token indices. Mask values selected in ``[0, 1]``: 491 | 492 | - 1 for tokens that are **not masked**, 493 | - 0 for tokens that are **masked**. 494 | 495 | `What are attention masks? <../glossary.html#attention-mask>`__ 496 | token_type_ids (:obj:`torch.LongTensor` of shape :obj:`(batch_size, input_ids_length)`, `optional`): 497 | Segment token indices to indicate first and second portions of the inputs. Indices are selected in ``[0, 498 | 1]``: 499 | 500 | - 0 corresponds to a `sentence A` token, 501 | - 1 corresponds to a `sentence B` token. 502 | 503 | `What are token type IDs? <../glossary.html#token-type-ids>`_ 504 | position_ids (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`): 505 | Indices of positions of each input sequence tokens in the position embeddings. Selected in the range ``[0, 506 | config.max_position_embeddings - 1]``. 507 | 508 | `What are position IDs? <../glossary.html#position-ids>`_ 509 | head_mask (:obj:`torch.FloatTensor` of shape :obj:`(num_heads,)` or :obj:`(num_layers, num_heads)`, `optional`): 510 | Mask to nullify selected heads of the self-attention modules. Mask values selected in ``[0, 1]``: 511 | 512 | - 1 indicates the head is **not masked**, 513 | - 0 indicates the head is **masked**. 514 | 515 | inputs_embeds (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`, `optional`): 516 | Optionally, instead of passing :obj:`input_ids` you can choose to directly pass an embedded representation. 517 | This is useful if you want more control over how to convert :obj:`input_ids` indices into associated 518 | vectors than the model's internal embedding lookup matrix. 519 | 520 | If :obj:`past_key_values` is used, optionally only the last :obj:`inputs_embeds` have to be input (see 521 | :obj:`past_key_values`). 522 | use_cache (:obj:`bool`, `optional`): 523 | If set to :obj:`True`, :obj:`past_key_values` key value states are returned and can be used to speed up 524 | decoding (see :obj:`past_key_values`). 525 | output_attentions (:obj:`bool`, `optional`): 526 | Whether or not to return the attentions tensors of all attention layers. See ``attentions`` under returned 527 | tensors for more detail. 528 | output_hidden_states (:obj:`bool`, `optional`): 529 | Whether or not to return the hidden states of all layers. See ``hidden_states`` under returned tensors for 530 | more detail. 531 | return_dict (:obj:`bool`, `optional`): 532 | Whether or not to return a :class:`~transformers.file_utils.ModelOutput` instead of a plain tuple. 533 | """ 534 | PARALLELIZE_DOCSTRING = r""" 535 | This is an experimental feature and is a subject to change at a moment's notice. 536 | 537 | Uses a device map to distribute attention modules of the model across several devices. If no device map is given, 538 | it will evenly distribute blocks across all devices. 539 | 540 | Args: 541 | device_map (:obj:`Dict[int, list]`, optional, defaults to None): 542 | A dictionary that maps attention modules to devices. Note that the embedding module and LMHead are always 543 | automatically mapped to the first device (for esoteric reasons). That means that the first device should 544 | have fewer attention modules mapped to it than other devices. For reference, the gpt2 models have the 545 | following number of attention modules: 546 | 547 | - gpt2: 12 548 | - gpt2-medium: 24 549 | - gpt2-large: 36 550 | - gpt2-xl: 48 551 | 552 | Example:: 553 | 554 | # Here is an example of a device map on a machine with 4 GPUs using gpt2-xl, which has a total of 48 attention modules: 555 | model = GPT2LMHeadModel.from_pretrained('gpt2-xl') 556 | device_map = {0: [0, 1, 2, 3, 4, 5, 6, 7, 8], 557 | 558 | 1: [9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21], 559 | 2: [22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34], 560 | 3: [35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47]} 561 | model.parallelize(device_map) 562 | """ 563 | DEPARALLELIZE_DOCSTRING = r""" 564 | Moves the model to cpu from a model parallel state. 565 | 566 | Example:: 567 | 568 | # On a 4 GPU machine with gpt2-large: 569 | model = GPT2LMHeadModel.from_pretrained('gpt2-large') 570 | device_map = {0: [0, 1, 2, 3, 4, 5, 6, 7], 571 | 572 | 1: [8, 9, 10, 11, 12, 13, 14, 15], 573 | 2: [16, 17, 18, 19, 20, 21, 22, 23], 574 | 3: [24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35]} 575 | model.parallelize(device_map) # Splits the model across several devices 576 | model.deparallelize() # Put the model back on cpu and cleans memory by calling torch.cuda.empty_cache() 577 | """ 578 | 579 | 580 | @add_start_docstrings( 581 | "The bare GPT2 Model transformer outputting raw hidden-states without any specific head on top.", 582 | GPT2_START_DOCSTRING, 583 | ) 584 | class GPT2Model(GPT2PreTrainedModel): 585 | _keys_to_ignore_on_load_missing = ["attn.masked_bias"] 586 | 587 | def __init__(self, config): 588 | super().__init__(config) 589 | 590 | self.embed_dim = config.hidden_size 591 | 592 | self.wte = nn.Embedding(config.vocab_size, self.embed_dim) 593 | self.wpe = nn.Embedding(config.max_position_embeddings, self.embed_dim) 594 | 595 | self.drop = nn.Dropout(config.embd_pdrop) 596 | self.h = nn.ModuleList([GPT2Block(config) for _ in range(config.num_hidden_layers)]) 597 | self.ln_f = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_epsilon) 598 | 599 | self.init_weights() 600 | 601 | # Model parallel 602 | self.model_parallel = False 603 | self.device_map = None 604 | 605 | @add_start_docstrings(PARALLELIZE_DOCSTRING) 606 | def parallelize(self, device_map=None): 607 | # Check validity of device_map 608 | self.device_map = ( 609 | get_device_map(len(self.h), range(torch.cuda.device_count())) if device_map is None else device_map 610 | ) 611 | assert_device_map(self.device_map, len(self.h)) 612 | self.model_parallel = True 613 | self.first_device = "cpu" if "cpu" in self.device_map.keys() else "cuda:" + str(min(self.device_map.keys())) 614 | self.last_device = "cuda:" + str(max(self.device_map.keys())) 615 | self.wte = self.wte.to(self.first_device) 616 | self.wpe = self.wpe.to(self.first_device) 617 | # Load onto devices 618 | for k, v in self.device_map.items(): 619 | for block in v: 620 | cuda_device = "cuda:" + str(k) 621 | self.h[block] = self.h[block].to(cuda_device) 622 | # ln_f to last 623 | self.ln_f = self.ln_f.to(self.last_device) 624 | 625 | @add_start_docstrings(DEPARALLELIZE_DOCSTRING) 626 | def deparallelize(self): 627 | self.model_parallel = False 628 | self.device_map = None 629 | self.first_device = "cpu" 630 | self.last_device = "cpu" 631 | self.wte = self.wte.to("cpu") 632 | self.wpe = self.wpe.to("cpu") 633 | for index in range(len(self.h)): 634 | self.h[index] = self.h[index].to("cpu") 635 | self.ln_f = self.ln_f.to("cpu") 636 | torch.cuda.empty_cache() 637 | 638 | def get_input_embeddings(self): 639 | return self.wte 640 | 641 | def set_input_embeddings(self, new_embeddings): 642 | self.wte = new_embeddings 643 | 644 | def _prune_heads(self, heads_to_prune): 645 | """ 646 | Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer} 647 | """ 648 | for layer, heads in heads_to_prune.items(): 649 | self.h[layer].attn.prune_heads(heads) 650 | 651 | def apply_masks(self, head_mask): 652 | for layer, module in enumerate(self.h): 653 | module.apply_masks(head_mask[layer]) 654 | 655 | def get_masks(self): 656 | return [module.get_masks() for module in self.h] 657 | 658 | @add_start_docstrings_to_model_forward(GPT2_INPUTS_DOCSTRING) 659 | @add_code_sample_docstrings( 660 | processor_class=_TOKENIZER_FOR_DOC, 661 | checkpoint=_CHECKPOINT_FOR_DOC, 662 | output_type=BaseModelOutputWithPastAndCrossAttentions, 663 | config_class=_CONFIG_FOR_DOC, 664 | ) 665 | def forward( 666 | self, 667 | input_ids=None, 668 | past_key_values=None, 669 | attention_mask=None, 670 | token_type_ids=None, 671 | position_ids=None, 672 | head_mask=None, 673 | inputs_embeds=None, 674 | encoder_hidden_states=None, 675 | encoder_attention_mask=None, 676 | use_cache=None, 677 | output_attentions=None, 678 | output_hidden_states=None, 679 | return_dict=None, 680 | ): 681 | output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions 682 | output_hidden_states = ( 683 | output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states 684 | ) 685 | use_cache = use_cache if use_cache is not None else self.config.use_cache 686 | return_dict = return_dict if return_dict is not None else self.config.use_return_dict 687 | 688 | if input_ids is not None and inputs_embeds is not None: 689 | raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time") 690 | elif input_ids is not None: 691 | input_shape = input_ids.size() 692 | input_ids = input_ids.view(-1, input_shape[-1]) 693 | batch_size = input_ids.shape[0] 694 | elif inputs_embeds is not None: 695 | input_shape = inputs_embeds.size()[:-1] 696 | batch_size = inputs_embeds.shape[0] 697 | else: 698 | raise ValueError("You have to specify either input_ids or inputs_embeds") 699 | 700 | device = input_ids.device if input_ids is not None else inputs_embeds.device 701 | 702 | if token_type_ids is not None: 703 | token_type_ids = token_type_ids.view(-1, input_shape[-1]) 704 | if position_ids is not None: 705 | position_ids = position_ids.view(-1, input_shape[-1]) 706 | 707 | if past_key_values is None: 708 | past_length = 0 709 | past_key_values = tuple([None] * len(self.h)) 710 | else: 711 | past_length = past_key_values[0][0].size(-2) 712 | if position_ids is None: 713 | position_ids = torch.arange(past_length, input_shape[-1] + past_length, dtype=torch.long, device=device) 714 | position_ids = position_ids.unsqueeze(0).view(-1, input_shape[-1]) 715 | 716 | # GPT2Attention mask. 717 | if attention_mask is not None: 718 | if batch_size <= 0: 719 | raise ValueError("batch_size has to be defined and > 0") 720 | attention_mask = attention_mask.view(batch_size, -1) 721 | # We create a 3D attention mask from a 2D tensor mask. 722 | # Sizes are [batch_size, 1, 1, to_seq_length] 723 | # So we can broadcast to [batch_size, num_heads, from_seq_length, to_seq_length] 724 | # this attention mask is more simple than the triangular masking of causal attention 725 | # used in OpenAI GPT, we just need to prepare the broadcast dimension here. 726 | attention_mask = attention_mask[:, None, None, :] 727 | 728 | # Since attention_mask is 1.0 for positions we want to attend and 0.0 for 729 | # masked positions, this operation will create a tensor which is 0.0 for 730 | # positions we want to attend and -10000.0 for masked positions. 731 | # Since we are adding it to the raw scores before the softmax, this is 732 | # effectively the same as removing these entirely. 733 | attention_mask = attention_mask.to(dtype=self.dtype) # fp16 compatibility 734 | attention_mask = (1.0 - attention_mask) * -10000.0 735 | 736 | # If a 2D ou 3D attention mask is provided for the cross-attention 737 | # we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length] 738 | if self.config.add_cross_attention and encoder_hidden_states is not None: 739 | encoder_batch_size, encoder_sequence_length, _ = encoder_hidden_states.size() 740 | encoder_hidden_shape = (encoder_batch_size, encoder_sequence_length) 741 | if encoder_attention_mask is None: 742 | encoder_attention_mask = torch.ones(encoder_hidden_shape, device=device) 743 | encoder_attention_mask = self.invert_attention_mask(encoder_attention_mask) 744 | else: 745 | encoder_attention_mask = None 746 | 747 | # Prepare head mask if needed 748 | # 1.0 in head_mask indicate we keep the head 749 | # attention_probs has shape bsz x n_heads x N x N 750 | # head_mask has shape n_layer x batch x n_heads x N x N 751 | head_mask = self.get_head_mask(head_mask, self.config.n_layer) 752 | 753 | if inputs_embeds is None: 754 | inputs_embeds = self.wte(input_ids) 755 | position_embeds = self.wpe(position_ids) 756 | hidden_states = inputs_embeds + position_embeds 757 | 758 | if token_type_ids is not None: 759 | token_type_embeds = self.wte(token_type_ids) 760 | hidden_states = hidden_states + token_type_embeds 761 | 762 | hidden_states = self.drop(hidden_states) 763 | 764 | output_shape = input_shape + (hidden_states.size(-1),) 765 | 766 | presents = () if use_cache else None 767 | all_self_attentions = () if output_attentions else None 768 | all_cross_attentions = () if output_attentions and self.config.add_cross_attention else None 769 | all_hidden_states = () if output_hidden_states else None 770 | for i, (block, layer_past) in enumerate(zip(self.h, past_key_values)): 771 | 772 | # Model parallel 773 | if self.model_parallel: 774 | torch.cuda.set_device(hidden_states.device) 775 | # Ensure layer_past is on same device as hidden_states (might not be correct) 776 | if layer_past is not None: 777 | layer_past = tuple(past_state.to(hidden_states.device) for past_state in layer_past) 778 | # Ensure that attention_mask is always on the same device as hidden_states 779 | if attention_mask is not None: 780 | attention_mask = attention_mask.to(hidden_states.device) 781 | if isinstance(head_mask, torch.Tensor): 782 | head_mask = head_mask.to(hidden_states.device) 783 | if output_hidden_states: 784 | all_hidden_states = all_hidden_states + (hidden_states,) 785 | 786 | if getattr(self.config, "gradient_checkpointing", False) and self.training: 787 | 788 | if use_cache: 789 | logger.warning( 790 | "`use_cache=True` is incompatible with `config.gradient_checkpointing=True`. Setting " 791 | "`use_cache=False`..." 792 | ) 793 | use_cache = False 794 | 795 | def create_custom_forward(module): 796 | def custom_forward(*inputs): 797 | # None for past_key_value 798 | return module(*inputs, use_cache, output_attentions) 799 | 800 | return custom_forward 801 | 802 | outputs = torch.utils.checkpoint.checkpoint( 803 | create_custom_forward(block), 804 | hidden_states, 805 | None, 806 | attention_mask, 807 | head_mask[i], 808 | encoder_hidden_states, 809 | encoder_attention_mask, 810 | ) 811 | else: 812 | outputs = block( 813 | hidden_states, 814 | layer_past=layer_past, 815 | attention_mask=attention_mask, 816 | head_mask=head_mask[i], 817 | encoder_hidden_states=encoder_hidden_states, 818 | encoder_attention_mask=encoder_attention_mask, 819 | use_cache=use_cache, 820 | output_attentions=output_attentions, 821 | ) 822 | 823 | hidden_states = outputs[0] 824 | if use_cache is True: 825 | presents = presents + (outputs[1],) 826 | 827 | if output_attentions: 828 | all_self_attentions = all_self_attentions + (outputs[2 if use_cache else 1],) 829 | if self.config.add_cross_attention: 830 | all_cross_attentions = all_cross_attentions + (outputs[3 if use_cache else 2],) 831 | 832 | # Model Parallel: If it's the last layer for that device, put things on the next device 833 | if self.model_parallel: 834 | for k, v in self.device_map.items(): 835 | if i == v[-1] and "cuda:" + str(k) != self.last_device: 836 | hidden_states = hidden_states.to("cuda:" + str(k + 1)) 837 | 838 | hidden_states = self.ln_f(hidden_states) 839 | 840 | hidden_states = hidden_states.view(*output_shape) 841 | # Add last hidden state 842 | if output_hidden_states: 843 | all_hidden_states = all_hidden_states + (hidden_states,) 844 | 845 | if not return_dict: 846 | return tuple( 847 | v 848 | for v in [hidden_states, presents, all_hidden_states, all_self_attentions, all_cross_attentions] 849 | if v is not None 850 | ) 851 | 852 | return BaseModelOutputWithPastAndCrossAttentions( 853 | last_hidden_state=hidden_states, 854 | past_key_values=presents, 855 | hidden_states=all_hidden_states, 856 | attentions=all_self_attentions, 857 | cross_attentions=all_cross_attentions, 858 | ) 859 | 860 | 861 | @add_start_docstrings( 862 | """ 863 | The GPT2 Model transformer with a language modeling head on top (linear layer with weights tied to the input 864 | embeddings). 865 | """, 866 | GPT2_START_DOCSTRING, 867 | ) 868 | class GatedGPT2LMHeadModel(GPT2PreTrainedModel): 869 | _keys_to_ignore_on_load_missing = [r"attn.masked_bias", r"attn.bias", r"lm_head.weight"] 870 | 871 | def __init__(self, config): 872 | super().__init__(config) 873 | self.transformer = GPT2Model(config) 874 | self.lm_head = nn.Linear(config.n_embd, config.vocab_size, bias=False) 875 | 876 | self.init_weights() 877 | 878 | # Model parallel 879 | self.model_parallel = False 880 | self.device_map = None 881 | 882 | self.w = nn.Parameter(torch.empty([config.num_hidden_layers, config.num_attention_heads])) 883 | nn.init.xavier_uniform(self.w) 884 | self.num_of_heads = None 885 | self.use_dsp = False 886 | 887 | self.num_labels = config.num_labels 888 | self.eval_acc = False 889 | 890 | @add_start_docstrings(PARALLELIZE_DOCSTRING) 891 | def parallelize(self, device_map=None): 892 | self.device_map = ( 893 | get_device_map(len(self.transformer.h), range(torch.cuda.device_count())) 894 | if device_map is None 895 | else device_map 896 | ) 897 | assert_device_map(self.device_map, len(self.transformer.h)) 898 | self.transformer.parallelize(self.device_map) 899 | self.lm_head = self.lm_head.to(self.transformer.first_device) 900 | self.model_parallel = True 901 | 902 | @add_start_docstrings(DEPARALLELIZE_DOCSTRING) 903 | def deparallelize(self): 904 | self.transformer.deparallelize() 905 | self.transformer = self.transformer.to("cpu") 906 | self.lm_head = self.lm_head.to("cpu") 907 | self.model_parallel = False 908 | torch.cuda.empty_cache() 909 | 910 | def get_output_embeddings(self): 911 | return self.lm_head 912 | 913 | def set_output_embeddings(self, new_embeddings): 914 | self.lm_head = new_embeddings 915 | 916 | def prepare_inputs_for_generation(self, input_ids, past=None, **kwargs): 917 | token_type_ids = kwargs.get("token_type_ids", None) 918 | # only last token for inputs_ids if past is defined in kwargs 919 | if past: 920 | input_ids = input_ids[:, -1].unsqueeze(-1) 921 | if token_type_ids is not None: 922 | token_type_ids = token_type_ids[:, -1].unsqueeze(-1) 923 | 924 | attention_mask = kwargs.get("attention_mask", None) 925 | position_ids = kwargs.get("position_ids", None) 926 | 927 | if attention_mask is not None and position_ids is None: 928 | # create position_ids on the fly for batch generation 929 | position_ids = attention_mask.long().cumsum(-1) - 1 930 | position_ids.masked_fill_(attention_mask == 0, 1) 931 | if past: 932 | position_ids = position_ids[:, -1].unsqueeze(-1) 933 | else: 934 | position_ids = None 935 | return { 936 | "input_ids": input_ids, 937 | "past_key_values": past, 938 | "use_cache": kwargs.get("use_cache"), 939 | "position_ids": position_ids, 940 | "attention_mask": attention_mask, 941 | "token_type_ids": token_type_ids, 942 | } 943 | 944 | @add_start_docstrings_to_model_forward(GPT2_INPUTS_DOCSTRING) 945 | @add_code_sample_docstrings( 946 | processor_class=_TOKENIZER_FOR_DOC, 947 | checkpoint=_CHECKPOINT_FOR_DOC, 948 | output_type=ProbingViaPromptingOutputs, 949 | config_class=_CONFIG_FOR_DOC, 950 | ) 951 | def forward( 952 | self, 953 | input_ids=None, 954 | past_key_values=None, 955 | attention_mask=None, 956 | token_type_ids=None, 957 | position_ids=None, 958 | head_mask=None, 959 | inputs_embeds=None, 960 | encoder_hidden_states=None, 961 | encoder_attention_mask=None, 962 | labels=None, 963 | use_cache=None, 964 | output_attentions=None, 965 | output_hidden_states=None, 966 | return_dict=None, 967 | ): 968 | r""" 969 | labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`): 970 | Labels for language modeling. Note that the labels **are shifted** inside the model, i.e. you can set 971 | ``labels = input_ids`` Indices are selected in ``[-100, 0, ..., config.vocab_size]`` All labels set to 972 | ``-100`` are ignored (masked), the loss is only computed for labels in ``[0, ..., config.vocab_size]`` 973 | """ 974 | return_dict = return_dict if return_dict is not None else self.config.use_return_dict 975 | 976 | if self.use_dsp: 977 | head_mask = STEFunction.apply(self.w.view(-1), self.num_of_heads).view_as(self.w) 978 | self.apply_masks(head_mask) 979 | 980 | transformer_outputs = self.transformer( 981 | input_ids, 982 | past_key_values=past_key_values, 983 | attention_mask=attention_mask, 984 | token_type_ids=token_type_ids, 985 | position_ids=position_ids, 986 | head_mask=head_mask, 987 | inputs_embeds=inputs_embeds, 988 | encoder_hidden_states=encoder_hidden_states, 989 | encoder_attention_mask=encoder_attention_mask, 990 | use_cache=use_cache, 991 | output_attentions=output_attentions, 992 | output_hidden_states=output_hidden_states, 993 | return_dict=return_dict, 994 | ) 995 | hidden_states = transformer_outputs[0] 996 | 997 | # Set device for model parallelism 998 | if self.model_parallel: 999 | torch.cuda.set_device(self.transformer.first_device) 1000 | hidden_states = hidden_states.to(self.lm_head.weight.device) 1001 | 1002 | lm_logits = self.lm_head(hidden_states) 1003 | 1004 | loss = None 1005 | if labels is not None: 1006 | # Shift so that tokens < n predict n 1007 | shift_logits = lm_logits[..., :-1, :].contiguous() 1008 | shift_labels = labels[..., 1:].contiguous() 1009 | # Flatten the tokens 1010 | loss_fct = CrossEntropyLoss() 1011 | loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1)) 1012 | 1013 | if self.eval_acc: 1014 | # Token prediction 1015 | # Take out the tokens for prediction 1016 | # There is only one token per example that has labels not equal to -100 1017 | indices = labels != -100 1018 | selected_labels = labels[indices] 1019 | # print('selected_labels', selected_labels) 1020 | num_tokens = lm_logits.shape[-1] 1021 | correct_labels = selected_labels - num_tokens + self.num_labels # convert to 0-based to match with pred_labels 1022 | # print('correct_labels', correct_labels) 1023 | selected_logits = lm_logits[:,:-1,:][indices[:,1:]] # we need to shift one position backward for logits 1024 | pred_labels = selected_logits[:,-self.num_labels:].argmax(-1) # predict only from candidate tokens 1025 | # print('pred_labels', pred_labels) 1026 | accuracy = (correct_labels == pred_labels).float() 1027 | else: 1028 | accuracy = None 1029 | 1030 | if not return_dict: 1031 | output = (lm_logits,) + transformer_outputs[1:] 1032 | return ((loss, accuracy) + output) if loss is not None else output 1033 | 1034 | return ProbingViaPromptingOutputs( 1035 | loss=loss, 1036 | accuracy=accuracy, 1037 | logits=lm_logits, 1038 | past_key_values=transformer_outputs.past_key_values, 1039 | hidden_states=transformer_outputs.hidden_states, 1040 | attentions=transformer_outputs.attentions, 1041 | cross_attentions=transformer_outputs.cross_attentions, 1042 | ) 1043 | 1044 | @staticmethod 1045 | def _reorder_cache(past: Tuple[Tuple[torch.Tensor]], beam_idx: torch.Tensor) -> Tuple[Tuple[torch.Tensor]]: 1046 | """ 1047 | This function is used to re-order the :obj:`past_key_values` cache if 1048 | :meth:`~transformers.PreTrainedModel.beam_search` or :meth:`~transformers.PreTrainedModel.beam_sample` is 1049 | called. This is required to match :obj:`past_key_values` with the correct beam_idx at every generation step. 1050 | """ 1051 | return tuple( 1052 | tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past) 1053 | for layer_past in past 1054 | ) 1055 | 1056 | def apply_masks(self, head_mask): 1057 | head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers) 1058 | self.transformer.apply_masks(head_mask) 1059 | 1060 | def get_masks(self): 1061 | return torch.stack(self.transformer.get_masks()) 1062 | 1063 | def apply_dsp(self, num_of_heads): 1064 | self.num_of_heads = num_of_heads 1065 | self.use_dsp = True 1066 | 1067 | 1068 | -------------------------------------------------------------------------------- /modeling_gpt2_dp.py: -------------------------------------------------------------------------------- 1 | from transformers import GPT2PreTrainedModel 2 | from allennlp.modules.span_extractors import SelfAttentiveSpanExtractor 3 | from allennlp.modules import scalar_mix 4 | from torch import nn 5 | import torch 6 | from torch.nn import CrossEntropyLoss 7 | import torch.nn.functional as F 8 | from transformers.file_utils import ModelOutput 9 | from typing import Optional 10 | from utils import STEFunction 11 | 12 | class DiagnosticProbingOutputs(ModelOutput): 13 | loss: Optional[torch.FloatTensor] = None 14 | logits: torch.FloatTensor = None 15 | 16 | class GPT2ForDiagnosticProbing(GPT2PreTrainedModel): 17 | def __init__(self, config, gpt2): 18 | super().__init__(config) 19 | self.transformer = gpt2 20 | for param in self.transformer.parameters(): 21 | param.requires_grad = False 22 | 23 | # Model parallel 24 | self.model_parallel = False 25 | self.device_map = None 26 | 27 | self.unary = config.unary 28 | self.num_labels = config.num_labels 29 | self.mlp_dropout = config.mlp_dropout 30 | self.mlp_dim = config.mlp_dim 31 | self.use_mlp = config.use_mlp 32 | 33 | self.scalar_mix = scalar_mix.ScalarMix(config.n_layer, do_layer_norm=False) 34 | 35 | self.proj1 = nn.Conv1d( 36 | config.n_embd, 37 | config.mlp_dim, 38 | kernel_size=1, 39 | stride=1, 40 | padding=0, 41 | dilation=1, 42 | groups=1, 43 | bias=True, 44 | ) 45 | self.span_extractor1 = SelfAttentiveSpanExtractor(config.mlp_dim) 46 | self.d_inp = self.span_extractor1.get_output_dim() 47 | if not self.unary: 48 | self.proj2 = nn.Conv1d( 49 | config.n_embd, 50 | config.mlp_dim, 51 | kernel_size=1, 52 | stride=1, 53 | padding=0, 54 | dilation=1, 55 | groups=1, 56 | bias=True, 57 | ) 58 | self.span_extractor2 = SelfAttentiveSpanExtractor(config.mlp_dim) 59 | self.d_inp += self.span_extractor2.get_output_dim() 60 | 61 | if not self.use_mlp: 62 | self.classifier = nn.Sequential( 63 | nn.Dropout(self.mlp_dropout), 64 | nn.Linear(self.d_inp, self.num_labels) 65 | ) 66 | else: 67 | self.classifier = nn.Sequential( 68 | nn.Linear(self.d_inp, self.mlp_dim), 69 | nn.Tanh(), 70 | nn.LayerNorm(self.mlp_dim), 71 | nn.Dropout(self.mlp_dropout), 72 | nn.Linear(self.mlp_dim, self.num_labels), 73 | ) 74 | 75 | self.w = nn.Parameter(torch.empty([config.num_hidden_layers, config.num_attention_heads])) 76 | nn.init.xavier_uniform(self.w) 77 | self.num_of_heads = None 78 | self.use_dsp = False 79 | 80 | def forward( 81 | self, 82 | input_ids=None, 83 | past_key_values=None, 84 | attention_mask=None, 85 | token_type_ids=None, 86 | position_ids=None, 87 | head_mask=None, 88 | inputs_embeds=None, 89 | encoder_hidden_states=None, 90 | encoder_attention_mask=None, 91 | labels=None, 92 | use_cache=None, 93 | output_attentions=None, 94 | output_hidden_states=None, 95 | return_dict=None, 96 | span1s=None, 97 | span2s=None, 98 | ): 99 | return_dict = return_dict if return_dict is not None else self.config.use_return_dict 100 | 101 | if self.use_dsp: 102 | head_mask = STEFunction.apply(self.w.view(-1), self.num_of_heads).view_as(self.w) 103 | self.apply_masks(head_mask) 104 | 105 | transformer_outputs = self.transformer( 106 | input_ids, 107 | past_key_values=past_key_values, 108 | attention_mask=attention_mask, 109 | token_type_ids=token_type_ids, 110 | position_ids=position_ids, 111 | head_mask=head_mask, 112 | inputs_embeds=inputs_embeds, 113 | encoder_hidden_states=encoder_hidden_states, 114 | encoder_attention_mask=encoder_attention_mask, 115 | use_cache=use_cache, 116 | output_attentions=output_attentions, 117 | output_hidden_states=True, 118 | return_dict=True, 119 | ) 120 | if not self.use_mlp: 121 | contextual_embeddings = transformer_outputs[0] 122 | else: 123 | all_hidden_states = transformer_outputs.hidden_states[1:] 124 | contextual_embeddings = self.scalar_mix(all_hidden_states) 125 | 126 | span_mask = span1s[:, :, 0] != -1 127 | 128 | se_proj1 = self.proj1(contextual_embeddings.transpose(1, 2)).transpose(2, 1).contiguous() 129 | span1_emb = self.span_extractor1(se_proj1, span1s, span_indices_mask=span_mask.long()) 130 | if not self.unary: 131 | se_proj2 = self.proj2(contextual_embeddings.transpose(1, 2)).transpose(2, 1).contiguous() 132 | span2_emb = self.span_extractor2(se_proj2, span2s, span_indices_mask=span_mask.long()) 133 | span_emb = torch.cat([span1_emb, span2_emb], dim=2) 134 | else: 135 | span_emb = span1_emb 136 | 137 | logits = self.classifier(span_emb) 138 | loss_fct = CrossEntropyLoss() 139 | loss = loss_fct(logits[span_mask], labels[span_mask]) 140 | 141 | corrections = logits[span_mask].argmax(-1) == labels[span_mask] 142 | correct_counts = corrections.sum() 143 | total_counts = len(corrections) 144 | accuracy = torch.tensor([[correct_counts, total_counts]], device=corrections.device) 145 | 146 | if not return_dict: 147 | output = (accuracy,) 148 | return ((loss,) + output) if loss is not None else output 149 | 150 | return DiagnosticProbingOutputs( 151 | loss=loss, 152 | logits=accuracy, 153 | ) 154 | 155 | def apply_masks(self, head_mask): 156 | head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers) 157 | self.transformer.apply_masks(head_mask) 158 | 159 | def get_masks(self): 160 | return torch.stack(self.transformer.get_masks()) 161 | 162 | def apply_dsp(self, num_of_heads): 163 | self.num_of_heads = num_of_heads 164 | self.use_dsp = True -------------------------------------------------------------------------------- /modeling_gpt2_pp.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from transformers import GPT2PreTrainedModel 3 | from torch import nn 4 | 5 | 6 | class GPT2ForProbingViaPrompting(GPT2PreTrainedModel): 7 | """Classification Head for transformer encoders""" 8 | def __init__(self, config, gpt2): 9 | super().__init__(config) 10 | 11 | self.gpt2 = gpt2 12 | self.match_n_layer = config.n_layer 13 | self.match_n_head = config.n_head 14 | self.n_embd = config.n_embd 15 | self.match_n_embd = self.n_embd // self.match_n_head 16 | 17 | for param in self.gpt2.parameters(): 18 | param.requires_grad = False 19 | 20 | self.prefix_len = config.prefix_len 21 | self.prefix_dim = config.prefix_dim 22 | self.prefix_drop = config.prefix_drop 23 | print('PrefixTuning') 24 | print('prefix_len: {}, prefix_dim: {}, prefix_drop: {}'.format(self.prefix_len, self.prefix_dim, self.prefix_drop)) 25 | 26 | 27 | self.input_tokens = torch.arange(self.prefix_len).long() 28 | self.wte = nn.Embedding(self.prefix_len, self.n_embd) 29 | self.prefix_model = nn.Sequential( 30 | nn.Linear(self.n_embd, self.prefix_dim), 31 | nn.Tanh(), 32 | nn.Linear(self.prefix_dim, self.match_n_layer * 2 * self.n_embd)) 33 | 34 | self.dropout = nn.Dropout(self.prefix_drop) 35 | 36 | self.model_parallel = False 37 | self.device_map = None 38 | 39 | def get_prefix(self, bsz, device): 40 | input_tokens = self.input_tokens.unsqueeze(0).expand(bsz, -1).to(device) 41 | temp_control = self.wte(input_tokens) 42 | past_key_values = self.prefix_model(temp_control) #bsz, seqlen, layer*emb 43 | bsz, seqlen, _ = past_key_values.shape 44 | past_key_values = past_key_values.view(bsz, seqlen, self.match_n_layer * 2, self.match_n_head, 45 | self.match_n_embd) 46 | past_key_values = self.dropout(past_key_values) 47 | past_key_values = past_key_values.permute([2, 0, 3, 1, 4]).split(2) # n_layer, 2, bsz, n_head, seqlen, n_embd 48 | return past_key_values 49 | 50 | def forward( 51 | self, 52 | input_ids=None, 53 | past_key_values=None, 54 | attention_mask=None, 55 | token_type_ids=None, 56 | position_ids=None, 57 | head_mask=None, 58 | inputs_embeds=None, 59 | encoder_hidden_states=None, 60 | encoder_attention_mask=None, 61 | labels=None, 62 | use_cache=None, 63 | output_attentions=None, 64 | output_hidden_states=None, 65 | return_dict=None, 66 | ): 67 | 68 | bsz = input_ids.shape[0] 69 | device = input_ids.device 70 | 71 | past_key_values = self.get_prefix(bsz, device) 72 | 73 | output = self.gpt2(input_ids=input_ids, 74 | past_key_values=past_key_values, attention_mask=attention_mask, 75 | token_type_ids=token_type_ids, position_ids=position_ids, 76 | head_mask=head_mask, inputs_embeds=inputs_embeds, encoder_hidden_states=encoder_hidden_states, 77 | encoder_attention_mask=encoder_attention_mask, labels=labels, use_cache=use_cache, 78 | output_attentions=output_attentions, output_hidden_states=output_hidden_states, 79 | return_dict=return_dict) 80 | 81 | return output -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | allennlp==0.8.4 2 | overrides==3.1.0 3 | transformers==4.18.0 4 | datasets==2.1.0 5 | accelerate==0.6.2 -------------------------------------------------------------------------------- /run_clm.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # coding=utf-8 3 | # Copyright 2020 The HuggingFace Inc. team. All rights reserved. 4 | # 5 | # Licensed under the Apache License, Version 2.0 (the "License"); 6 | # you may not use this file except in compliance with the License. 7 | # You may obtain a copy of the License at 8 | # 9 | # http://www.apache.org/licenses/LICENSE-2.0 10 | # 11 | # Unless required by applicable law or agreed to in writing, software 12 | # distributed under the License is distributed on an "AS IS" BASIS, 13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | # See the License for the specific language governing permissions and 15 | # limitations under the License. 16 | """ 17 | Fine-tuning the library models for causal language modeling (GPT, GPT-2, CTRL, ...) on a text file or a dataset. 18 | 19 | Here is the full list of checkpoints on the hub that can be fine-tuned by this script: 20 | https://huggingface.co/models?filter=causal-lm 21 | """ 22 | # You can also adapt this script on your own causal language modeling task. Pointers for this are left as comments. 23 | 24 | import logging 25 | import math 26 | import os 27 | import sys 28 | from dataclasses import dataclass, field 29 | from typing import Optional 30 | 31 | import torch 32 | 33 | import datasets 34 | from datasets import load_dataset 35 | 36 | import transformers 37 | from transformers import ( 38 | CONFIG_MAPPING, 39 | MODEL_FOR_CAUSAL_LM_MAPPING, 40 | AutoConfig, 41 | AutoModelForCausalLM, 42 | AutoTokenizer, 43 | HfArgumentParser, 44 | Trainer, 45 | TrainingArguments, 46 | default_data_collator, 47 | set_seed, 48 | ) 49 | from transformers.testing_utils import CaptureLogger 50 | from transformers.trainer_utils import get_last_checkpoint 51 | from transformers.utils import check_min_version 52 | from transformers.utils.versions import require_version 53 | 54 | from modeling_gated_gpt2 import GatedGPT2LMHeadModel 55 | from utils import convert_gate_to_mask 56 | 57 | # Will error if the minimal version of Transformers is not installed. Remove at your own risks. 58 | check_min_version("4.13.0.dev0") 59 | 60 | require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/language-modeling/requirements.txt") 61 | 62 | logger = logging.getLogger(__name__) 63 | 64 | 65 | MODEL_CONFIG_CLASSES = list(MODEL_FOR_CAUSAL_LM_MAPPING.keys()) 66 | MODEL_TYPES = tuple(conf.model_type for conf in MODEL_CONFIG_CLASSES) 67 | 68 | 69 | @dataclass 70 | class ModelArguments: 71 | """ 72 | Arguments pertaining to which model/config/tokenizer we are going to fine-tune, or train from scratch. 73 | """ 74 | 75 | model_name_or_path: Optional[str] = field( 76 | default=None, 77 | metadata={ 78 | "help": "The model checkpoint for weights initialization." 79 | "Don't set if you want to train a model from scratch." 80 | }, 81 | ) 82 | model_type: Optional[str] = field( 83 | default=None, 84 | metadata={"help": "If training from scratch, pass a model type from the list: " + ", ".join(MODEL_TYPES)}, 85 | ) 86 | config_overrides: Optional[str] = field( 87 | default=None, 88 | metadata={ 89 | "help": "Override some existing default config settings when a model is trained from scratch. Example: " 90 | "n_embd=10,resid_pdrop=0.2,scale_attn_weights=false,summary_type=cls_index" 91 | }, 92 | ) 93 | config_name: Optional[str] = field( 94 | default=None, metadata={"help": "Pretrained config name or path if not the same as model_name"} 95 | ) 96 | tokenizer_name: Optional[str] = field( 97 | default=None, metadata={"help": "Pretrained tokenizer name or path if not the same as model_name"} 98 | ) 99 | cache_dir: Optional[str] = field( 100 | default=None, 101 | metadata={"help": "Where do you want to store the pretrained models downloaded from huggingface.co"}, 102 | ) 103 | use_fast_tokenizer: bool = field( 104 | default=True, 105 | metadata={"help": "Whether to use one of the fast tokenizer (backed by the tokenizers library) or not."}, 106 | ) 107 | model_revision: str = field( 108 | default="main", 109 | metadata={"help": "The specific model version to use (can be a branch name, tag name or commit id)."}, 110 | ) 111 | use_auth_token: bool = field( 112 | default=False, 113 | metadata={ 114 | "help": "Will use the token generated when running `transformers-cli login` (necessary to use this script " 115 | "with private models)." 116 | }, 117 | ) 118 | head_mask_path: Optional[str] = field( 119 | default='None', 120 | metadata={ 121 | "help": "Where head mask is stored" 122 | }, 123 | ) 124 | 125 | def __post_init__(self): 126 | if self.config_overrides is not None and (self.config_name is not None or self.model_name_or_path is not None): 127 | raise ValueError( 128 | "--config_overrides can't be used in combination with --config_name or --model_name_or_path" 129 | ) 130 | 131 | 132 | @dataclass 133 | class DataTrainingArguments: 134 | """ 135 | Arguments pertaining to what data we are going to input our model for training and eval. 136 | """ 137 | 138 | dataset_name: Optional[str] = field( 139 | default=None, metadata={"help": "The name of the dataset to use (via the datasets library)."} 140 | ) 141 | dataset_config_name: Optional[str] = field( 142 | default=None, metadata={"help": "The configuration name of the dataset to use (via the datasets library)."} 143 | ) 144 | train_file: Optional[str] = field(default=None, metadata={"help": "The input training data file (a text file)."}) 145 | validation_file: Optional[str] = field( 146 | default=None, 147 | metadata={"help": "An optional input evaluation data file to evaluate the perplexity on (a text file)."}, 148 | ) 149 | max_train_samples: Optional[int] = field( 150 | default=None, 151 | metadata={ 152 | "help": "For debugging purposes or quicker training, truncate the number of training examples to this " 153 | "value if set." 154 | }, 155 | ) 156 | max_eval_samples: Optional[int] = field( 157 | default=None, 158 | metadata={ 159 | "help": "For debugging purposes or quicker training, truncate the number of evaluation examples to this " 160 | "value if set." 161 | }, 162 | ) 163 | 164 | block_size: Optional[int] = field( 165 | default=None, 166 | metadata={ 167 | "help": "Optional input sequence length after tokenization. " 168 | "The training dataset will be truncated in block of this size for training. " 169 | "Default to the model max input length for single sentence inputs (take into account special tokens)." 170 | }, 171 | ) 172 | overwrite_cache: bool = field( 173 | default=False, metadata={"help": "Overwrite the cached training and evaluation sets"} 174 | ) 175 | validation_split_percentage: Optional[int] = field( 176 | default=5, 177 | metadata={ 178 | "help": "The percentage of the train set used as validation set in case there's no validation split" 179 | }, 180 | ) 181 | preprocessing_num_workers: Optional[int] = field( 182 | default=None, 183 | metadata={"help": "The number of processes to use for the preprocessing."}, 184 | ) 185 | keep_linebreaks: bool = field( 186 | default=True, metadata={"help": "Whether to keep line breaks when using TXT files or not."} 187 | ) 188 | 189 | def __post_init__(self): 190 | assert self.dataset_name is not None 191 | 192 | def main(): 193 | # See all possible arguments in src/transformers/training_args.py 194 | # or by passing the --help flag to this script. 195 | # We now keep distinct sets of args, for a cleaner separation of concerns. 196 | 197 | parser = HfArgumentParser((ModelArguments, DataTrainingArguments, TrainingArguments)) 198 | if len(sys.argv) == 2 and sys.argv[1].endswith(".json"): 199 | # If we pass only one argument to the script and it's the path to a json file, 200 | # let's parse it to get our arguments. 201 | model_args, data_args, training_args = parser.parse_json_file(json_file=os.path.abspath(sys.argv[1])) 202 | else: 203 | model_args, data_args, training_args = parser.parse_args_into_dataclasses() 204 | 205 | # Setup logging 206 | logging.basicConfig( 207 | format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", 208 | datefmt="%m/%d/%Y %H:%M:%S", 209 | handlers=[logging.StreamHandler(sys.stdout)], 210 | ) 211 | 212 | log_level = training_args.get_process_log_level() 213 | logger.setLevel(log_level) 214 | datasets.utils.logging.set_verbosity(log_level) 215 | transformers.utils.logging.set_verbosity(log_level) 216 | transformers.utils.logging.enable_default_handler() 217 | transformers.utils.logging.enable_explicit_format() 218 | 219 | # Log on each process the small summary: 220 | logger.warning( 221 | f"Process rank: {training_args.local_rank}, device: {training_args.device}, n_gpu: {training_args.n_gpu}" 222 | + f"distributed training: {bool(training_args.local_rank != -1)}, 16-bits training: {training_args.fp16}" 223 | ) 224 | logger.info(f"Training/evaluation parameters {training_args}") 225 | 226 | # Set seed before initializing model. 227 | set_seed(training_args.seed) 228 | 229 | # Get the datasets: you can either provide your own CSV/JSON/TXT training and evaluation files (see below) 230 | # or just provide the name of one of the public datasets available on the hub at https://huggingface.co/datasets/ 231 | # (the dataset will be downloaded automatically from the datasets Hub). 232 | # 233 | # For CSV/JSON files, this script will use the column called 'text' or the first column if no column called 234 | # 'text' is found. You can easily tweak this behavior (see below). 235 | # 236 | # In distributed training, the load_dataset function guarantee that only one local process can concurrently 237 | # download the dataset. 238 | # Downloading and loading a dataset from the hub. 239 | raw_datasets = load_dataset( 240 | data_args.dataset_name, data_args.dataset_config_name, cache_dir=model_args.cache_dir 241 | ) 242 | 243 | # See more about loading any type of standard or custom dataset (from files, python dict, pandas DataFrame, etc) at 244 | # https://huggingface.co/docs/datasets/loading_datasets.html. 245 | 246 | # Load pretrained model and tokenizer 247 | # 248 | # Distributed training: 249 | # The .from_pretrained methods guarantee that only one local process can concurrently 250 | # download model & vocab. 251 | 252 | config_kwargs = { 253 | "cache_dir": model_args.cache_dir, 254 | "revision": model_args.model_revision, 255 | "use_auth_token": True if model_args.use_auth_token else None, 256 | } 257 | if model_args.config_name: 258 | config = AutoConfig.from_pretrained(model_args.config_name, **config_kwargs) 259 | elif model_args.model_name_or_path: 260 | config = AutoConfig.from_pretrained(model_args.model_name_or_path, **config_kwargs) 261 | else: 262 | config = CONFIG_MAPPING[model_args.model_type]() 263 | logger.warning("You are instantiating a new config instance from scratch.") 264 | if model_args.config_overrides is not None: 265 | logger.info(f"Overriding config: {model_args.config_overrides}") 266 | config.update_from_string(model_args.config_overrides) 267 | 268 | tokenizer_kwargs = { 269 | "cache_dir": model_args.cache_dir, 270 | "use_fast": model_args.use_fast_tokenizer, 271 | "revision": model_args.model_revision, 272 | "use_auth_token": True if model_args.use_auth_token else None, 273 | } 274 | if model_args.tokenizer_name: 275 | tokenizer = AutoTokenizer.from_pretrained(model_args.tokenizer_name, **tokenizer_kwargs) 276 | elif model_args.model_name_or_path: 277 | tokenizer = AutoTokenizer.from_pretrained(model_args.model_name_or_path, **tokenizer_kwargs) 278 | else: 279 | raise ValueError( 280 | "You are instantiating a new tokenizer from scratch. This is not supported by this script." 281 | "You can do it from another script, save it, and load it from here, using --tokenizer_name." 282 | ) 283 | 284 | if model_args.model_name_or_path: 285 | model = GatedGPT2LMHeadModel.from_pretrained( 286 | model_args.model_name_or_path, 287 | from_tf=bool(".ckpt" in model_args.model_name_or_path), 288 | config=config, 289 | cache_dir=model_args.cache_dir, 290 | revision=model_args.model_revision, 291 | use_auth_token=True if model_args.use_auth_token else None, 292 | ) 293 | else: 294 | model = AutoModelForCausalLM.from_config(config) 295 | n_params = sum(dict((p.data_ptr(), p.numel()) for p in model.parameters()).values()) 296 | logger.info(f"Training new model from scratch - Total size={n_params/2**20:.2f}M params") 297 | 298 | model.resize_token_embeddings(len(tokenizer)) 299 | 300 | if model_args.head_mask_path == 'random': 301 | head_mask = torch.rand(12, 12) 302 | head_mask = convert_gate_to_mask(head_mask, 48) 303 | model.apply_masks(head_mask) 304 | elif model_args.head_mask_path != 'None': 305 | head_mask = torch.load(model_args.head_mask_path) 306 | head_mask = (head_mask == 0).float() 307 | model.apply_masks(head_mask) 308 | 309 | # Preprocessing the datasets. 310 | # First we tokenize all the texts. 311 | column_names = raw_datasets["train"].column_names 312 | text_column_name = "text" if "text" in column_names else column_names[0] 313 | 314 | # since this will be pickled to avoid _LazyModule error in Hasher force logger loading before tokenize_function 315 | tok_logger = transformers.utils.logging.get_logger("transformers.tokenization_utils_base") 316 | 317 | def tokenize_function(examples): 318 | with CaptureLogger(tok_logger) as cl: 319 | output = tokenizer(examples[text_column_name]) 320 | # clm input could be much much longer than block_size 321 | if "Token indices sequence length is longer than the" in cl.out: 322 | tok_logger.warning( 323 | "^^^^^^^^^^^^^^^^ Please ignore the warning above - this long input will be chunked into smaller bits before being passed to the model." 324 | ) 325 | return output 326 | 327 | with training_args.main_process_first(desc="dataset map tokenization"): 328 | tokenized_datasets = raw_datasets.map( 329 | tokenize_function, 330 | batched=True, 331 | num_proc=data_args.preprocessing_num_workers, 332 | remove_columns=column_names, 333 | load_from_cache_file=not data_args.overwrite_cache, 334 | desc="Running tokenizer on dataset", 335 | ) 336 | 337 | if data_args.block_size is None: 338 | block_size = tokenizer.model_max_length 339 | if block_size > 1024: 340 | logger.warning( 341 | f"The tokenizer picked seems to have a very large `model_max_length` ({tokenizer.model_max_length}). " 342 | "Picking 1024 instead. You can change that default value by passing --block_size xxx." 343 | ) 344 | block_size = 1024 345 | else: 346 | if data_args.block_size > tokenizer.model_max_length: 347 | logger.warning( 348 | f"The block_size passed ({data_args.block_size}) is larger than the maximum length for the model" 349 | f"({tokenizer.model_max_length}). Using block_size={tokenizer.model_max_length}." 350 | ) 351 | block_size = min(data_args.block_size, tokenizer.model_max_length) 352 | 353 | # Main data processing function that will concatenate all texts from our dataset and generate chunks of block_size. 354 | def group_texts(examples): 355 | # Concatenate all texts. 356 | concatenated_examples = {k: sum(examples[k], []) for k in examples.keys()} 357 | total_length = len(concatenated_examples[list(examples.keys())[0]]) 358 | # We drop the small remainder, we could add padding if the model supported it instead of this drop, you can 359 | # customize this part to your needs. 360 | if total_length >= block_size: 361 | total_length = (total_length // block_size) * block_size 362 | # Split by chunks of max_len. 363 | result = { 364 | k: [t[i : i + block_size] for i in range(0, total_length, block_size)] 365 | for k, t in concatenated_examples.items() 366 | } 367 | result["labels"] = result["input_ids"].copy() 368 | return result 369 | 370 | # Note that with `batched=True`, this map processes 1,000 texts together, so group_texts throws away a remainder 371 | # for each of those groups of 1,000 texts. You can adjust that batch_size here but a higher value might be slower 372 | # to preprocess. 373 | # 374 | # To speed up this part, we use multiprocessing. See the documentation of the map method for more information: 375 | # https://huggingface.co/docs/datasets/package_reference/main_classes.html#datasets.Dataset.map 376 | 377 | with training_args.main_process_first(desc="grouping texts together"): 378 | lm_datasets = tokenized_datasets.map( 379 | group_texts, 380 | batched=True, 381 | num_proc=data_args.preprocessing_num_workers, 382 | load_from_cache_file=not data_args.overwrite_cache, 383 | desc=f"Grouping texts in chunks of {block_size}", 384 | ) 385 | 386 | if training_args.do_eval: 387 | eval_dataset = lm_datasets["train"] 388 | if data_args.max_eval_samples is not None: 389 | eval_dataset = eval_dataset.select(range(data_args.max_eval_samples)) 390 | 391 | # Initialize our Trainer 392 | trainer = Trainer( 393 | model=model, 394 | args=training_args, 395 | train_dataset=None, 396 | eval_dataset=eval_dataset if training_args.do_eval else None, 397 | tokenizer=tokenizer, 398 | # Data collator will default to DataCollatorWithPadding, so we change it. 399 | data_collator=default_data_collator, 400 | ) 401 | 402 | # Evaluation 403 | if training_args.do_eval: 404 | logger.info("*** Evaluate ***") 405 | 406 | metrics = trainer.evaluate() 407 | 408 | max_eval_samples = data_args.max_eval_samples if data_args.max_eval_samples is not None else len(eval_dataset) 409 | metrics["eval_samples"] = min(max_eval_samples, len(eval_dataset)) 410 | try: 411 | perplexity = math.exp(metrics["eval_loss"]) 412 | except OverflowError: 413 | perplexity = float("inf") 414 | metrics["perplexity"] = perplexity 415 | 416 | trainer.log_metrics("eval", metrics) 417 | trainer.save_metrics("eval", metrics) 418 | 419 | if model_args.head_mask_path != 'None': 420 | logger.info(f"Number of heads: {head_mask.sum()}") 421 | 422 | kwargs = {"finetuned_from": model_args.model_name_or_path, "tasks": "text-generation"} 423 | if data_args.dataset_name is not None: 424 | kwargs["dataset_tags"] = data_args.dataset_name 425 | if data_args.dataset_config_name is not None: 426 | kwargs["dataset_args"] = data_args.dataset_config_name 427 | kwargs["dataset"] = f"{data_args.dataset_name} {data_args.dataset_config_name}" 428 | else: 429 | kwargs["dataset"] = data_args.dataset_name 430 | 431 | if training_args.push_to_hub: 432 | trainer.push_to_hub(**kwargs) 433 | else: 434 | trainer.create_model_card(**kwargs) 435 | 436 | 437 | def _mp_fn(index): 438 | # For xla_spawn (TPUs) 439 | main() 440 | 441 | 442 | if __name__ == "__main__": 443 | main() 444 | -------------------------------------------------------------------------------- /run_dp.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # coding=utf-8 3 | # Copyright 2020 The HuggingFace Inc. team. All rights reserved. 4 | # 5 | # Licensed under the Apache License, Version 2.0 (the "License"); 6 | # you may not use this file except in compliance with the License. 7 | # You may obtain a copy of the License at 8 | # 9 | # http://www.apache.org/licenses/LICENSE-2.0 10 | # 11 | # Unless required by applicable law or agreed to in writing, software 12 | # distributed under the License is distributed on an "AS IS" BASIS, 13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | # See the License for the specific language governing permissions and 15 | # limitations under the License. 16 | """ 17 | Fine-tuning the library models for causal language modeling (GPT, GPT-2, CTRL, ...) on a text file or a dataset. 18 | 19 | Here is the full list of checkpoints on the hub that can be fine-tuned by this script: 20 | https://huggingface.co/models?filter=causal-lm 21 | """ 22 | # You can also adapt this script on your own causal language modeling task. Pointers for this are left as comments. 23 | 24 | import logging 25 | import math 26 | import os 27 | import sys 28 | from dataclasses import dataclass, field 29 | from typing import Optional 30 | import random 31 | 32 | import torch 33 | 34 | import datasets 35 | from datasets import load_dataset 36 | 37 | import transformers 38 | from transformers import ( 39 | AdamW, 40 | CONFIG_MAPPING, 41 | MODEL_FOR_CAUSAL_LM_MAPPING, 42 | AutoConfig, 43 | AutoTokenizer, 44 | HfArgumentParser, 45 | TrainingArguments, 46 | default_data_collator, 47 | set_seed, 48 | Trainer, 49 | ) 50 | from transformers.trainer_utils import get_last_checkpoint 51 | from transformers.utils.versions import require_version 52 | from tokenizers.pre_tokenizers import WhitespaceSplit 53 | 54 | from modeling_gpt2_dp import GPT2ForDiagnosticProbing 55 | from modeling_gated_gpt2 import GPT2Model 56 | 57 | from utils import LABEL_DICT, convert_gate_to_mask 58 | 59 | # Will error if the minimal version of Transformers is not installed. Remove at your own risks. 60 | # check_min_version("4.13.0.dev0") 61 | 62 | require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/language-modeling/requirements.txt") 63 | 64 | logger = logging.getLogger(__name__) 65 | 66 | 67 | MODEL_CONFIG_CLASSES = list(MODEL_FOR_CAUSAL_LM_MAPPING.keys()) 68 | MODEL_TYPES = tuple(conf.model_type for conf in MODEL_CONFIG_CLASSES) 69 | 70 | MAX_LENGTH = {'pos': 350, 'const': 350, 'ner': 350, 'coref': 280, 'srl': 350} 71 | MAX_TARGET = {'pos': 275, 'const': 175, 'ner': 71, 'coref': 300, 'srl': 11} 72 | IS_UNARY = {'pos': True, 'const': True, 'ner': True, 'coref': False, 'srl': False} 73 | 74 | @dataclass 75 | class ModelArguments: 76 | """ 77 | Arguments pertaining to which model/config/tokenizer we are going to fine-tune, or train from scratch. 78 | """ 79 | 80 | gpt2_name_or_path: Optional[str] = field( 81 | default=None, 82 | metadata={ 83 | "help": "The model checkpoint for weights initialization." 84 | "Don't set if you want to train a model from scratch." 85 | }, 86 | ) 87 | model_path: Optional[str] = field( 88 | default=None, 89 | metadata={ 90 | "help": "Path to trained model." 91 | "Don't set if you want to train a model from scratch." 92 | }, 93 | ) 94 | model_type: Optional[str] = field( 95 | default=None, 96 | metadata={"help": "If training from scratch, pass a model type from the list: " + ", ".join(MODEL_TYPES)}, 97 | ) 98 | config_overrides: Optional[str] = field( 99 | default=None, 100 | metadata={ 101 | "help": "Override some existing default config settings when a model is trained from scratch. Example: " 102 | "n_embd=10,resid_pdrop=0.2,scale_attn_weights=false,summary_type=cls_index" 103 | }, 104 | ) 105 | config_name: Optional[str] = field( 106 | default=None, metadata={"help": "Pretrained config name or path if not the same as model_name"} 107 | ) 108 | tokenizer_name: Optional[str] = field( 109 | default=None, metadata={"help": "Pretrained tokenizer name or path if not the same as model_name"} 110 | ) 111 | cache_dir: Optional[str] = field( 112 | default='cache/', 113 | metadata={"help": "Where do you want to store the pretrained models downloaded from huggingface.co"}, 114 | ) 115 | use_fast_tokenizer: bool = field( 116 | default=True, 117 | metadata={"help": "Whether to use one of the fast tokenizer (backed by the tokenizers library) or not."}, 118 | ) 119 | model_revision: str = field( 120 | default="main", 121 | metadata={"help": "The specific model version to use (can be a branch name, tag name or commit id)."}, 122 | ) 123 | use_auth_token: bool = field( 124 | default=False, 125 | metadata={ 126 | "help": "Will use the token generated when running `transformers-cli login` (necessary to use this script " 127 | "with private models)." 128 | }, 129 | ) 130 | use_mlp: bool = field( 131 | default=True, 132 | metadata={ 133 | "help": "use mlp or linear regression" 134 | }, 135 | ) 136 | mlp_dropout: Optional[float] = field( 137 | default=0.2, 138 | metadata={"help": "Dropout in MLP model."}, 139 | ) 140 | mlp_dim: Optional[int] = field( 141 | default=512, 142 | metadata={"help": "Dimension of hidden states of MLP model."}, 143 | ) 144 | num_of_heads: Optional[int] = field( 145 | default=96, 146 | metadata={"help": "Number of heads left unpruned."}, 147 | ) 148 | pruning_lr: Optional[float] = field( 149 | default=0.1, 150 | metadata={"help": "Learning rate for head importance variables."}, 151 | ) 152 | do_prune: Optional[bool] = field( 153 | default=False, 154 | metadata={"help": "Whether heads are pruned."}, 155 | ) 156 | 157 | 158 | @dataclass 159 | class DataTrainingArguments: 160 | """ 161 | Arguments pertaining to what data we are going to input our model for training and eval. 162 | """ 163 | 164 | data_dir: Optional[str] = field( 165 | default=None, metadata={"help": "Where data is stored"} 166 | ) 167 | task: Optional[str] = field( 168 | default='ner', 169 | metadata={"help": "Tasks, one or more of {pos, const, coref, ner, srl}."}, 170 | ) 171 | max_train_samples: Optional[int] = field( 172 | default=None, 173 | metadata={ 174 | "help": "For debugging purposes or quicker training, truncate the number of training examples to this " 175 | "value if set." 176 | }, 177 | ) 178 | max_eval_samples: Optional[int] = field( 179 | default=None, 180 | metadata={ 181 | "help": "For debugging purposes or quicker training, truncate the number of evaluation examples to this " 182 | "value if set." 183 | }, 184 | ) 185 | overwrite_cache: bool = field( 186 | default=False, metadata={"help": "Overwrite the cached training and evaluation sets"} 187 | ) 188 | preprocessing_num_workers: Optional[int] = field( 189 | default=None, 190 | metadata={"help": "The number of processes to use for the preprocessing."}, 191 | ) 192 | 193 | def main(): 194 | # See all possible arguments in src/transformers/training_args.py 195 | # or by passing the --help flag to this script. 196 | # We now keep distinct sets of args, for a cleaner separation of concerns. 197 | 198 | parser = HfArgumentParser((ModelArguments, DataTrainingArguments, TrainingArguments)) 199 | if len(sys.argv) == 2 and sys.argv[1].endswith(".json"): 200 | # If we pass only one argument to the script and it's the path to a json file, 201 | # let's parse it to get our arguments. 202 | model_args, data_args, training_args = parser.parse_json_file(json_file=os.path.abspath(sys.argv[1])) 203 | else: 204 | model_args, data_args, training_args = parser.parse_args_into_dataclasses() 205 | 206 | # Setup logging 207 | logging.basicConfig( 208 | format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", 209 | datefmt="%m/%d/%Y %H:%M:%S", 210 | handlers=[logging.StreamHandler(sys.stdout)], 211 | ) 212 | 213 | log_level = training_args.get_process_log_level() 214 | logger.setLevel(log_level) 215 | datasets.utils.logging.set_verbosity(log_level) 216 | transformers.utils.logging.set_verbosity(log_level) 217 | transformers.utils.logging.enable_default_handler() 218 | transformers.utils.logging.enable_explicit_format() 219 | 220 | # Log on each process the small summary: 221 | logger.warning( 222 | f"Process rank: {training_args.local_rank}, device: {training_args.device}, n_gpu: {training_args.n_gpu}" 223 | + f"distributed training: {bool(training_args.local_rank != -1)}, 16-bits training: {training_args.fp16}" 224 | ) 225 | logger.info(f"Training/evaluation parameters {training_args}") 226 | 227 | # Detecting last checkpoint. 228 | last_checkpoint = None 229 | if os.path.isdir(training_args.output_dir) and training_args.do_train and not training_args.overwrite_output_dir: 230 | last_checkpoint = get_last_checkpoint(training_args.output_dir) 231 | if last_checkpoint is None and len(os.listdir(training_args.output_dir)) > 0: 232 | raise ValueError( 233 | f"Output directory ({training_args.output_dir}) already exists and is not empty. " 234 | "Use --overwrite_output_dir to overcome." 235 | ) 236 | elif last_checkpoint is not None and training_args.resume_from_checkpoint is None: 237 | logger.info( 238 | f"Checkpoint detected, resuming training at {last_checkpoint}. To avoid this behavior, change " 239 | "the `--output_dir` or add `--overwrite_output_dir` to train from scratch." 240 | ) 241 | 242 | # Set seed before initializing model. 243 | set_seed(training_args.seed) 244 | 245 | # Get the datasets: you can either provide your own CSV/JSON/TXT training and evaluation files (see below) 246 | # or just provide the name of one of the public datasets available on the hub at https://huggingface.co/datasets/ 247 | # (the dataset will be downloaded automatically from the datasets Hub). 248 | # 249 | # For CSV/JSON files, this script will use the column called 'text' or the first column if no column called 250 | # 'text' is found. You can easily tweak this behavior (see below). 251 | # 252 | # In distributed training, the load_dataset function guarantee that only one local process can concurrently 253 | # download the dataset. 254 | 255 | data_files = {} 256 | dataset_args = {} 257 | logger.info("Loading data for {}".format(data_args.task)) 258 | if training_args.do_train: 259 | data_files["train"] = os.path.join(data_args.data_dir, data_args.task, 'train.json') 260 | data_files["validation"] = os.path.join(data_args.data_dir, data_args.task, 'test.json') 261 | raw_datasets = load_dataset("json", data_files=data_files, cache_dir=model_args.cache_dir, **dataset_args) 262 | if "_control" in data_args.task: 263 | data_args.task = data_args.task.replace("_control", "") 264 | label2id = {label: i for i, label in enumerate(LABEL_DICT[data_args.task])} 265 | 266 | # See more about loading any type of standard or custom dataset (from files, python dict, pandas DataFrame, etc) at 267 | # https://huggingface.co/docs/datasets/loading_datasets.html. 268 | 269 | # Load pretrained model and tokenizer 270 | # 271 | # Distributed training: 272 | # The .from_pretrained methods guarantee that only one local process can concurrently 273 | # download model & vocab. 274 | 275 | config_kwargs = { 276 | "cache_dir": model_args.cache_dir, 277 | "revision": model_args.model_revision, 278 | "use_auth_token": True if model_args.use_auth_token else None, 279 | } 280 | if model_args.config_name: 281 | config = AutoConfig.from_pretrained(model_args.config_name, **config_kwargs) 282 | elif model_args.gpt2_name_or_path: 283 | config = AutoConfig.from_pretrained(model_args.gpt2_name_or_path, **config_kwargs) 284 | else: 285 | config = CONFIG_MAPPING[model_args.model_type]() 286 | logger.warning("You are instantiating a new config instance from scratch.") 287 | if model_args.config_overrides is not None: 288 | logger.info(f"Overriding config: {model_args.config_overrides}") 289 | config.update_from_string(model_args.config_overrides) 290 | 291 | tokenizer_kwargs = { 292 | "cache_dir": model_args.cache_dir, 293 | "use_fast": model_args.use_fast_tokenizer, 294 | "revision": model_args.model_revision, 295 | "use_auth_token": True if model_args.use_auth_token else None, 296 | } 297 | if model_args.tokenizer_name: 298 | tokenizer = AutoTokenizer.from_pretrained(model_args.tokenizer_name, **tokenizer_kwargs) 299 | elif model_args.gpt2_name_or_path: 300 | tokenizer = AutoTokenizer.from_pretrained(model_args.gpt2_name_or_path, **tokenizer_kwargs) 301 | else: 302 | raise ValueError( 303 | "You are instantiating a new tokenizer from scratch. This is not supported by this script." 304 | "You can do it from another script, save it, and load it from here, using --tokenizer_name." 305 | ) 306 | 307 | config.num_labels = len(label2id) 308 | if model_args.gpt2_name_or_path: 309 | gpt2 = GPT2Model.from_pretrained( 310 | model_args.gpt2_name_or_path, 311 | cache_dir=model_args.cache_dir, 312 | ) 313 | else: 314 | gpt2 = GPT2Model(config) 315 | n_params = sum(dict((p.data_ptr(), p.numel()) for p in gpt2.parameters()).values()) 316 | logger.info(f"Training new gpt2 from scratch - Total size={n_params/2**20:.2f}M params") 317 | 318 | gpt2.resize_token_embeddings(len(tokenizer)) 319 | 320 | if model_args.model_path: 321 | config = AutoConfig.from_pretrained(model_args.model_path, cache_dir=model_args.cache_dir) 322 | model = GPT2ForDiagnosticProbing.from_pretrained(model_args.model_path, config=config, gpt2=gpt2) 323 | else: 324 | config.mlp_dropout = model_args.mlp_dropout 325 | config.mlp_dim = model_args.mlp_dim 326 | config.unary = IS_UNARY[data_args.task] 327 | config.use_mlp = model_args.use_mlp 328 | model = GPT2ForDiagnosticProbing(config, gpt2) 329 | 330 | # Preprocessing the datasets. 331 | # First we tokenize all the texts. 332 | if training_args.do_train: 333 | column_names = raw_datasets["train"].column_names 334 | else: 335 | column_names = raw_datasets["validation"].column_names 336 | 337 | # since this will be pickled to avoid _LazyModule error in Hasher force logger loading before tokenize_function 338 | tok_logger = transformers.utils.logging.get_logger("transformers.tokenization_utils_base") 339 | 340 | tokenizer.pad_token = tokenizer.eos_token 341 | pre_tokenizer = WhitespaceSplit() 342 | tokenizer.pre_tokenizer = pre_tokenizer 343 | 344 | def convert_span(result, pre_tokenized_str, span): 345 | char_start = pre_tokenized_str[span[0]][1][0] 346 | char_end = pre_tokenized_str[span[1]][1][1] - 1 347 | start = result.char_to_token(char_start) 348 | end = result.char_to_token(char_end) 349 | return [start, end] 350 | 351 | def tokenize_function(example): 352 | result = tokenizer(example['text'], padding="max_length", max_length=MAX_LENGTH[data_args.task]) 353 | pre_tokenized_str = pre_tokenizer.pre_tokenize_str(example['text']) 354 | 355 | num_targets = len(example['targets']) 356 | num_to_pad = MAX_TARGET[data_args.task] - num_targets 357 | pad_spans = [[-1, -1]] * num_to_pad 358 | pad_labels = [-1] * num_to_pad 359 | 360 | result['span1s'] = [convert_span(result, pre_tokenized_str, target['span1']) for target in example['targets']] 361 | result['span1s'].extend(pad_spans) 362 | result['labels'] = [label2id[target['label']] for target in example['targets']] 363 | result['labels'].extend(pad_labels) 364 | if not config.unary: 365 | result['span2s'] = [convert_span(result, pre_tokenized_str, target['span2']) for target in example['targets']] 366 | result['span2s'].extend(pad_spans) 367 | return result 368 | 369 | with training_args.main_process_first(desc="dataset map tokenization"): 370 | tokenized_datasets = raw_datasets.map( 371 | tokenize_function, 372 | batched=False, 373 | num_proc=data_args.preprocessing_num_workers, 374 | remove_columns=column_names, 375 | load_from_cache_file=not data_args.overwrite_cache, 376 | desc="Running tokenizer on dataset", 377 | ) 378 | 379 | if training_args.do_train: 380 | if "train" not in tokenized_datasets: 381 | raise ValueError("--do_train requires a train dataset") 382 | train_dataset = tokenized_datasets["train"] 383 | if data_args.max_train_samples is not None: 384 | train_dataset = train_dataset.select(random.sample(range(len(train_dataset)), data_args.max_train_samples)) 385 | total = 0 386 | for example in train_dataset: 387 | for label in example['labels']: 388 | if label != -1: 389 | total += 1 390 | logger.info("Total number of samples: {}".format(total)) 391 | 392 | if training_args.do_eval: 393 | if "validation" not in tokenized_datasets: 394 | raise ValueError("--do_eval requires a validation dataset") 395 | eval_dataset = tokenized_datasets["validation"] 396 | if data_args.max_eval_samples is not None: 397 | eval_dataset = eval_dataset.select(range(data_args.max_eval_samples)) 398 | 399 | 400 | if training_args.do_train: 401 | # Optimizer 402 | no_decay = ["bias", "LayerNorm.weight"] 403 | optimizer_grouped_parameters = [ 404 | { 405 | "params": [p for n, p in model.named_parameters() if not any(nd in n for nd in no_decay) and p.requires_grad], 406 | "weight_decay": training_args.weight_decay, 407 | "lr": training_args.learning_rate 408 | }, 409 | { 410 | "params": [p for n, p in model.named_parameters() if any(nd in n for nd in no_decay) and p.requires_grad], 411 | "weight_decay": 0.0, 412 | "lr": training_args.learning_rate 413 | }, 414 | ] 415 | 416 | if model_args.do_prune: 417 | model.apply_dsp(model_args.num_of_heads) 418 | for n, p in model.named_parameters(): 419 | if n == "gpt2.w": 420 | p.requires_grad = True 421 | optimizer_grouped_parameters.append( 422 | { 423 | "params": [p for n, p in model.named_parameters() if n == "gpt2.w"], 424 | "lr": model_args.pruning_lr, 425 | } 426 | ) 427 | 428 | optimizer = AdamW(optimizer_grouped_parameters) 429 | else: 430 | optimizer = None 431 | 432 | def compute_metrics(eval_pred): 433 | accuracy, _ = eval_pred 434 | accuracy = accuracy.sum(axis=0) 435 | accuracy = accuracy[0] / accuracy[1] 436 | return {"accuracy": accuracy} 437 | 438 | # Initialize our Trainer 439 | trainer = Trainer( 440 | model=model, 441 | args=training_args, 442 | train_dataset=train_dataset if training_args.do_train else None, 443 | eval_dataset=eval_dataset if training_args.do_eval else None, 444 | tokenizer=tokenizer, 445 | # Data collator will default to DataCollatorWithPadding, so we change it. 446 | data_collator=default_data_collator, 447 | optimizers=(optimizer, None), 448 | compute_metrics=compute_metrics, 449 | ) 450 | 451 | # Training 452 | if training_args.do_train: 453 | checkpoint = None 454 | if training_args.resume_from_checkpoint is not None: 455 | checkpoint = training_args.resume_from_checkpoint 456 | elif last_checkpoint is not None: 457 | checkpoint = last_checkpoint 458 | train_result = trainer.train(resume_from_checkpoint=checkpoint) 459 | # trainer.save_model() # Saves the tokenizer too for easy upload 460 | 461 | metrics = train_result.metrics 462 | 463 | max_train_samples = ( 464 | data_args.max_train_samples if data_args.max_train_samples is not None else len(train_dataset) 465 | ) 466 | metrics["train_samples"] = min(max_train_samples, len(train_dataset)) 467 | 468 | trainer.log_metrics("train", metrics) 469 | trainer.save_metrics("train", metrics) 470 | trainer.save_state() 471 | 472 | if model_args.do_prune: 473 | head_mask = convert_gate_to_mask(model.w, model_args.num_of_heads) 474 | model.apply_masks(head_mask) 475 | model.use_dsp = False 476 | logger.info("Number of heads: {}".format(head_mask.sum())) 477 | logger.info(f'Number of heads in each layer: {head_mask.sum(-1)}') 478 | if training_args.output_dir is not None: 479 | torch.save(head_mask, os.path.join(training_args.output_dir, "mask" + str(model_args.num_of_heads) + ".pt")) 480 | 481 | # Evaluation 482 | if training_args.do_eval: 483 | logger.info("*** Evaluate ***") 484 | logger.info(f'Layer weights: {torch.stack([p for n, p in model.scalar_mix.named_parameters() if "scalar" in n]).flatten()}') 485 | 486 | metrics = trainer.evaluate() 487 | 488 | max_eval_samples = data_args.max_eval_samples if data_args.max_eval_samples is not None else len(eval_dataset) 489 | metrics["eval_samples"] = min(max_eval_samples, len(eval_dataset)) 490 | 491 | trainer.log_metrics("eval", metrics) 492 | trainer.save_metrics("eval", metrics) 493 | 494 | 495 | def _mp_fn(index): 496 | # For xla_spawn (TPUs) 497 | main() 498 | 499 | 500 | if __name__ == "__main__": 501 | main() 502 | -------------------------------------------------------------------------------- /run_pp.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # coding=utf-8 3 | # Copyright 2020 The HuggingFace Inc. team. All rights reserved. 4 | # 5 | # Licensed under the Apache License, Version 2.0 (the "License"); 6 | # you may not use this file except in compliance with the License. 7 | # You may obtain a copy of the License at 8 | # 9 | # http://www.apache.org/licenses/LICENSE-2.0 10 | # 11 | # Unless required by applicable law or agreed to in writing, software 12 | # distributed under the License is distributed on an "AS IS" BASIS, 13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | # See the License for the specific language governing permissions and 15 | # limitations under the License. 16 | """ 17 | Fine-tuning the library models for causal language modeling (GPT, GPT-2, CTRL, ...) on a text file or a dataset. 18 | 19 | Here is the full list of checkpoints on the hub that can be fine-tuned by this script: 20 | https://huggingface.co/models?filter=causal-lm 21 | """ 22 | # You can also adapt this script on your own causal language modeling task. Pointers for this are left as comments. 23 | 24 | import logging 25 | import math 26 | import os 27 | import sys 28 | from dataclasses import dataclass, field 29 | from typing import Optional 30 | import random 31 | 32 | import datasets 33 | from datasets import load_dataset 34 | 35 | import torch 36 | import transformers 37 | from transformers import ( 38 | AdamW, 39 | CONFIG_MAPPING, 40 | MODEL_FOR_CAUSAL_LM_MAPPING, 41 | AutoConfig, 42 | AutoTokenizer, 43 | HfArgumentParser, 44 | TrainingArguments, 45 | default_data_collator, 46 | set_seed, 47 | ) 48 | from transformers.trainer_utils import get_last_checkpoint 49 | from transformers.utils.versions import require_version 50 | 51 | from modeling_gated_gpt2 import GatedGPT2LMHeadModel 52 | from modeling_gpt2_pp import GPT2ForProbingViaPrompting 53 | from trainer_pp import PPTrainer 54 | 55 | from utils import LABEL_DICT, convert_gate_to_mask 56 | 57 | # Will error if the minimal version of Transformers is not installed. Remove at your own risks. 58 | # check_min_version("4.13.0.dev0") 59 | 60 | require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/language-modeling/requirements.txt") 61 | 62 | logger = logging.getLogger(__name__) 63 | 64 | 65 | MODEL_CONFIG_CLASSES = list(MODEL_FOR_CAUSAL_LM_MAPPING.keys()) 66 | MODEL_TYPES = tuple(conf.model_type for conf in MODEL_CONFIG_CLASSES) 67 | 68 | MAX_LENGTH = {'pos': 360, 'const': 700, 'ner': 360, 'coref': 340, 'srl': 685} 69 | 70 | @dataclass 71 | class ModelArguments: 72 | """ 73 | Arguments pertaining to which model/config/tokenizer we are going to fine-tune, or train from scratch. 74 | """ 75 | 76 | gpt2_name_or_path: Optional[str] = field( 77 | default=None, 78 | metadata={ 79 | "help": "The model checkpoint for weights initialization." 80 | "Don't set if you want to train a model from scratch." 81 | }, 82 | ) 83 | prefix_model_path: Optional[str] = field( 84 | default=None, 85 | metadata={ 86 | "help": "Path to trained prefix model." 87 | "Don't set if you want to train a model from scratch." 88 | }, 89 | ) 90 | model_type: Optional[str] = field( 91 | default=None, 92 | metadata={"help": "If training from scratch, pass a model type from the list: " + ", ".join(MODEL_TYPES)}, 93 | ) 94 | config_overrides: Optional[str] = field( 95 | default=None, 96 | metadata={ 97 | "help": "Override some existing default config settings when a model is trained from scratch. Example: " 98 | "n_embd=10,resid_pdrop=0.2,scale_attn_weights=false,summary_type=cls_index" 99 | }, 100 | ) 101 | config_name: Optional[str] = field( 102 | default=None, metadata={"help": "Pretrained config name or path if not the same as model_name"} 103 | ) 104 | tokenizer_name: Optional[str] = field( 105 | default=None, metadata={"help": "Pretrained tokenizer name or path if not the same as model_name"} 106 | ) 107 | cache_dir: Optional[str] = field( 108 | default='cache/', 109 | metadata={"help": "Where do you want to store the pretrained models downloaded from huggingface.co"}, 110 | ) 111 | use_fast_tokenizer: bool = field( 112 | default=True, 113 | metadata={"help": "Whether to use one of the fast tokenizer (backed by the tokenizers library) or not."}, 114 | ) 115 | model_revision: str = field( 116 | default="main", 117 | metadata={"help": "The specific model version to use (can be a branch name, tag name or commit id)."}, 118 | ) 119 | use_auth_token: bool = field( 120 | default=False, 121 | metadata={ 122 | "help": "Will use the token generated when running `transformers-cli login` (necessary to use this script " 123 | "with private models)." 124 | }, 125 | ) 126 | prefix_len: Optional[int] = field( 127 | default=200, 128 | metadata={"help": "Length of the prefix."}, 129 | ) 130 | prefix_dim: Optional[int] = field( 131 | default=512, 132 | metadata={"help": "Dimension of hidden states of prefix model."}, 133 | ) 134 | prefix_drop: Optional[float] = field( 135 | default=0.0, 136 | metadata={"help": "Droput rate for the prefix model."}, 137 | ) 138 | num_of_heads: Optional[int] = field( 139 | default=96, 140 | metadata={"help": "Number of heads left unpruned."}, 141 | ) 142 | pruning_lr: Optional[float] = field( 143 | default=0.1, 144 | metadata={"help": "Learning rate for head importance variables."}, 145 | ) 146 | do_prune: Optional[bool] = field( 147 | default=False, 148 | metadata={"help": "Whether heads are pruned."}, 149 | ) 150 | head_mask_path: Optional[str] = field( 151 | default='None', 152 | metadata={ 153 | "help": "Where head mask is stored" 154 | }, 155 | ) 156 | toggle_mask: Optional[bool] = field( 157 | default=True, 158 | metadata={"help": "Whether heads are pruned."}, 159 | ) 160 | 161 | 162 | @dataclass 163 | class DataTrainingArguments: 164 | """ 165 | Arguments pertaining to what data we are going to input our model for training and eval. 166 | """ 167 | 168 | data_dir: Optional[str] = field( 169 | default=None, metadata={"help": "Where data is stored"} 170 | ) 171 | task: Optional[str] = field( 172 | default='ner', 173 | metadata={"help": "Tasks, one or more of {pos, const, coref, ner, srl}."}, 174 | ) 175 | max_train_samples: Optional[int] = field( 176 | default=None, 177 | metadata={ 178 | "help": "For debugging purposes or quicker training, truncate the number of training examples to this " 179 | "value if set." 180 | }, 181 | ) 182 | max_eval_samples: Optional[int] = field( 183 | default=None, 184 | metadata={ 185 | "help": "For debugging purposes or quicker training, truncate the number of evaluation examples to this " 186 | "value if set." 187 | }, 188 | ) 189 | overwrite_cache: bool = field( 190 | default=False, metadata={"help": "Overwrite the cached training and evaluation sets"} 191 | ) 192 | preprocessing_num_workers: Optional[int] = field( 193 | default=None, 194 | metadata={"help": "The number of processes to use for the preprocessing."}, 195 | ) 196 | 197 | def main(): 198 | # See all possible arguments in src/transformers/training_args.py 199 | # or by passing the --help flag to this script. 200 | # We now keep distinct sets of args, for a cleaner separation of concerns. 201 | 202 | parser = HfArgumentParser((ModelArguments, DataTrainingArguments, TrainingArguments)) 203 | if len(sys.argv) == 2 and sys.argv[1].endswith(".json"): 204 | # If we pass only one argument to the script and it's the path to a json file, 205 | # let's parse it to get our arguments. 206 | model_args, data_args, training_args = parser.parse_json_file(json_file=os.path.abspath(sys.argv[1])) 207 | else: 208 | model_args, data_args, training_args = parser.parse_args_into_dataclasses() 209 | 210 | # Setup logging 211 | logging.basicConfig( 212 | format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", 213 | datefmt="%m/%d/%Y %H:%M:%S", 214 | handlers=[logging.StreamHandler(sys.stdout)], 215 | ) 216 | 217 | log_level = training_args.get_process_log_level() 218 | logger.setLevel(log_level) 219 | datasets.utils.logging.set_verbosity(log_level) 220 | transformers.utils.logging.set_verbosity(log_level) 221 | transformers.utils.logging.enable_default_handler() 222 | transformers.utils.logging.enable_explicit_format() 223 | 224 | # Log on each process the small summary: 225 | logger.warning( 226 | f"Process rank: {training_args.local_rank}, device: {training_args.device}, n_gpu: {training_args.n_gpu}" 227 | + f"distributed training: {bool(training_args.local_rank != -1)}, 16-bits training: {training_args.fp16}" 228 | ) 229 | logger.info(f"Training/evaluation parameters {training_args}") 230 | 231 | # Detecting last checkpoint. 232 | last_checkpoint = None 233 | if os.path.isdir(training_args.output_dir) and training_args.do_train and not training_args.overwrite_output_dir: 234 | last_checkpoint = get_last_checkpoint(training_args.output_dir) 235 | if last_checkpoint is None and len(os.listdir(training_args.output_dir)) > 0: 236 | raise ValueError( 237 | f"Output directory ({training_args.output_dir}) already exists and is not empty. " 238 | "Use --overwrite_output_dir to overcome." 239 | ) 240 | elif last_checkpoint is not None and training_args.resume_from_checkpoint is None: 241 | logger.info( 242 | f"Checkpoint detected, resuming training at {last_checkpoint}. To avoid this behavior, change " 243 | "the `--output_dir` or add `--overwrite_output_dir` to train from scratch." 244 | ) 245 | 246 | # Set seed before initializing model. 247 | set_seed(training_args.seed) 248 | 249 | # Get the datasets: you can either provide your own CSV/JSON/TXT training and evaluation files (see below) 250 | # or just provide the name of one of the public datasets available on the hub at https://huggingface.co/datasets/ 251 | # (the dataset will be downloaded automatically from the datasets Hub). 252 | # 253 | # For CSV/JSON files, this script will use the column called 'text' or the first column if no column called 254 | # 'text' is found. You can easily tweak this behavior (see below). 255 | # 256 | # In distributed training, the load_dataset function guarantee that only one local process can concurrently 257 | # download the dataset. 258 | 259 | data_files = {} 260 | dataset_args = {} 261 | logger.info("Loading data for {}".format(data_args.task)) 262 | if training_args.do_train: 263 | data_files["train"] = os.path.join(data_args.data_dir, data_args.task, 'train.json') 264 | data_files["validation"] = os.path.join(data_args.data_dir, data_args.task, 'test.json') 265 | raw_datasets = load_dataset("json", data_files=data_files, cache_dir=model_args.cache_dir, **dataset_args) 266 | if "_control" in data_args.task: 267 | data_args.task = data_args.task.replace("_control", "") 268 | 269 | # See more about loading any type of standard or custom dataset (from files, python dict, pandas DataFrame, etc) at 270 | # https://huggingface.co/docs/datasets/loading_datasets.html. 271 | 272 | # Load pretrained model and tokenizer 273 | # 274 | # Distributed training: 275 | # The .from_pretrained methods guarantee that only one local process can concurrently 276 | # download model & vocab. 277 | 278 | config_kwargs = { 279 | "cache_dir": model_args.cache_dir, 280 | "revision": model_args.model_revision, 281 | "use_auth_token": True if model_args.use_auth_token else None, 282 | } 283 | if model_args.config_name: 284 | config = AutoConfig.from_pretrained(model_args.config_name, **config_kwargs) 285 | elif model_args.gpt2_name_or_path: 286 | config = AutoConfig.from_pretrained(model_args.gpt2_name_or_path, **config_kwargs) 287 | else: 288 | config = CONFIG_MAPPING[model_args.model_type]() 289 | logger.warning("You are instantiating a new config instance from scratch.") 290 | if model_args.config_overrides is not None: 291 | logger.info(f"Overriding config: {model_args.config_overrides}") 292 | config.update_from_string(model_args.config_overrides) 293 | 294 | tokenizer_kwargs = { 295 | "cache_dir": model_args.cache_dir, 296 | "use_fast": model_args.use_fast_tokenizer, 297 | "revision": model_args.model_revision, 298 | "use_auth_token": True if model_args.use_auth_token else None, 299 | } 300 | if model_args.tokenizer_name: 301 | tokenizer = AutoTokenizer.from_pretrained(model_args.tokenizer_name, **tokenizer_kwargs) 302 | elif model_args.gpt2_name_or_path: 303 | tokenizer = AutoTokenizer.from_pretrained(model_args.gpt2_name_or_path, **tokenizer_kwargs) 304 | else: 305 | raise ValueError( 306 | "You are instantiating a new tokenizer from scratch. This is not supported by this script." 307 | "You can do it from another script, save it, and load it from here, using --tokenizer_name." 308 | ) 309 | 310 | special_tokens_dict = { 311 | 'sep_token': '', 312 | 'additional_special_tokens': list(LABEL_DICT[data_args.task].values()) 313 | } 314 | num_added_toks = tokenizer.add_special_tokens(special_tokens_dict) 315 | config.num_labels = num_added_toks - 1 316 | if model_args.gpt2_name_or_path: 317 | gpt2 = GatedGPT2LMHeadModel.from_pretrained( 318 | model_args.gpt2_name_or_path, 319 | from_tf=bool(".ckpt" in model_args.gpt2_name_or_path), 320 | config=config, 321 | cache_dir=model_args.cache_dir, 322 | revision=model_args.model_revision, 323 | use_auth_token=True if model_args.use_auth_token else None, 324 | ) 325 | else: 326 | gpt2 = GatedGPT2LMHeadModel(config) 327 | n_params = sum(dict((p.data_ptr(), p.numel()) for p in gpt2.parameters()).values()) 328 | logger.info(f"Training new gpt2 from scratch - Total size={n_params/2**20:.2f}M params") 329 | 330 | gpt2.resize_token_embeddings(len(tokenizer)) 331 | gpt2.eval_acc = True 332 | 333 | if model_args.prefix_model_path: 334 | config = AutoConfig.from_pretrained(model_args.prefix_model_path, cache_dir=model_args.cache_dir) 335 | model = GPT2ForProbingViaPrompting.from_pretrained(model_args.prefix_model_path, config=config, gpt2=gpt2) 336 | else: 337 | config.prefix_len = model_args.prefix_len 338 | config.prefix_dim = model_args.prefix_dim 339 | config.prefix_drop = model_args.prefix_drop 340 | model = GPT2ForProbingViaPrompting(config, gpt2) 341 | 342 | # Preprocessing the datasets. 343 | # First we tokenize all the texts. 344 | if training_args.do_train: 345 | column_names = raw_datasets["train"].column_names 346 | else: 347 | column_names = raw_datasets["validation"].column_names 348 | text_column_name = "text" if "text" in column_names else column_names[0] 349 | 350 | # since this will be pickled to avoid _LazyModule error in Hasher force logger loading before tokenize_function 351 | tok_logger = transformers.utils.logging.get_logger("transformers.tokenization_utils_base") 352 | 353 | tokenizer.pad_token = tokenizer.eos_token 354 | 355 | def tokenize_function(examples): 356 | result = tokenizer(examples[text_column_name], padding="max_length", max_length=MAX_LENGTH[data_args.task]) 357 | examples = {} 358 | examples['input_ids'] = result['input_ids'] 359 | examples["labels"] = [ 360 | [-100 if mask == 0 else token for mask, token in mask_and_tokens] for mask_and_tokens in [zip(masks, labels) for masks, labels in zip(result["attention_mask"], result["input_ids"])] 361 | ] 362 | for i, elem in enumerate(examples['labels']): 363 | sep_idx = elem.index(tokenizer.eos_token_id) + 1 364 | examples['labels'][i][:sep_idx] = [-100] * sep_idx 365 | return examples 366 | 367 | with training_args.main_process_first(desc="dataset map tokenization"): 368 | tokenized_datasets = raw_datasets.map( 369 | tokenize_function, 370 | batched=True, 371 | num_proc=data_args.preprocessing_num_workers, 372 | remove_columns=column_names, 373 | load_from_cache_file=not data_args.overwrite_cache, 374 | desc="Running tokenizer on dataset", 375 | ) 376 | 377 | if training_args.do_train: 378 | if "train" not in tokenized_datasets: 379 | raise ValueError("--do_train requires a train dataset") 380 | train_dataset = tokenized_datasets["train"] 381 | if data_args.max_train_samples is not None: 382 | train_dataset = train_dataset.select(random.sample(range(len(train_dataset)), data_args.max_train_samples)) 383 | 384 | if training_args.do_eval: 385 | if "validation" not in tokenized_datasets: 386 | raise ValueError("--do_eval requires a validation dataset") 387 | eval_dataset = tokenized_datasets["validation"] 388 | if data_args.max_eval_samples is not None: 389 | eval_dataset = eval_dataset.select(range(data_args.max_eval_samples)) 390 | if model_args.head_mask_path != 'None': 391 | head_mask = torch.load(model_args.head_mask_path) 392 | if model_args.toggle_mask: 393 | head_mask = (head_mask == 0).float() 394 | model.gpt2.apply_masks(head_mask) 395 | 396 | 397 | if training_args.do_train: 398 | # Optimizer 399 | no_decay = ["bias", "LayerNorm.weight"] 400 | optimizer_grouped_parameters = [ 401 | { 402 | "params": [p for n, p in model.named_parameters() if not any(nd in n for nd in no_decay) and p.requires_grad], 403 | "weight_decay": training_args.weight_decay, 404 | "lr": training_args.learning_rate 405 | }, 406 | { 407 | "params": [p for n, p in model.named_parameters() if any(nd in n for nd in no_decay) and p.requires_grad], 408 | "weight_decay": 0.0, 409 | "lr": training_args.learning_rate 410 | }, 411 | ] 412 | 413 | if model_args.do_prune: 414 | model.gpt2.apply_dsp(model_args.num_of_heads) 415 | for n, p in model.named_parameters(): 416 | if n == "gpt2.w": 417 | p.requires_grad = True 418 | optimizer_grouped_parameters.append( 419 | { 420 | "params": [p for n, p in model.named_parameters() if n == "gpt2.w"], 421 | "lr": model_args.pruning_lr, 422 | } 423 | ) 424 | 425 | optimizer = AdamW(optimizer_grouped_parameters) 426 | else: 427 | optimizer = None 428 | 429 | # Initialize our Trainer 430 | trainer = PPTrainer( 431 | model=model, 432 | args=training_args, 433 | train_dataset=train_dataset if training_args.do_train else None, 434 | eval_dataset=eval_dataset if training_args.do_eval else None, 435 | tokenizer=tokenizer, 436 | # Data collator will default to DataCollatorWithPadding, so we change it. 437 | data_collator=default_data_collator, 438 | optimizers=(optimizer, None), 439 | ) 440 | 441 | # Training 442 | if training_args.do_train: 443 | checkpoint = None 444 | if training_args.resume_from_checkpoint is not None: 445 | checkpoint = training_args.resume_from_checkpoint 446 | elif last_checkpoint is not None: 447 | checkpoint = last_checkpoint 448 | train_result = trainer.train(resume_from_checkpoint=checkpoint) 449 | trainer.save_model() # Saves the tokenizer too for easy upload 450 | 451 | metrics = train_result.metrics 452 | 453 | max_train_samples = ( 454 | data_args.max_train_samples if data_args.max_train_samples is not None else len(train_dataset) 455 | ) 456 | metrics["train_samples"] = min(max_train_samples, len(train_dataset)) 457 | 458 | trainer.log_metrics("train", metrics) 459 | trainer.save_metrics("train", metrics) 460 | trainer.save_state() 461 | 462 | if model_args.do_prune: 463 | head_mask = convert_gate_to_mask(model.gpt2.w, model_args.num_of_heads) 464 | model.gpt2.apply_masks(head_mask) 465 | model.gpt2.use_dsp = False 466 | logger.info("Number of heads: {}".format(head_mask.sum())) 467 | logger.info(f'Number of heads in each layer: {head_mask.sum(-1)}') 468 | if training_args.output_dir is not None: 469 | torch.save(head_mask, os.path.join(training_args.output_dir, "mask" + str(model_args.num_of_heads) + ".pt")) 470 | 471 | # Evaluation 472 | if training_args.do_eval: 473 | logger.info("*** Evaluate ***") 474 | 475 | if model_args.head_mask_path != 'None': 476 | logger.info("Number of heads: {}".format(head_mask.sum())) 477 | 478 | metrics = trainer.evaluate() 479 | 480 | max_eval_samples = data_args.max_eval_samples if data_args.max_eval_samples is not None else len(eval_dataset) 481 | metrics["eval_samples"] = min(max_eval_samples, len(eval_dataset)) 482 | try: 483 | perplexity = math.exp(metrics["eval_loss"]) 484 | except OverflowError: 485 | perplexity = float("inf") 486 | metrics["perplexity"] = perplexity 487 | 488 | trainer.log_metrics("eval", metrics) 489 | trainer.save_metrics("eval", metrics) 490 | 491 | 492 | def _mp_fn(index): 493 | # For xla_spawn (TPUs) 494 | main() 495 | 496 | 497 | if __name__ == "__main__": 498 | main() 499 | -------------------------------------------------------------------------------- /trainer_pp.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2020-present the HuggingFace Inc. team. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | """ 16 | The Trainer class, to easily train a 🤗 Transformers from scratch or finetune it on a new task. 17 | """ 18 | 19 | import collections 20 | from typing import Any, Callable, Dict, List, Optional, Tuple, Union 21 | 22 | 23 | 24 | import numpy as np 25 | import torch 26 | from packaging import version 27 | from torch import nn 28 | from torch.utils.data import DataLoader, Dataset, IterableDataset 29 | from transformers.data.data_collator import DataCollator 30 | from transformers.tokenization_utils_base import PreTrainedTokenizerBase 31 | from transformers.trainer_callback import TrainerCallback 32 | from transformers.training_args import TrainingArguments 33 | from transformers.utils import logging 34 | from transformers.modeling_utils import PreTrainedModel 35 | from transformers.trainer_pt_utils import ( 36 | IterableDatasetShard, 37 | find_batch_size, 38 | nested_concat, 39 | nested_detach, 40 | nested_numpify, 41 | nested_truncate, 42 | ) 43 | from transformers.trainer_utils import ( 44 | EvalLoopOutput, 45 | EvalPrediction, 46 | denumpify_detensorize, 47 | ) 48 | from transformers import Trainer 49 | 50 | if version.parse(torch.__version__) >= version.parse("1.6"): 51 | from torch.cuda.amp import autocast 52 | 53 | logger = logging.get_logger(__name__) 54 | 55 | 56 | class PPTrainer(Trainer): 57 | def __init__( 58 | self, 59 | model: Union[PreTrainedModel, nn.Module] = None, 60 | args: TrainingArguments = None, 61 | data_collator: Optional[DataCollator] = None, 62 | train_dataset: Optional[Dataset] = None, 63 | eval_dataset: Optional[Dataset] = None, 64 | tokenizer: Optional[PreTrainedTokenizerBase] = None, 65 | model_init: Callable[[], PreTrainedModel] = None, 66 | compute_metrics: Optional[Callable[[EvalPrediction], Dict]] = None, 67 | callbacks: Optional[List[TrainerCallback]] = None, 68 | optimizers: Tuple[torch.optim.Optimizer, torch.optim.lr_scheduler.LambdaLR] = (None, None), 69 | ): 70 | super().__init__(model, args, data_collator, train_dataset, eval_dataset, tokenizer, model_init, compute_metrics, callbacks, optimizers) 71 | 72 | def evaluation_loop( 73 | self, 74 | dataloader: DataLoader, 75 | description: str, 76 | prediction_loss_only: Optional[bool] = None, 77 | ignore_keys: Optional[List[str]] = None, 78 | metric_key_prefix: str = "eval", 79 | ) -> EvalLoopOutput: 80 | """ 81 | Prediction/evaluation loop, shared by :obj:`Trainer.evaluate()` and :obj:`Trainer.predict()`. 82 | 83 | Works both with or without labels. 84 | """ 85 | prediction_loss_only = ( 86 | prediction_loss_only if prediction_loss_only is not None else self.args.prediction_loss_only 87 | ) 88 | 89 | model = self._wrap_model(self.model, training=False) 90 | 91 | # if full fp16 is wanted on eval and this ``evaluation`` or ``predict`` isn't called while 92 | # ``train`` is running, halve it first and then put on device 93 | if not self.is_in_train and self.args.fp16_full_eval: 94 | model = model.half().to(self.args.device) 95 | 96 | batch_size = dataloader.batch_size 97 | 98 | logger.info(f"***** Running {description} *****") 99 | if isinstance(dataloader.dataset, collections.abc.Sized): 100 | logger.info(f" Num examples = {self.num_examples(dataloader)}") 101 | else: 102 | logger.info(" Num examples: Unknown") 103 | logger.info(f" Batch size = {batch_size}") 104 | 105 | model.eval() 106 | 107 | self.callback_handler.eval_dataloader = dataloader 108 | # Do this before wrapping. 109 | eval_dataset = dataloader.dataset 110 | 111 | if self.args.past_index >= 0: 112 | self._past = None 113 | 114 | # Initialize containers 115 | # losses/preds/labels on GPU/TPU (accumulated for eval_accumulation_steps) 116 | losses_host = None 117 | accuracies_host = None 118 | preds_host = None 119 | labels_host = None 120 | # losses/preds/labels on CPU (final containers) 121 | all_losses = None 122 | all_accuracies = None 123 | all_preds = None 124 | all_labels = None 125 | # Will be useful when we have an iterable dataset so don't know its length. 126 | 127 | observed_num_examples = 0 128 | # Main evaluation loop 129 | for step, inputs in enumerate(dataloader): 130 | # Update the observed num examples 131 | observed_batch_size = find_batch_size(inputs) 132 | if observed_batch_size is not None: 133 | observed_num_examples += observed_batch_size 134 | # For batch samplers, batch_size is not known by the dataloader in advance. 135 | if batch_size is None: 136 | batch_size = observed_batch_size 137 | 138 | # Prediction step 139 | loss, accuracy, logits, labels = self.prediction_step(model, inputs, prediction_loss_only, ignore_keys=ignore_keys) 140 | 141 | # Update containers on host 142 | if loss is not None: 143 | losses = self._nested_gather(loss.repeat(batch_size)) 144 | losses_host = losses if losses_host is None else torch.cat((losses_host, losses), dim=0) 145 | if accuracy is not None: 146 | accuracies = self._nested_gather(accuracy) 147 | accuracies_host = accuracies if accuracies_host is None else torch.cat((accuracies_host, accuracies), dim=0) 148 | if logits is not None: 149 | logits = self._pad_across_processes(logits) 150 | logits = self._nested_gather(logits) 151 | preds_host = logits if preds_host is None else nested_concat(preds_host, logits, padding_index=-100) 152 | if labels is not None: 153 | labels = self._pad_across_processes(labels) 154 | labels = self._nested_gather(labels) 155 | labels_host = labels if labels_host is None else nested_concat(labels_host, labels, padding_index=-100) 156 | self.control = self.callback_handler.on_prediction_step(self.args, self.state, self.control) 157 | 158 | # Gather all tensors and put them back on the CPU if we have done enough accumulation steps. 159 | if self.args.eval_accumulation_steps is not None and (step + 1) % self.args.eval_accumulation_steps == 0: 160 | if losses_host is not None: 161 | losses = nested_numpify(losses_host) 162 | all_losses = losses if all_losses is None else np.concatenate((all_losses, losses), axis=0) 163 | if accuracies_host is not None: 164 | accuracies = nested_numpify(accuracies_host) 165 | all_accuracies = accuracies if all_accuracies is None else np.concatenate((all_accuracies, accuracies), axis=0) 166 | if preds_host is not None: 167 | logits = nested_numpify(preds_host) 168 | all_preds = logits if all_preds is None else nested_concat(all_preds, logits, padding_index=-100) 169 | if labels_host is not None: 170 | labels = nested_numpify(labels_host) 171 | all_labels = ( 172 | labels if all_labels is None else nested_concat(all_labels, labels, padding_index=-100) 173 | ) 174 | 175 | # Set back to None to begin a new accumulation 176 | losses_host, accuracies_host, preds_host, labels_host = None, None, None, None 177 | 178 | if self.args.past_index and hasattr(self, "_past"): 179 | # Clean the state at the end of the evaluation loop 180 | delattr(self, "_past") 181 | 182 | # Gather all remaining tensors and put them back on the CPU 183 | if losses_host is not None: 184 | losses = nested_numpify(losses_host) 185 | all_losses = losses if all_losses is None else np.concatenate((all_losses, losses), axis=0) 186 | if accuracies_host is not None: 187 | accuracies = nested_numpify(accuracies_host) 188 | all_accuracies = accuracies if all_accuracies is None else np.concatenate((all_accuracies, accuracies), axis=0) 189 | if preds_host is not None: 190 | logits = nested_numpify(preds_host) 191 | all_preds = logits if all_preds is None else nested_concat(all_preds, logits, padding_index=-100) 192 | if labels_host is not None: 193 | labels = nested_numpify(labels_host) 194 | all_labels = labels if all_labels is None else nested_concat(all_labels, labels, padding_index=-100) 195 | 196 | # Number of samples 197 | if not isinstance(eval_dataset, IterableDataset): 198 | num_samples = len(eval_dataset) 199 | # The instance check is weird and does not actually check for the type, but whether the dataset has the right 200 | # methods. Therefore we need to make sure it also has the attribute. 201 | elif isinstance(eval_dataset, IterableDatasetShard) and hasattr(eval_dataset, "num_examples"): 202 | num_samples = eval_dataset.num_examples 203 | else: 204 | num_samples = observed_num_examples 205 | 206 | # Number of losses has been rounded to a multiple of batch_size and in a distributed training, the number of 207 | # samplers has been rounded to a multiple of batch_size, so we truncate. 208 | if all_losses is not None: 209 | all_losses = all_losses[:num_samples] 210 | if all_accuracies is not None: 211 | all_accuracies = all_accuracies[:num_samples] 212 | if all_preds is not None: 213 | all_preds = nested_truncate(all_preds, num_samples) 214 | if all_labels is not None: 215 | all_labels = nested_truncate(all_labels, num_samples) 216 | 217 | # Metrics! 218 | if self.compute_metrics is not None and all_preds is not None and all_labels is not None: 219 | metrics = self.compute_metrics(EvalPrediction(predictions=all_preds, label_ids=all_labels)) 220 | else: 221 | metrics = {} 222 | 223 | # To be JSON-serializable, we need to remove numpy types or zero-d tensors 224 | metrics = denumpify_detensorize(metrics) 225 | 226 | if all_losses is not None: 227 | metrics[f"{metric_key_prefix}_loss"] = all_losses.mean().item() 228 | if all_accuracies is not None: 229 | metrics[f"{metric_key_prefix}_accuracy"] = all_accuracies.mean().item() 230 | 231 | # Prefix all keys with metric_key_prefix + '_' 232 | for key in list(metrics.keys()): 233 | if not key.startswith(f"{metric_key_prefix}_"): 234 | metrics[f"{metric_key_prefix}_{key}"] = metrics.pop(key) 235 | 236 | return EvalLoopOutput(predictions=all_preds, label_ids=all_labels, metrics=metrics, num_samples=num_samples) 237 | 238 | def prediction_step( 239 | self, 240 | model: nn.Module, 241 | inputs: Dict[str, Union[torch.Tensor, Any]], 242 | prediction_loss_only: bool, 243 | ignore_keys: Optional[List[str]] = None, 244 | ) -> Tuple[Optional[torch.Tensor], Optional[torch.Tensor], Optional[torch.Tensor]]: 245 | """ 246 | Perform an evaluation step on :obj:`model` using obj:`inputs`. 247 | 248 | Subclass and override to inject custom behavior. 249 | 250 | Args: 251 | model (:obj:`nn.Module`): 252 | The model to evaluate. 253 | inputs (:obj:`Dict[str, Union[torch.Tensor, Any]]`): 254 | The inputs and targets of the model. 255 | 256 | The dictionary will be unpacked before being fed to the model. Most models expect the targets under the 257 | argument :obj:`labels`. Check your model's documentation for all accepted arguments. 258 | prediction_loss_only (:obj:`bool`): 259 | Whether or not to return the loss only. 260 | ignore_keys (:obj:`Lst[str]`, `optional`): 261 | A list of keys in the output of your model (if it is a dictionary) that should be ignored when 262 | gathering predictions. 263 | 264 | Return: 265 | Tuple[Optional[torch.Tensor], Optional[torch.Tensor], Optional[torch.Tensor]]: A tuple with the loss, 266 | logits and labels (each being optional). 267 | """ 268 | has_labels = all(inputs.get(k) is not None for k in self.label_names) 269 | inputs = self._prepare_inputs(inputs) 270 | if ignore_keys is None: 271 | if hasattr(self.model, "config"): 272 | ignore_keys = getattr(self.model.config, "keys_to_ignore_at_inference", []) 273 | else: 274 | ignore_keys = [] 275 | 276 | # labels may be popped when computing the loss (label smoothing for instance) so we grab them first. 277 | if has_labels: 278 | labels = nested_detach(tuple(inputs.get(name) for name in self.label_names)) 279 | if len(labels) == 1: 280 | labels = labels[0] 281 | else: 282 | labels = None 283 | 284 | with torch.no_grad(): 285 | if has_labels: 286 | if self.use_amp: 287 | with autocast(): 288 | loss, outputs = self.compute_loss(model, inputs, return_outputs=True) 289 | else: 290 | loss, outputs = self.compute_loss(model, inputs, return_outputs=True) 291 | loss = loss.mean().detach() 292 | if isinstance(outputs, dict): 293 | accuracy = outputs["accuracy"] 294 | logits = tuple(v for k, v in outputs.items() if k not in ignore_keys + ["loss", "accuracy"]) 295 | else: 296 | accuracy = outputs[1] 297 | logits = outputs[2:] 298 | else: 299 | loss = None 300 | accuracy = None 301 | if self.use_amp: 302 | with autocast(): 303 | outputs = model(**inputs) 304 | else: 305 | outputs = model(**inputs) 306 | if isinstance(outputs, dict): 307 | logits = tuple(v for k, v in outputs.items() if k not in ignore_keys) 308 | else: 309 | logits = outputs 310 | # TODO: this needs to be fixed and made cleaner later. 311 | if self.args.past_index >= 0: 312 | self._past = outputs[self.args.past_index - 1] 313 | 314 | if prediction_loss_only: 315 | return (loss, accuracy, None, None) 316 | 317 | logits = nested_detach(logits) 318 | if len(logits) == 1: 319 | logits = logits[0] 320 | 321 | return (loss, accuracy, logits, labels) -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | LABEL_DICT = {} 4 | LABEL_DICT['ner'] = ['CARDINAL', 'DATE', 'EVENT', 'FAC', 'GPE', 'LANGUAGE', 5 | 'LAW', 'LOC', 'MONEY', 'NORP', 'ORDINAL', 'ORG', 'PERCENT', 'PERSON', 'PRODUCT', 6 | 'QUANTITY', 'TIME', 'WORK_OF_ART'] 7 | LABEL_DICT['pos'] = ['$', "''", ',', '-LRB-', '-RRB-', '.', ':', 'ADD', 'AFX', 8 | 'CC', 'CD', 'DT', 'EX', 'FW', 'HYPH', 'IN', 'JJ', 'JJR', 'JJS', 'LS', 'MD', 9 | 'NFP', 'NN', 'NNP', 'NNPS', 'NNS', 'PDT', 'POS', 'PRP', 'PRP$', 'RB', 'RBR', 10 | 'RBS', 'RP', 'SYM', 'TO', 'UH', 'VB', 'VBD', 'VBG', 'VBN', 'VBP', 'VBZ', 11 | 'WDT', 'WP', 'WP$', 'WRB', '``'] 12 | LABEL_DICT['const'] = ['ADJP', 'ADVP', 'CONJP', 'EMBED', 'FRAG', 'INTJ', 'LST', 13 | 'META', 'NAC', 'NML', 'NP', 'NX', 'PP', 'PRN', 'PRT', 'QP', 'RRC', 'S', 'SBAR', 14 | 'SBARQ', 'SINV', 'SQ', 'TOP', 'UCP', 'VP', 'WHADJP', 'WHADVP', 'WHNP', 'WHPP', 15 | 'X'] 16 | LABEL_DICT['coref'] = ['False', 'True'] 17 | LABEL_DICT['srl'] = ['ARG0', 'ARG1', 'ARG2', 'ARG3', 'ARG4', 'ARG5', 'ARGA', 18 | 'ARGM-ADJ', 'ARGM-ADV', 'ARGM-CAU', 'ARGM-COM', 'ARGM-DIR', 'ARGM-DIS', 'ARGM-DSP', 19 | 'ARGM-EXT', 'ARGM-GOL', 'ARGM-LOC', 'ARGM-LVB', 'ARGM-MNR', 'ARGM-MOD', 'ARGM-NEG', 20 | 'ARGM-PNC', 'ARGM-PRD', 'ARGM-PRP', 'ARGM-PRR', 'ARGM-PRX', 'ARGM-REC', 'ARGM-TMP', 21 | 'C-ARG0', 'C-ARG1', 'C-ARG2', 'C-ARG3', 'C-ARG4', 'C-ARGM-ADJ', 'C-ARGM-ADV', 22 | 'C-ARGM-CAU', 'C-ARGM-COM', 'C-ARGM-DIR', 'C-ARGM-DIS', 'C-ARGM-DSP', 'C-ARGM-EXT', 23 | 'C-ARGM-LOC', 'C-ARGM-MNR', 'C-ARGM-MOD', 'C-ARGM-NEG', 'C-ARGM-PRP', 'C-ARGM-TMP', 24 | 'R-ARG0', 'R-ARG1', 'R-ARG2', 'R-ARG3', 'R-ARG4', 'R-ARG5', 'R-ARGM-ADV', 'R-ARGM-CAU', 25 | 'R-ARGM-COM', 'R-ARGM-DIR', 'R-ARGM-EXT', 'R-ARGM-GOL', 'R-ARGM-LOC', 'R-ARGM-MNR', 26 | 'R-ARGM-MOD', 'R-ARGM-PNC', 'R-ARGM-PRD', 'R-ARGM-PRP', 'R-ARGM-TMP'] 27 | for task in LABEL_DICT: 28 | LABEL_DICT[task] = {label: "label" + str(i) for i, label in enumerate(LABEL_DICT[task])} 29 | 30 | 31 | def convert_gate_to_mask(gates, num_of_heads=None): 32 | if num_of_heads is not None: 33 | head_mask = torch.zeros_like(gates) 34 | current_heads_to_keep = gates.view(-1).sort(descending = True)[1] 35 | current_heads_to_keep = current_heads_to_keep[:num_of_heads] 36 | head_mask = head_mask.view(-1) 37 | head_mask[current_heads_to_keep] = 1.0 38 | head_mask = head_mask.view_as(gates) 39 | else: 40 | head_mask = (gates > 0.5).float() 41 | return head_mask 42 | 43 | class STEFunction(torch.autograd.Function): 44 | @staticmethod 45 | def forward(ctx, input, k): 46 | threshold = input.sort(descending = True)[0][k] 47 | return (input > threshold).float() 48 | 49 | @staticmethod 50 | def backward(ctx, grad_output): 51 | return grad_output, None --------------------------------------------------------------------------------