.
675 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # Source code for paper "Zeroth-Order Fine-Tuning of LLMs in Random Subspaces"
2 |
3 | This is the implementation for the paper [Zeroth-Order Fine-Tuning of LLMs in Random Subspaces](http://arxiv.org/abs/2410.08989).
4 |
5 | In this paper, we propose the random Subspace Zeroth-order (SubZero) optimization to address the challenges posed by LLMs’ high dimensionality. We introduce a low-rank perturbation tailored for LLMs that significantly reduces memory consumption while improving training performance. Additionally, we have successfully applied SubZero to four popular fine-tuning schemes for LLMs, including full parameter tuning, LoRA, prefix tuning, and prompt tuning. This demonstrates SubZero's compatibility and versatility across different tuning approaches.
6 |
7 | Furthermore, we prove that our gradient estimation closely approximates the backpropagation gradient, exhibits lower variance than traditional ZO methods, and ensures convergence when combined with SGD. Experimental results show that SubZero enhances fine-tuning performance and achieves faster convergence compared to standard ZO approaches like [MeZO](https://github.com/princeton-nlp/MeZO) across various language modeling tasks.
8 |
9 |
10 |
11 |
12 |
13 | Visualization of cosine similarity, relative variance, training loss and GPU memory cost on OPT-1.3B under the prompt tuning scheme. SubZero demonstrates reduced angle error and variance in gradient estimation, while also accelerating convergence with minimal additional memory overhead.
14 |
15 |
16 |
17 | ## Getting start
18 | - We use python 3.10 and torch 2.1.0, transformers 4.28.1, and cuda 11.8.0.
19 | - pip install -r requirements.txt
20 |
21 | ## Usage
22 |
23 | Use `run.py` for all functions (zero-shot/ICL/fine-tuning/MeZO/SubZero):
24 | ```bash
25 | python run.py {ARGUMENTS}
26 | ```
27 |
28 | Please read `run.py` for a complete list of arguments. We introduce some of the most important ones below.
29 | * `--num_train`: Number of training examples. For ICL, this is the number of demonstrations.
30 | * `--num_dev`: Number of validation examples.
31 | * `--num_test`: Number of testing examples.
32 | * `--model_name`: HuggingFace model name or path.
33 | * `--task_name`: Task name.
34 | * `--trainer`: can be `none` (zero-shot/ICL), `regular` (fine-tuning), or `zo_sgd` (MeZO) or `subzero_sgd`(SubZero).
35 | * `--train_as_classification`: turn this on for classification tasks (Cross Entropy over likelihood of each class' label words). Otherwise it is LM-style teacher forcing.
36 | * `--zo_eps`: ZO hyperparameter epsilon
37 | * `--prefix_tuning`: use prefix-tuning.
38 | * `--lora`: use LoRA.
39 | * `--prompt_tuning`: use prompt-tuning.
40 |
41 | ## Reproducing Results
42 |
43 | We provide an example of the OPT-1.3b model performing prompt tuning on the SST-2 dataset.
44 |
45 | ### MeZO-SGD
46 | `CUDA_VISIBLE_DEVICES=0 python run.py --task_name=SST2 --model_name=facebook/opt-1.3b --output_dir=result/opt1.3b-SST2-prompt-mezo --num_train_epochs=5 --per_device_train_batch_size=16 --load_best_model_at_end --evaluation_strategy=steps --save_strategy=steps --save_total_limit=1 --eval_steps=1000 --max_steps=20000 --logging_steps=10 --num_eval=1000 --num_train=1000 --num_dev=500 --train_as_classification --perturbation_mode=two_side --trainer=zo_sgd --train_set_seed=0 --lr_scheduler_type=constant --eval_steps=500 --save_steps=500 --prompt_tuning --num_virtual_tokens=10 --prompt_init_by_real_tokens --learning_rate=1e-3 --zo_eps=1e-2 --weight_decay=0`
47 |
48 | ### SubZero-SGD
49 | `CUDA_VISIBLE_DEVICES=0 python run.py --task_name=SST2 --model_name=facebook/opt-1.3b --output_dir=result/opt1.3b-SST2-prompt-subzero --num_train_epochs=5 --per_device_train_batch_size=16 --load_best_model_at_end --evaluation_strategy=steps --save_strategy=steps --save_total_limit=1 --eval_steps=1000 --max_steps=20000 --logging_steps=10 --num_eval=1000 --num_train=1000 --num_dev=500 --train_as_classification --perturbation_mode=two_side --trainer=subzero_sgd --train_set_seed=0 --lr_scheduler_type=constant --eval_steps=500 --save_steps=500 --prompt_tuning --num_virtual_tokens=10 --prompt_init_by_real_tokens --learning_rate=1e-3 --zo_eps=1e-2 --weight_decay=0 --gauss_rank=24 --update_interval=1000`
50 |
51 | ### FO-SGD
52 | `CUDA_VISIBLE_DEVICES=0 python run.py --task_name=SST2 --model_name=facebook/opt-1.3b --output_dir=result/opt1.3b-SST2-prompt-sgd --num_train_epochs=5 --per_device_train_batch_size=16 --load_best_model_at_end --evaluation_strategy=steps --save_strategy=steps --save_total_limit=1 --eval_steps=1000 --max_steps=20000 --logging_steps=10 --num_eval=1000 --num_train=1000 --num_dev=500 --train_as_classification --perturbation_mode=two_side --trainer=sgd --optimizer=sgd --train_set_seed=0 --lr_scheduler_type=constant --eval_steps=500 --save_steps=500 --prompt_tuning --num_virtual_tokens=10 --prompt_init_by_real_tokens --learning_rate=1e-3 --zo_eps=1e-2 --weight_decay=0`
53 |
54 | ## Acknowledgment
55 |
56 | This project is built upon the foundation laid by [MeZO: Fine-Tuning Language Models with Just Forward Passes](https://github.com/princeton-nlp/MeZO) and [Revisiting Zeroth-Order Optimization for Memory-Efficient LLM Fine-Tuning: A Benchmark](https://github.com/ZO-Bench/ZO-LLM/tree/main). The original code from their project is licensed under the [MIT License](https://github.com/princeton-nlp/MeZO/blob/main/LICENSE) and [License](https://github.com/ZO-Bench/ZO-LLM/blob/main/LICENSE) respectively. We would like to thank the authors for their great work and contributions.
57 |
--------------------------------------------------------------------------------
/figure/subzero.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/zimingyy/SubZero/cf6effdd0aef2b1cea79b2a121dc2ad4e7037414/figure/subzero.png
--------------------------------------------------------------------------------
/large_models/lora.py:
--------------------------------------------------------------------------------
1 | import logging
2 |
3 | logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
4 | logger = logging.getLogger(__name__)
5 | logger.setLevel(logging.INFO)
6 |
7 | import torch
8 | from torch import nn
9 | from torch.nn import functional as F
10 | import math
11 |
12 |
13 | def find_module(root_module: nn.Module, key: str):
14 | """
15 | Find a module with a specific name in a Transformer model
16 | From OpenDelta https://github.com/thunlp/OpenDelta
17 | """
18 | sub_keys = key.split(".")
19 | parent_module = root_module
20 | for sub_key in sub_keys[:-1]:
21 | parent_module = getattr(parent_module, sub_key)
22 | module = getattr(parent_module, sub_keys[-1])
23 | return parent_module, sub_keys[-1], module
24 |
25 |
26 | class LoRALinear(nn.Linear):
27 | """
28 | LoRA implemented in a dense layer
29 | From https://github.com/microsoft/LoRA/blob/main/loralib/layers.py
30 | """
31 |
32 | def __init__(
33 | self,
34 | in_features: int,
35 | out_features: int,
36 | r: int = 0,
37 | lora_alpha: int = 1,
38 | lora_dropout: float = 0.,
39 | fan_in_fan_out: bool = False,
40 | # Set this to True if the layer to replace stores weight like (fan_in, fan_out)
41 | merge_weights: bool = False,
42 | # Not sure if this will affect saving/loading models so just set it to be False
43 | **kwargs
44 | ):
45 | nn.Linear.__init__(self, in_features, out_features, **kwargs)
46 |
47 | self.r = r
48 | self.lora_alpha = lora_alpha
49 | # Optional dropout
50 | if lora_dropout > 0.:
51 | self.lora_dropout = nn.Dropout(p=lora_dropout)
52 | else:
53 | self.lora_dropout = lambda x: x
54 | # Mark the weight as unmerged
55 | self.merged = False
56 | self.merge_weights = merge_weights
57 | self.fan_in_fan_out = fan_in_fan_out
58 | # Actual trainable parameters
59 | if r > 0:
60 | self.lora_A = nn.Parameter(self.weight.new_zeros((r, in_features)))
61 | self.lora_B = nn.Parameter(self.weight.new_zeros((out_features, r)))
62 | self.scaling = self.lora_alpha / self.r
63 | # Freezing the pre-trained weight matrix
64 | self.weight.requires_grad = False
65 | self.reset_parameters()
66 | if fan_in_fan_out:
67 | self.weight.data = self.weight.data.transpose(0, 1)
68 |
69 | def reset_parameters(self):
70 | nn.Linear.reset_parameters(self)
71 | if hasattr(self, 'lora_A'):
72 | # initialize A the same way as the default for nn.Linear and B to zero
73 | nn.init.kaiming_uniform_(self.lora_A, a=math.sqrt(5))
74 | nn.init.zeros_(self.lora_B)
75 |
76 | def train(self, mode: bool = True):
77 | def T(w):
78 | return w.transpose(0, 1) if self.fan_in_fan_out else w
79 |
80 | nn.Linear.train(self, mode)
81 | if mode:
82 | if self.merge_weights and self.merged:
83 | # Make sure that the weights are not merged
84 | if self.r > 0:
85 | self.weight.data -= T(self.lora_B @ self.lora_A) * self.scaling
86 | self.merged = False
87 | else:
88 | if self.merge_weights and not self.merged:
89 | # Merge the weights and mark it
90 | if self.r > 0:
91 | self.weight.data += T(self.lora_B @ self.lora_A) * self.scaling
92 | self.merged = True
93 |
94 | def forward(self, x: torch.Tensor):
95 | def T(w):
96 | return w.transpose(0, 1) if self.fan_in_fan_out else w
97 |
98 | if self.r > 0 and not self.merged:
99 | result = F.linear(x, T(self.weight), bias=self.bias)
100 | if self.r > 0:
101 | result += (self.lora_dropout(x) @ self.lora_A.transpose(0, 1) @ self.lora_B.transpose(0,
102 | 1)) * self.scaling
103 | return result
104 | else:
105 | return F.linear(x, T(self.weight), bias=self.bias)
106 |
107 |
108 | class LoRA:
109 |
110 | def __init__(self, model, r, alpha, float16):
111 | """
112 | Input:
113 | r, alpha: LoRA hyperparameters
114 | float16: Whether the model parameters are float16 or not
115 | """
116 |
117 | self.model = model
118 | self.hidden_dim = model.config.hidden_size
119 | self.float16 = float16
120 |
121 | if model.config.model_type == "opt":
122 | attention_name = "attn"
123 | elif model.config.model_type == "roberta":
124 | attention_name = "attention"
125 | elif model.config.model_type in ["llama", "mistral"]:
126 | attention_name = "self_attn"
127 | else:
128 | raise NotImplementedError
129 |
130 | # Insert LoRA
131 | for key, _ in model.named_modules():
132 | if key[-len(attention_name):] == attention_name:
133 | logger.info(f"Inject lora to: {key}")
134 | _, _, attn = find_module(model, key)
135 |
136 | if model.config.model_type == "opt":
137 | original_q_weight = attn.q_proj.weight.data
138 | original_q_bias = attn.q_proj.bias.data
139 | original_v_weight = attn.v_proj.weight.data
140 | original_v_bias = attn.v_proj.bias.data
141 | attn.q_proj = LoRALinear(model.config.hidden_size, model.config.hidden_size, r=r, lora_alpha=alpha,
142 | bias=model.config.enable_bias).to(original_q_weight.device)
143 | attn.v_proj = LoRALinear(model.config.hidden_size, model.config.hidden_size, r=r, lora_alpha=alpha,
144 | bias=model.config.enable_bias).to(original_v_weight.device)
145 | if float16:
146 | attn.q_proj.half()
147 | attn.v_proj.half()
148 | attn.q_proj.weight.data = original_q_weight
149 | attn.q_proj.bias.data = original_q_bias
150 | attn.v_proj.weight.data = original_v_weight
151 | attn.v_proj.bias.data = original_v_bias
152 | elif model.config.model_type == "llama":
153 | # in early version of transformers, llama attention bias is hard coded to False
154 | attention_bias = False if not hasattr(model.config, "attention_bias") else model.config.attention_bias
155 | original_q_weight = attn.q_proj.weight.data
156 | original_v_weight = attn.v_proj.weight.data
157 | original_q_bias = attn.q_proj.bias.data if attention_bias else None
158 | original_v_bias = attn.v_proj.bias.data if attention_bias else None
159 | attn.q_proj = LoRALinear(
160 | model.config.hidden_size,
161 | model.config.hidden_size,
162 | r=r, lora_alpha=alpha, bias=attention_bias
163 | ).to(original_q_weight.device)
164 | attn.v_proj = LoRALinear(
165 | model.config.hidden_size,
166 | model.config.hidden_size,
167 | r=r, lora_alpha=alpha, bias=attention_bias
168 | ).to(original_v_weight.device)
169 | if float16:
170 | attn.q_proj.half()
171 | attn.v_proj.half()
172 | attn.q_proj.weight.data = original_q_weight
173 | attn.v_proj.weight.data = original_v_weight
174 | if attention_bias:
175 | attn.q_proj.bias.data = original_q_bias
176 | attn.v_proj.bias.data = original_v_bias
177 | elif model.config.model_type == "mistral":
178 | # in early version of transformers, llama attention bias is hard coded to False
179 | config = model.config
180 | original_q_weight = attn.q_proj.weight.data
181 | original_v_weight = attn.v_proj.weight.data
182 | head_dim = config.hidden_size // config.num_attention_heads
183 | attn.q_proj = LoRALinear(
184 | config.hidden_size,
185 | config.hidden_size,
186 | r=r, lora_alpha=alpha
187 | ).to(original_q_weight.device)
188 | attn.v_proj = LoRALinear(
189 | config.hidden_size,
190 | config.num_key_value_heads * head_dim,
191 | r=r, lora_alpha=alpha
192 | ).to(original_v_weight.device)
193 | if float16:
194 | attn.q_proj.half()
195 | attn.v_proj.half()
196 | attn.q_proj.weight.data = original_q_weight
197 | attn.v_proj.weight.data = original_v_weight
198 | else:
199 | raise NotImplementedError
200 |
201 | # Freeze non-LoRA parameters
202 | for n, p in model.named_parameters():
203 | if "lora" not in n:
204 | p.requires_grad = False
205 |
--------------------------------------------------------------------------------
/large_models/metrics.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import collections
3 | import re
4 | import string
5 | from collections import Counter
6 |
7 | def normalize_answer(s):
8 | """Lower text and remove punctuation, articles and extra whitespace."""
9 |
10 | def remove_articles(text):
11 | return re.sub(r'\b(a|an|the)\b', ' ', text)
12 |
13 | def white_space_fix(text):
14 | return ' '.join(text.split())
15 |
16 | def remove_punc(text):
17 | exclude = set(string.punctuation)
18 | return ''.join(ch for ch in text if ch not in exclude)
19 |
20 | def lower(text):
21 | return text.lower()
22 |
23 | return white_space_fix(remove_articles(remove_punc(lower(s))))
24 |
25 |
26 | def calculate_metric(predictions, metric_name):
27 | if metric_name == "accuracy":
28 | if isinstance(predictions[0].correct_candidate, list):
29 | return np.mean([pred.predicted_candidate in pred.correct_candidate for pred in predictions])
30 | else:
31 | return np.mean([pred.correct_candidate == pred.predicted_candidate for pred in predictions])
32 | elif metric_name == "em":
33 | # For question answering
34 | return np.mean([any([normalize_answer(ans) == normalize_answer(pred.predicted_candidate) for ans in pred.correct_candidate]) for pred in predictions])
35 | elif metric_name == "f1":
36 | # For question answering
37 | f1 = []
38 | for pred in predictions:
39 | all_f1s = []
40 | if pred.correct_candidate[0] == "CANNOTANSWER" or pred.correct_candidate[0] == "no answer":
41 | f1.append(int(normalize_answer(pred.correct_candidate[0]) == normalize_answer(pred.predicted_candidate)))
42 | else:
43 | for ans in pred.correct_candidate:
44 | prediction_tokens = normalize_answer(pred.predicted_candidate).split()
45 | ground_truth_tokens = normalize_answer(ans).split()
46 | common = Counter(prediction_tokens) & Counter(ground_truth_tokens)
47 | num_same = sum(common.values())
48 | if num_same == 0:
49 | all_f1s.append(0)
50 | else:
51 | precision = 1.0 * num_same / len(prediction_tokens)
52 | recall = 1.0 * num_same / len(ground_truth_tokens)
53 | all_f1s.append((2 * precision * recall) / (precision + recall))
54 | f1.append(max(all_f1s))
55 |
56 | return np.mean(f1)
57 |
58 |
59 | def f1(pred, gold):
60 | """
61 | This separate F1 function is used as non-differentiable metric for SQuAD
62 | """
63 | if gold[0] == "CANNOTANSWER" or gold[0] == "no answer":
64 | return int(normalize_answer(gold[0]) == normalize_answer(pred))
65 | else:
66 | all_f1s = []
67 | for ans in gold:
68 | prediction_tokens = normalize_answer(pred).split()
69 | ground_truth_tokens = normalize_answer(ans).split()
70 | common = Counter(prediction_tokens) & Counter(ground_truth_tokens)
71 | num_same = sum(common.values())
72 | if num_same == 0:
73 | all_f1s.append(0)
74 | else:
75 | precision = 1.0 * num_same / len(prediction_tokens)
76 | recall = 1.0 * num_same / len(ground_truth_tokens)
77 | all_f1s.append((2 * precision * recall) / (precision + recall))
78 | return np.max(all_f1s)
--------------------------------------------------------------------------------
/large_models/modeling_mistral/__init__.py:
--------------------------------------------------------------------------------
1 | from .configuration_mistral import MistralConfig, MISTRAL_PRETRAINED_CONFIG_ARCHIVE_MAP
2 | from .modeling_mistral import (
3 | MistralModel,
4 | MistralForCausalLM,
5 | MistralForSequenceClassification,
6 | MistralPreTrainedModel,
7 | MistralForCausalLMWithHeadTuning
8 | )
9 |
--------------------------------------------------------------------------------
/large_models/modeling_mistral/__pycache__/__init__.cpython-310.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/zimingyy/SubZero/cf6effdd0aef2b1cea79b2a121dc2ad4e7037414/large_models/modeling_mistral/__pycache__/__init__.cpython-310.pyc
--------------------------------------------------------------------------------
/large_models/modeling_mistral/__pycache__/configuration_mistral.cpython-310.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/zimingyy/SubZero/cf6effdd0aef2b1cea79b2a121dc2ad4e7037414/large_models/modeling_mistral/__pycache__/configuration_mistral.cpython-310.pyc
--------------------------------------------------------------------------------
/large_models/modeling_mistral/__pycache__/modeling_mistral.cpython-310.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/zimingyy/SubZero/cf6effdd0aef2b1cea79b2a121dc2ad4e7037414/large_models/modeling_mistral/__pycache__/modeling_mistral.cpython-310.pyc
--------------------------------------------------------------------------------
/large_models/modeling_mistral/configuration_mistral.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2023 Mistral AI and the HuggingFace Inc. team. All rights reserved.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 | """ Mistral model configuration"""
16 |
17 | from transformers.configuration_utils import PretrainedConfig
18 | from transformers.utils import logging
19 |
20 |
21 | logger = logging.get_logger(__name__)
22 |
23 | MISTRAL_PRETRAINED_CONFIG_ARCHIVE_MAP = {
24 | "mistralai/Mistral-7B-v0.1": "https://huggingface.co/mistralai/Mistral-7B-v0.1/resolve/main/config.json",
25 | "mistralai/Mistral-7B-Instruct-v0.1": "https://huggingface.co/mistralai/Mistral-7B-Instruct-v0.1/resolve/main/config.json",
26 | }
27 |
28 |
29 | class MistralConfig(PretrainedConfig):
30 | r"""
31 | This is the configuration class to store the configuration of a [`MistralModel`]. It is used to instantiate an
32 | Mistral model according to the specified arguments, defining the model architecture. Instantiating a configuration
33 | with the defaults will yield a similar configuration to that of the Mistral-7B-v0.1 or Mistral-7B-Instruct-v0.1.
34 |
35 | [mistralai/Mistral-7B-v0.1](https://huggingface.co/mistralai/Mistral-7B-v0.1)
36 | [mistralai/Mistral-7B-Instruct-v0.1](https://huggingface.co/mistralai/Mistral-7B-Instruct-v0.1)
37 |
38 | Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
39 | documentation from [`PretrainedConfig`] for more information.
40 |
41 |
42 | Args:
43 | vocab_size (`int`, *optional*, defaults to 32000):
44 | Vocabulary size of the Mistral model. Defines the number of different tokens that can be represented by the
45 | `inputs_ids` passed when calling [`MistralModel`]
46 | hidden_size (`int`, *optional*, defaults to 4096):
47 | Dimension of the hidden representations.
48 | intermediate_size (`int`, *optional*, defaults to 14336):
49 | Dimension of the MLP representations.
50 | num_hidden_layers (`int`, *optional*, defaults to 32):
51 | Number of hidden layers in the Transformer encoder.
52 | num_attention_heads (`int`, *optional*, defaults to 32):
53 | Number of attention heads for each attention layer in the Transformer encoder.
54 | num_key_value_heads (`int`, *optional*, defaults to 8):
55 | This is the number of key_value heads that should be used to implement Grouped Query Attention. If
56 | `num_key_value_heads=num_attention_heads`, the model will use Multi Head Attention (MHA), if
57 | `num_key_value_heads=1 the model will use Multi Query Attention (MQA) otherwise GQA is used. When
58 | converting a multi-head checkpoint to a GQA checkpoint, each group key and value head should be constructed
59 | by meanpooling all the original heads within that group. For more details checkout [this
60 | paper](https://arxiv.org/pdf/2305.13245.pdf). If it is not specified, will default to `8`.
61 | hidden_act (`str` or `function`, *optional*, defaults to `"silu"`):
62 | The non-linear activation function (function or string) in the decoder.
63 | max_position_embeddings (`int`, *optional*, defaults to `4096*32`):
64 | The maximum sequence length that this model might ever be used with. Mistral's sliding window attention
65 | allows sequence of up to 4096*32 tokens.
66 | initializer_range (`float`, *optional*, defaults to 0.02):
67 | The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
68 | rms_norm_eps (`float`, *optional*, defaults to 1e-06):
69 | The epsilon used by the rms normalization layers.
70 | use_cache (`bool`, *optional*, defaults to `True`):
71 | Whether or not the model should return the last key/values attentions (not used by all models). Only
72 | relevant if `config.is_decoder=True`.
73 | pad_token_id (`int`, *optional*):
74 | The id of the padding token.
75 | bos_token_id (`int`, *optional*, defaults to 1):
76 | The id of the "beginning-of-sequence" token.
77 | eos_token_id (`int`, *optional*, defaults to 2):
78 | The id of the "end-of-sequence" token.
79 | tie_word_embeddings (`bool`, *optional*, defaults to `False`):
80 | Whether the model's input and output word embeddings should be tied.
81 | rope_theta (`float`, *optional*, defaults to 10000.0):
82 | The base period of the RoPE embeddings.
83 | sliding_window (`int`, *optional*, defaults to 4096):
84 | Sliding window attention window size. If not specified, will default to `4096`.
85 | attention_dropout (`float`, *optional*, defaults to 0.0):
86 | The dropout ratio for the attention probabilities.
87 |
88 | ```python
89 | >>> from transformers import MistralModel, MistralConfig
90 |
91 | >>> # Initializing a Mistral 7B style configuration
92 | >>> configuration = MistralConfig()
93 |
94 | >>> # Initializing a model from the Mistral 7B style configuration
95 | >>> model = MistralModel(configuration)
96 |
97 | >>> # Accessing the model configuration
98 | >>> configuration = model.config
99 | ```"""
100 |
101 | model_type = "mistral"
102 | keys_to_ignore_at_inference = ["past_key_values"]
103 |
104 | def __init__(
105 | self,
106 | vocab_size=32000,
107 | hidden_size=4096,
108 | intermediate_size=14336,
109 | num_hidden_layers=32,
110 | num_attention_heads=32,
111 | num_key_value_heads=8,
112 | hidden_act="silu",
113 | max_position_embeddings=4096 * 32,
114 | initializer_range=0.02,
115 | rms_norm_eps=1e-6,
116 | use_cache=True,
117 | pad_token_id=None,
118 | bos_token_id=1,
119 | eos_token_id=2,
120 | tie_word_embeddings=False,
121 | rope_theta=10000.0,
122 | sliding_window=4096,
123 | attention_dropout=0.0,
124 | **kwargs,
125 | ):
126 | self.vocab_size = vocab_size
127 | self.max_position_embeddings = max_position_embeddings
128 | self.hidden_size = hidden_size
129 | self.intermediate_size = intermediate_size
130 | self.num_hidden_layers = num_hidden_layers
131 | self.num_attention_heads = num_attention_heads
132 | self.sliding_window = sliding_window
133 |
134 | # for backward compatibility
135 | if num_key_value_heads is None:
136 | num_key_value_heads = num_attention_heads
137 |
138 | self.num_key_value_heads = num_key_value_heads
139 | self.hidden_act = hidden_act
140 | self.initializer_range = initializer_range
141 | self.rms_norm_eps = rms_norm_eps
142 | self.use_cache = use_cache
143 | self.rope_theta = rope_theta
144 | self.attention_dropout = attention_dropout
145 |
146 | super().__init__(
147 | pad_token_id=pad_token_id,
148 | bos_token_id=bos_token_id,
149 | eos_token_id=eos_token_id,
150 | tie_word_embeddings=tie_word_embeddings,
151 | **kwargs,
152 | )
--------------------------------------------------------------------------------
/large_models/prefix_tuning.py:
--------------------------------------------------------------------------------
1 | import logging
2 |
3 | logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
4 | logger = logging.getLogger(__name__)
5 | logger.setLevel(logging.INFO)
6 |
7 | import torch
8 | from torch import nn
9 |
10 |
11 | def find_module(root_module: nn.Module, key: str):
12 | """
13 | Find a module with a specific name in a Transformer model
14 | From OpenDelta https://github.com/thunlp/OpenDelta
15 | """
16 | sub_keys = key.split(".")
17 | parent_module = root_module
18 | for sub_key in sub_keys[:-1]:
19 | parent_module = getattr(parent_module, sub_key)
20 | module = getattr(parent_module, sub_keys[-1])
21 | return parent_module, sub_keys[-1], module
22 |
23 |
24 | def attn_forward_hook(self, *args, **kwargs):
25 | """
26 | Replace the original attention forward with this to enable prefix
27 | """
28 |
29 | def _expand_bsz(x, bsz):
30 | x = x.reshape(x.size(0), self.num_heads, -1).transpose(0,
31 | 1) # (num_prefix, hidden) -> (num_head, num_prefix, hidden/num_head)
32 | x = x.unsqueeze(0).expand(bsz, *x.shape) # -> (bsz, num_head, num_prefix, hidden/num_head)
33 | return x
34 |
35 | if "hidden_states" in kwargs:
36 | hidden_states = kwargs["hidden_states"]
37 | else:
38 | hidden_states = args[0]
39 | bsz = hidden_states.size(0)
40 |
41 | if 'past_key_value' not in kwargs or kwargs['past_key_value'] is None:
42 | if self.reparam:
43 | prefix_keys = self.prefix_mlp_keys(self.prefix_input_embeds)
44 | prefix_values = self.prefix_mlp_values(self.prefix_input_embeds)
45 | else:
46 | prefix_keys, prefix_values = self.prefix_keys, self.prefix_values
47 | kwargs['past_key_value'] = (_expand_bsz(prefix_keys, bsz), _expand_bsz(prefix_values, bsz))
48 |
49 | if 'attention_mask' in kwargs and kwargs['attention_mask'] is not None:
50 | am = kwargs['attention_mask']
51 | kwargs['attention_mask'] = torch.cat(
52 | [-torch.zeros((*am.shape[:-1], self.num_prefix), dtype=am.dtype, device=am.device), am], dim=-1)
53 | elif len(args) > 1: # attention mask is passed via positional argument
54 | am = args[1]
55 | am = torch.cat([-torch.zeros((*am.shape[:-1], self.num_prefix), dtype=am.dtype, device=am.device), am],
56 | dim=-1)
57 | args = (args[0], am) + args[2:]
58 |
59 | return self.original_forward(*args, **kwargs)
60 |
61 |
62 | def prepare_inputs_for_generation(
63 | self, input_ids, past_key_values=None, attention_mask=None, inputs_embeds=None, **kwargs):
64 | """
65 | Replace the original "prepare_inputs_for_generation" with this to pass prefix correctly
66 | """
67 | original_input_len = input_ids.size(-1)
68 | if past_key_values:
69 | input_ids = input_ids[:, -1:]
70 |
71 | # if `inputs_embeds` are passed, we only want to use them in the 1st generation step
72 | if inputs_embeds is not None and past_key_values is None:
73 | model_inputs = {"inputs_embeds": inputs_embeds}
74 | else:
75 | model_inputs = {"input_ids": input_ids}
76 |
77 | if past_key_values is not None:
78 | # Check if we should add extra to attention mask
79 | if past_key_values[0][0].size(2) != attention_mask.size(1) - 1:
80 | num_prefix = past_key_values[0][0].size(2) - (attention_mask.size(1) - 1)
81 | attention_mask = torch.cat([torch.ones((attention_mask.size(0), num_prefix), dtype=attention_mask.dtype,
82 | device=attention_mask.device), attention_mask], dim=-1)
83 |
84 | model_inputs.update(
85 | {
86 | "past_key_values": past_key_values,
87 | "use_cache": kwargs.get("use_cache"),
88 | "attention_mask": attention_mask,
89 | }
90 | )
91 | return model_inputs
92 |
93 |
94 | class PrefixTuning:
95 |
96 | def __init__(self, model, num_prefix, reparam=True, embed_dim=512, mid_dim=512, float16=False,
97 | init_by_real_act=False):
98 | """
99 | Inputs:
100 | num_prefix: number of prefix tokens
101 | reparam: use reparameterization trick (not used in MeZO)
102 | embed_dim, mid_dim: hyperparameters for reparameterization trick (not used in MeZO)
103 | float15: whether the model parameters are float15
104 | init_by_real_act: init prefix tokens by real activations
105 | """
106 |
107 | self.model = model
108 | self.num_prefix = num_prefix
109 | self.hidden_dim = model.config.hidden_size
110 | self.float16 = float16
111 |
112 | # Reparameterization
113 | self.reparam = reparam
114 | self.embed_dim = embed_dim
115 | self.mid_dim = mid_dim
116 |
117 | input_embeds = None # For reparameterization
118 | if model.config.model_type == "opt":
119 | attention_name = "attn"
120 | first_layer_name = "layers.0"
121 | layer_name = "layers."
122 | elif model.config.model_type == "roberta":
123 | attention_name = "attention"
124 | first_layer_name = "layer.0"
125 | layer_name = "layer."
126 | elif model.config.model_type in ["llama", "mistral"]:
127 | attention_name = "self_attn"
128 | first_layer_name = "layers.0"
129 | layer_name = "layers."
130 | else:
131 | raise NotImplementedError
132 |
133 | if init_by_real_act:
134 | # Initialize prefix with real words' activations
135 | assert not reparam
136 |
137 | # Randomly sample input tokens
138 | input_tokens = torch.randint(low=0, high=model.config.vocab_size, size=(1, num_prefix),
139 | dtype=torch.long).cuda()
140 | if model.config.model_type in ["opt", "llama", "mistral"]:
141 | with torch.no_grad():
142 | # Get the real activations
143 | real_key_values = model(input_ids=input_tokens, use_cache=True).past_key_values
144 | else:
145 | raise NotImplementedError
146 |
147 | # Insert prefix
148 | for key, _ in model.named_modules():
149 | if key[-len(attention_name):] == attention_name:
150 | layer_id = int(key.split(layer_name)[1].split(".")[0])
151 | logger.info(f"Inject prefix to: {key}")
152 | _, _, attn = find_module(model, key)
153 |
154 | # Replace the old forward functions
155 | attn.original_forward = attn.forward
156 | attn.forward = attn_forward_hook.__get__(attn, type(attn))
157 | if not hasattr(attn, "num_heads"):
158 | attn.num_heads = model.config.num_attention_heads
159 | first = first_layer_name in key
160 | self.add_prefix(attn, first=first, input_embeds=input_embeds)
161 |
162 | if first and self.reparam:
163 | input_embeds = attn.prefix_input_embeds
164 | if init_by_real_act:
165 | logger.info(f"Reinitialize with actual activation: {key} (layer {layer_id})")
166 | keys = real_key_values[layer_id][0].squeeze(0).transpose(0, 1).reshape(num_prefix, -1)
167 | values = real_key_values[layer_id][1].squeeze(0).transpose(0, 1).reshape(num_prefix, -1)
168 | attn.prefix_keys.data = keys.to(attn.prefix_keys.data.device)
169 | attn.prefix_values.data = values.to(attn.prefix_values.data.device)
170 |
171 | # Freeze non-prefix parameters
172 | for n, p in model.named_parameters():
173 | if "prefix" not in n:
174 | p.requires_grad = False
175 |
176 | # Replace the old prepare_inputs_for_generation function
177 | model.prepare_inputs_for_generation = prepare_inputs_for_generation.__get__(model, type(model))
178 |
179 | def add_prefix(self, module, first, input_embeds=None):
180 | device = module.k_proj.weight.data.device
181 | module.num_prefix = self.num_prefix
182 | module.reparam = self.reparam
183 | if self.reparam:
184 | if first:
185 | # For the first layer we inject the embeddings
186 | logger.info("For prefix+reparameterization, inject the embeddings in the first layer.")
187 | module.prefix_input_embeds = nn.Parameter(
188 | torch.randn(self.num_prefix, self.embed_dim, device=device, dtype=self.model.dtype),
189 | requires_grad=True)
190 | else:
191 | assert input_embeds is not None
192 | module.prefix_input_embeds = input_embeds
193 | module.prefix_mlp_keys = nn.Sequential(
194 | nn.Linear(self.embed_dim, self.mid_dim),
195 | nn.Tanh(),
196 | nn.Linear(self.mid_dim, self.hidden_dim)
197 | ).to(device)
198 | module.prefix_mlp_values = nn.Sequential(
199 | nn.Linear(self.embed_dim, self.mid_dim),
200 | nn.Tanh(),
201 | nn.Linear(self.mid_dim, self.hidden_dim)
202 | ).to(device)
203 | if self.float16:
204 | module.prefix_mlp_keys = module.prefix_mlp_keys.half()
205 | module.prefix_mlp_values = module.prefix_mlp_values.half()
206 | else:
207 | module.prefix_keys = nn.Parameter(
208 | torch.randn(self.num_prefix, self.hidden_dim, device=device, dtype=self.model.dtype),
209 | requires_grad=True)
210 | module.prefix_values = nn.Parameter(
211 | torch.randn(self.num_prefix, self.hidden_dim, device=device, dtype=self.model.dtype),
212 | requires_grad=True)
213 |
--------------------------------------------------------------------------------
/large_models/prompt_tuning.py:
--------------------------------------------------------------------------------
1 | import logging
2 | from functools import partial
3 | from typing import Optional, Callable
4 |
5 | import torch
6 | from torch import nn
7 | from transformers import PreTrainedModel
8 |
9 | logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
10 | logger = logging.getLogger(__name__)
11 | logger.setLevel(logging.INFO)
12 |
13 |
14 | class PromptEmbedding(nn.Module):
15 | def __init__(
16 | self,
17 | num_virtual_tokens: int,
18 | token_dim: int,
19 | init_by_real_text: bool,
20 | word_embeddings: Optional[nn.Module] = None,
21 | vocab_size: Optional[int] = None,
22 | ):
23 | super().__init__()
24 | self.num_virtual_tokens = num_virtual_tokens
25 |
26 | self.embedding = torch.nn.Embedding(num_virtual_tokens, token_dim)
27 | if init_by_real_text:
28 | init_token_ids = torch.randint(
29 | low=0, high=vocab_size,
30 | size=(num_virtual_tokens,), dtype=torch.long
31 | ).to(word_embeddings.weight.device)
32 |
33 | word_embedding_weights = word_embeddings(init_token_ids).detach().clone()
34 | word_embedding_weights = word_embedding_weights.to(torch.float32)
35 | self.embedding.weight = nn.Parameter(word_embedding_weights)
36 |
37 | def forward(self, indices):
38 | # Just get embeddings
39 | prompt_embeddings = self.embedding(indices)
40 | return prompt_embeddings
41 |
42 |
43 | def _get_batch_size(input_ids: Optional[torch.Tensor], inputs_embeds: Optional[torch.Tensor]) -> int:
44 | if (input_ids is None) and (inputs_embeds is None):
45 | raise ValueError("You have to provide either input_ids or inputs_embeds")
46 |
47 | if input_ids is not None:
48 | batch_size = input_ids.shape[0]
49 | else:
50 | batch_size = inputs_embeds.shape[0]
51 | return batch_size
52 |
53 |
54 | def _model_forward_hook(
55 | self,
56 | embedding_module: Callable,
57 | embedding_module_device_refer,
58 | hide_virtual_token_logits: bool,
59 | input_ids=None,
60 | attention_mask=None,
61 | inputs_embeds=None,
62 | labels=None,
63 | output_attentions=None,
64 | output_hidden_states=None,
65 | return_dict=None,
66 | **kwargs,
67 | ):
68 | batch_size = _get_batch_size(input_ids, inputs_embeds)
69 | num_virtual_tokens = self.prompt_encoder.num_virtual_tokens
70 | if attention_mask is not None:
71 | # concat prompt attention mask
72 | prefix_attention_mask = torch.ones(batch_size, num_virtual_tokens).to(attention_mask.device)
73 | attention_mask = torch.cat((prefix_attention_mask, attention_mask), dim=1)
74 | if kwargs.get("position_ids", None) is not None:
75 | warnings.warn("Position ids are not supported for parameter efficient tuning. Ignoring position ids.")
76 | kwargs["position_ids"] = None
77 | kwargs.update(
78 | {
79 | "attention_mask": attention_mask,
80 | "output_attentions": output_attentions,
81 | "output_hidden_states": output_hidden_states,
82 | "return_dict": return_dict,
83 | }
84 | )
85 |
86 | if labels is not None:
87 | if len(labels.shape) == 1:
88 | # if sequence classification task, labels do not have to be padded
89 | kwargs["labels"] = labels
90 | elif len(labels.shape) == 2:
91 | # suppose to be language modeling task, labels have to be padded with -100
92 | kwargs["labels"] = torch.cat(
93 | (
94 | -100 * torch.ones(batch_size, num_virtual_tokens).to(labels.device).long(),
95 | labels,
96 | ),
97 | dim=1,
98 | )
99 | else:
100 | raise NotImplementedError("Not implemented for labels with shape {}".format(labels.shape))
101 |
102 | if kwargs.get("token_type_ids", None) is not None:
103 | kwargs["token_type_ids"] = torch.cat(
104 | (
105 | torch.zeros(batch_size, num_virtual_tokens).to(kwargs["token_type_ids"].device),
106 | kwargs["token_type_ids"],
107 | ),
108 | dim=1,
109 | ).long()
110 |
111 | if kwargs.get("mask_pos", None) is not None:
112 | kwargs["mask_pos"] = num_virtual_tokens + kwargs["mask_pos"]
113 |
114 | input_device = input_ids.device if input_ids is not None else inputs_embeds.device
115 | if inputs_embeds is None:
116 | inputs_embeds = embedding_module(input_ids.to(embedding_module_device_refer.device))
117 | inputs_embeds = inputs_embeds.to(input_device)
118 | prompts = torch.arange(num_virtual_tokens).unsqueeze(0).expand(batch_size, -1).to(
119 | self.prompt_encoder.embedding.weight.device)
120 | prompts = self.prompt_encoder(prompts).to(dtype=inputs_embeds.dtype, device=input_device)
121 |
122 | inputs_embeds = torch.cat((prompts, inputs_embeds), dim=1)
123 |
124 | outputs = self.prompt_tuning_original_forward(inputs_embeds=inputs_embeds, **kwargs)
125 | if hide_virtual_token_logits and hasattr(outputs, "logits"):
126 | outputs.logits = outputs.logits[..., num_virtual_tokens:, :]
127 | return outputs
128 |
129 |
130 | class PromptTuning:
131 |
132 | def __init__(
133 | self,
134 | model: PreTrainedModel,
135 | num_virtual_tokens: int,
136 | init_by_real_tokens: Optional[bool] = False,
137 | hide_virtual_token_logits: Optional[bool] = True,
138 | ):
139 | """
140 | Prompt tuning model initializer.
141 |
142 | Parameters
143 | ----------
144 | model: PreTrainedModel, required
145 | The model to be tuned.
146 | num_virtual_tokens: int, required
147 | The number of virtual tokens to be added.
148 | init_by_real_tokens: bool, optional, default=False
149 | Whether to initialize the virtual tokens by real tokens.
150 | """
151 | hidden_dim = model.config.hidden_size
152 |
153 | if model.config.model_type == "opt":
154 | embedding_module = model.get_input_embeddings()
155 | embedding_module_device_refer = embedding_module.weight
156 | elif model.config.model_type == "roberta":
157 | if hasattr(model, "roberta"): # is RoBERTaForMaskedLM etc.
158 | embedding_module = partial(model.roberta.embeddings, past_key_values_length=num_virtual_tokens)
159 | embedding_module_device_refer = model.roberta.embeddings.word_embeddings.weight
160 | elif hasattr(model, "embeddings"): # is RoBERTa base model
161 | embedding_module = partial(model.embeddings, past_key_values_length=num_virtual_tokens)
162 | embedding_module_device_refer = model.embeddings.word_embeddings.weight
163 | else:
164 | raise ValueError(f"Cannot find embedding module in {model.__class__.__name__}")
165 | elif model.config.model_type in ["llama", "mistral"]:
166 | embedding_module = model.get_input_embeddings()
167 | embedding_module_device_refer = embedding_module.weight
168 | else:
169 | raise NotImplementedError
170 |
171 | model.prompt_encoder = PromptEmbedding(
172 | num_virtual_tokens, hidden_dim, init_by_real_tokens,
173 | model.get_input_embeddings(), model.config.vocab_size
174 | )
175 |
176 | model.prompt_tuning_original_forward = model.forward
177 |
178 | if not hasattr(embedding_module_device_refer, "device"):
179 | raise ValueError(f"Cannot find device attribute in {embedding_module_device_refer.__class__.__name__}")
180 |
181 | forward_hook_kwargs = {
182 | "embedding_module": embedding_module,
183 | "embedding_module_device_refer": embedding_module_device_refer,
184 | "hide_virtual_token_logits": hide_virtual_token_logits,
185 | }
186 | model.forward = partial(
187 | _model_forward_hook.__get__(model, type(model)),
188 | **forward_hook_kwargs
189 | )
190 |
191 | for n, p in model.named_parameters():
192 | if "prompt_encoder" not in n:
193 | p.requires_grad = False
194 |
195 |
196 | def test_roberta():
197 | from transformers import AutoTokenizer, RobertaModel
198 | model = RobertaModel.from_pretrained("roberta-base")
199 | tokenizer = AutoTokenizer.from_pretrained("roberta-base")
200 |
201 | PromptTuning(model, num_virtual_tokens=5, init_by_real_tokens=True)
202 |
203 | inputs = tokenizer("in heissem Liebesstreben", return_tensors="pt")
204 | outputs = model(**inputs)
205 |
206 |
207 | def test_opt():
208 | from transformers import AutoTokenizer, OPTModel
209 | model = OPTModel.from_pretrained("facebook/opt-125m")
210 | tokenizer = AutoTokenizer.from_pretrained("facebook/opt-125m")
211 |
212 | PromptTuning(model, num_virtual_tokens=5, init_by_real_tokens=True)
213 |
214 | inputs = tokenizer("werd ich entschweben", return_tensors="pt")
215 | outputs = model(**inputs)
216 |
217 |
218 | if __name__ == "__main__":
219 | test_roberta()
220 | test_opt()
221 |
--------------------------------------------------------------------------------
/large_models/run.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import os
3 |
4 | import random
5 |
6 | import wandb
7 | from torch.utils.tensorboard import SummaryWriter
8 | from datetime import datetime
9 | from torch.distributed.fsdp.fully_sharded_data_parallel import FullyShardedDataParallel as FSDP
10 | from torch.utils.data import Dataset
11 | from tqdm import tqdm
12 | from transformers import (
13 | AutoConfig,
14 | AutoTokenizer,
15 | AutoModelForCausalLM,
16 | HfArgumentParser,
17 | TrainingArguments,
18 | DataCollatorForTokenClassification
19 | )
20 |
21 | from metrics import calculate_metric
22 | from modeling_mistral import (
23 | MistralForCausalLM,
24 | MistralConfig
25 | )
26 | from tasks import get_task
27 | from trainer import OurTrainer
28 | from utils import *
29 |
30 | os.environ["TRANSFORMERS_CACHE"] = "./cache"
31 |
32 | logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
33 | logger = logging.getLogger(__name__)
34 | logger.setLevel(logging.INFO)
35 |
36 | AutoConfig.register("mistral", MistralConfig)
37 | AutoModelForCausalLM.register(MistralConfig, MistralForCausalLM)
38 |
39 |
40 | @dataclass
41 | class OurArguments(TrainingArguments):
42 | # dataset and sampling strategy
43 | task_name: str = "SST2" # task name should match the string before Dataset in the Dataset class name. We support the following task_name: SST2, RTE, CB, BoolQ, WSC, WIC, MultiRC, Copa, ReCoRD, SQuAD, DROP
44 |
45 | # Number of examples
46 | num_train: int = 0 # ICL mode: number of demonstrations; training mode: number of training samples
47 | num_dev: int = None # (only enabled with training) number of development samples
48 | num_eval: int = None # number of evaluation samples
49 | num_train_sets: int = None # how many sets of training samples/demos to sample; if None and train_set_seed is None, then we will sample one set for each evaluation sample
50 | train_set_seed: int = 0 # designated seed to sample training samples/demos
51 | result_file: str = None # file name for saving performance; if None, then use the task name, model name, and config
52 |
53 | # Model loading
54 | model_name: str = "facebook/opt-125m" # HuggingFace model name
55 | load_float16: bool = False # load model parameters as float16
56 | load_bfloat16: bool = False # load model parameters as bfloat16
57 | load_int8: bool = False # load model parameters as int8
58 | max_length: int = 2048 # max length the model can take
59 | no_auto_device: bool = False # do not load model by auto device; should turn this on when using FSDP
60 |
61 | # Calibration
62 | sfc: bool = False # whether to use SFC calibration
63 | icl_sfc: bool = False # whether to use SFC calibration for ICL samples
64 |
65 | template_ver: int = 0 # template. For some tasks (SST2, RTE, Copa), we add template ver=1 as the empty template.
66 |
67 | # Training
68 | trainer: str = "subzero_sgd"
69 | ## options
70 | ## - none: no training -- for zero-shot or in-context learning (ICL)
71 | ## - regular: regular huggingface trainer -- for fine-tuning
72 | ## - zo_sgd: zeroth-order SGD (MeZO) training
73 | ## - zo_conserv: zeroth-order SGD conservative training
74 | ## - zo_adam: zeroth-order Adam training
75 | ## - zo_sign_opt: zeroth-order sign sgd training
76 | ## - forward_grad: forward gradient
77 | ## (add) -zo_sgd_svd
78 |
79 | optimizer: str = "adamw"
80 | ## options
81 | ## - sgd
82 | ## - adam
83 | ## - adamw # this is huggingface default
84 | only_train_option: bool = True # whether to only train the option part of the input
85 | train_as_classification: bool = False # take the log likelihood of all options and train as classification
86 | momentum: float = 0.0 # only work for SGD optimizer
87 | lr_scheduler_type: str = "constant" # only work for SGD optimizer
88 |
89 | # MeZO and SubZero
90 | zo_eps: float = 1e-3 # eps in MeZO
91 | perturbation_mode: str = "two_side"
92 | q: int = 1 # number of Gaussian samples for zeroth-order trainers
93 |
94 | update_interval: int = 2000
95 | gauss_rank: int = 8
96 |
97 |
98 | # Prefix tuning
99 | prefix_tuning: bool = False # whether to use prefix tuning
100 | num_prefix: int = 5 # number of prefixes to use
101 | no_reparam: bool = True # do not use reparameterization trick
102 | prefix_init_by_real_act: bool = True # initialize prefix by real activations of random words
103 |
104 | # prompt tuning hyperparameters
105 | prompt_tuning: bool = False # whether to use prompt tuning
106 | num_virtual_tokens: int = 10 # number of prompt tokens to use
107 | prompt_init_by_real_tokens: bool = False # whether to sample random tokens from Embedding layer
108 |
109 | # LoRA
110 | lora: bool = False # whether to use LoRA
111 | lora_alpha: int = 16 # alpha in LoRA
112 | lora_r: int = 8 # r in LoRA
113 |
114 | # Generation
115 | sampling: bool = False # whether to use sampling
116 | temperature: float = 1.0 # temperature for generation
117 | num_beams: int = 1 # number of beams for generation
118 | top_k: int = None # top-k for generation
119 | top_p: float = 0.95 # top-p for generation
120 | max_new_tokens: int = 50 # max number of new tokens to generate
121 | eos_token: str = "\n" # end of sentence token
122 |
123 | # Saving
124 | save_model: bool = False # whether to save the model
125 | no_eval: bool = False # whether to skip evaluation
126 | tag: str = "" # saving tag
127 |
128 | # Linear probing
129 | linear_probing: bool = False # whether to do linear probing
130 | lp_early_stopping: bool = False # whether to do early stopping in linear probing
131 | head_tuning: bool = False # head tuning: only tune the LM head
132 |
133 | # Untie emb/lm_head weights
134 | untie_emb: bool = False # untie the embeddings and LM head
135 |
136 | # Display
137 | verbose: bool = False # verbose output
138 |
139 | # Non-diff objective
140 | non_diff: bool = False # use non-differentiable objective (only support F1 for SQuAD for now)
141 |
142 | # Auto saving when interrupted
143 | save_on_interrupt: bool = False # save model when interrupted (useful for long training)
144 |
145 | clean_model_at_end: bool = True # remove everthing at the end.
146 |
147 | def parse_args():
148 | parser = argparse.ArgumentParser()
149 | parser = HfArgumentParser(OurArguments)
150 | args = parser.parse_args_into_dataclasses()[0]
151 | print(args)
152 | return args
153 |
154 |
155 | def set_seed(seed: int):
156 | random.seed(seed)
157 | np.random.seed(seed)
158 | torch.manual_seed(seed)
159 | torch.cuda.manual_seed_all(seed)
160 |
161 |
162 | class Framework:
163 |
164 | def __init__(self, args, task):
165 | self.args = args
166 | self.task = task
167 | self.model, self.tokenizer = self.load_model()
168 |
169 | def load_model(self):
170 | """
171 | Load HuggingFace models
172 | """
173 | with count_time("Loading model with FP%d" % (16 if self.args.load_float16 else 32)):
174 | free_in_GB = int(torch.cuda.mem_get_info()[0] / 1024 ** 3)
175 | print(free_in_GB)
176 | config = AutoConfig.from_pretrained(self.args.model_name)
177 | if self.args.untie_emb:
178 | # Untie embeddings/LM head
179 | logger.warn("Untie embeddings and LM head")
180 | config.tie_word_embeddings = False
181 | if self.args.head_tuning:
182 | torch_dtype = torch.float32
183 | if self.args.load_float16:
184 | torch_dtype = torch.float16
185 | elif self.args.load_bfloat16:
186 | torch_dtype = torch.bfloat16
187 | # Head tuning
188 | if "opt" in self.args.model_name.lower():
189 | from modeling_opt import OPTForCausalLM
190 | model = OPTForCausalLM.from_pretrained(
191 | self.args.model_name,
192 | config=config,
193 | device_map='auto',
194 | torch_dtype=torch_dtype,
195 | max_memory={i: f'{free_in_GB - 5}GB' for i in
196 | range(torch.cuda.device_count())},
197 | )
198 | elif "llama" in self.args.model_name.lower():
199 | from modeling_llama import LlamaForCausalLMWithHeadTuning
200 | model = LlamaForCausalLMWithHeadTuning.from_pretrained(
201 | self.args.model_name,
202 | config=config,
203 | device_map='auto',
204 | torch_dtype=torch_dtype,
205 | max_memory={i: f'{free_in_GB - 5}GB' for i in
206 | range(torch.cuda.device_count())},
207 | )
208 | elif "mistral" in self.args.model_name.lower():
209 | from modeling_mistral import MistralForCausalLMWithHeadTuning
210 | model = MistralForCausalLMWithHeadTuning.from_pretrained(
211 | self.args.model_name,
212 | config=config,
213 | device_map='auto',
214 | torch_dtype=torch_dtype,
215 | max_memory={i: f'{free_in_GB - 5}GB' for i in
216 | range(torch.cuda.device_count())},
217 | )
218 | else:
219 | raise NotImplementedError(f"Head tuning is not supported for {self.args.model_name}")
220 | elif self.args.no_auto_device:
221 | # No auto device (use for FSDP)
222 | model = AutoModelForCausalLM.from_pretrained(self.args.model_name, config=config, )
223 | else:
224 | # Auto device loading
225 | torch_dtype = torch.float32
226 | if self.args.load_float16:
227 | torch_dtype = torch.float16
228 | elif self.args.load_bfloat16:
229 | torch_dtype = torch.bfloat16
230 | model = AutoModelForCausalLM.from_pretrained(self.args.model_name, config=config, device_map='auto',
231 | torch_dtype=torch_dtype,
232 | max_memory={i: f'{free_in_GB - 0.5}GB' for i in
233 | range(torch.cuda.device_count())},
234 | load_in_8bit=self.args.load_int8, )
235 | model.eval()
236 |
237 | # Load tokenizer
238 | # In mezo, use_fast is set to False. But TypeError will occur when running SQuaD. Setting to be True can fix.
239 | tokenizer = AutoTokenizer.from_pretrained(self.args.model_name, use_fast=True)
240 |
241 | # HF tokenizer bug fix
242 | if "opt" in self.args.model_name:
243 | tokenizer.bos_token_id = 0
244 |
245 | if ("llama" in self.args.model_name) or ("mistral" in self.args.model_name.lower()):
246 | # LLaMA padding token
247 | tokenizer.pad_token_id = 0 # technically
248 |
249 | # Prefix tuning/LoRA
250 | if self.args.prefix_tuning:
251 | from prefix_tuning import PrefixTuning
252 | PrefixTuning(model, num_prefix=self.args.num_prefix, reparam=not self.args.no_reparam,
253 | float16=self.args.load_float16, init_by_real_act=self.args.prefix_init_by_real_act)
254 | if self.args.lora:
255 | from lora import LoRA
256 | LoRA(model, r=self.args.lora_r, alpha=self.args.lora_alpha, float16=self.args.load_float16)
257 |
258 | if self.args.prompt_tuning:
259 | from prompt_tuning import PromptTuning
260 | print("Adding Prompt Tuning to model...")
261 | PromptTuning(
262 | model,
263 | num_virtual_tokens=self.args.num_virtual_tokens,
264 | init_by_real_tokens=self.args.prompt_init_by_real_tokens,
265 | hide_virtual_token_logits=True, # a workaround for the other loss/prediction functions
266 | )
267 |
268 | # for name, param in model.named_parameters():
269 | # if name == 'prompt_encoder.embedding.weight':
270 | # print(param.shape, end="\n")
271 |
272 |
273 | print("Total/Trainable number of parameters: {}/{}".format(
274 | sum(p.numel() for p in model.parameters()),
275 | sum(p.numel() for p in model.parameters() if p.requires_grad),
276 | ))
277 |
278 | if self.args.head_tuning:
279 | if model.config.model_type in ["opt", "llama", "mistral"]:
280 | head_name = "lm_head" if self.args.untie_emb else "embed_tokens"
281 | else:
282 | raise NotImplementedError
283 | for n, p in model.named_parameters():
284 | if head_name not in n:
285 | p.requires_grad = False
286 | else:
287 | logger.info(f"Only tuning {n}")
288 |
289 | return model, tokenizer
290 |
291 | def forward(self, input_ids, option_len=None, generation=False):
292 | """
293 | Given input_ids and the length of the option, return the log-likelihood of each token in the option.
294 | For generation tasks, return the generated text.
295 | This function is only for inference
296 | """
297 | input_ids = torch.tensor([input_ids]).to(self.model.device)
298 |
299 | if generation:
300 | args = self.args
301 | # Autoregressive generation
302 | outputs = self.model.generate(input_ids, do_sample=args.sampling, temperature=args.temperature,
303 | num_beams=args.num_beams, top_p=args.top_p, top_k=args.top_k,
304 | max_new_tokens=min(args.max_new_tokens, args.max_length - input_ids.size(1)),
305 | num_return_sequences=1,
306 | eos_token_id=[
307 | self.tokenizer.encode(args.eos_token, add_special_tokens=False)[-1],
308 | self.tokenizer.eos_token_id], )
309 | # For generation, directly return the text output
310 | output_text = self.tokenizer.decode(outputs[0][input_ids.size(1):], skip_special_tokens=True).strip()
311 | return output_text
312 | else:
313 | with torch.inference_mode():
314 | self.model.eval()
315 | logits = self.model(input_ids=input_ids).logits
316 | labels = input_ids[0, 1:]
317 | logits = logits[0, :-1]
318 | log_probs = F.log_softmax(logits, dim=-1)
319 |
320 | selected_log_probs = log_probs[torch.arange(len(labels)).to(labels.device), labels]
321 | selected_log_probs = selected_log_probs.cpu().detach()
322 | # Only return the option (candidate) part
323 | return selected_log_probs[-option_len:]
324 |
325 | def one_step_pred(self, train_samples, eval_sample, verbose=False):
326 | """
327 | Return the prediction on the eval sample. In ICL, use train_samples as demonstrations
328 | """
329 | verbose = verbose or self.args.verbose
330 | # if verbose:
331 | # logger.info("========= Example =========")
332 | # logger.info(f"Candidate: {eval_sample.candidates}")
333 | # logger.info(f"Correct candidate: {eval_sample.correct_candidate}")
334 |
335 | # Encode (add prompt and tokenize) the sample; if multiple-choice/classification, encode all candidates (options)
336 | encoded_candidates, option_lens = encode_prompt(self.task,
337 | self.task.get_template(template_version=self.args.template_ver),
338 | train_samples, eval_sample,
339 | self.tokenizer, max_length=self.args.max_length,
340 | generation=self.task.generation,
341 | max_new_tokens=self.args.max_new_tokens)
342 |
343 | # Calibration
344 | if self.args.sfc or self.args.icl_sfc:
345 | sfc_encoded_candidates, sfc_option_lens = encode_prompt(self.task, self.task.get_template(
346 | template_version=self.args.template_ver), train_samples,
347 | eval_sample, self.tokenizer,
348 | max_length=self.args.max_length, sfc=self.args.sfc,
349 | icl_sfc=self.args.icl_sfc,
350 | generation=self.task.generation,
351 | max_new_tokens=self.args.max_new_tokens)
352 |
353 | outputs = []
354 | if self.task.generation:
355 | # For generation tasks, return the autoregressively-generated text
356 | output_text = self.forward(encoded_candidates[0], generation=True)
357 | # if verbose:
358 | # logger.info("=== Prompt ===")
359 | # logger.info(self.tokenizer.decode(encoded_candidates[0]))
360 | # logger.info(f"Output: {output_text}")
361 | return Prediction(correct_candidate=eval_sample.correct_candidate, predicted_candidate=output_text)
362 | else:
363 | # For classification/multiple-choice, calculate the probabilities of all candidates
364 | for candidate_id, encoded_candidate in enumerate(encoded_candidates):
365 | selected_log_probs = self.forward(encoded_candidate, option_len=option_lens[candidate_id])
366 | if verbose:
367 | # if candidate_id == 0:
368 | # logger.info("=== Candidate %d ===" % candidate_id)
369 | # logger.info(self.tokenizer.decode(encoded_candidate))
370 | # else:
371 | # logger.info("=== Candidate %d (without context)===" % candidate_id)
372 | # logger.info(self.tokenizer.decode(encoded_candidate).split(self.task.train_sep)[-1])
373 | logger.info(f"Log probabilities of the option tokens: {selected_log_probs}")
374 |
375 | if self.args.sfc or self.args.icl_sfc:
376 | sfc_selected_log_probs = self.forward(sfc_encoded_candidates[candidate_id],
377 | option_len=sfc_option_lens[
378 | candidate_id]) # if verbose: # logger.info("=== Candidate %d (without context) SFC ===" % candidate_id) # logger.info( # self.tokenizer.decode(sfc_encoded_candidates[candidate_id]).split(self.task.train_sep)[-1]) # logger.info(f"Log probabilities of the option tokens: {sfc_selected_log_probs}")
379 |
380 | outputs.append({"log_probs": selected_log_probs,
381 | "sfc_log_probs": sfc_selected_log_probs if self.args.sfc or self.args.icl_sfc else None})
382 |
383 | if self.args.sfc or self.args.icl_sfc:
384 | # Calibrated probabilities (surface form competition; https://arxiv.org/pdf/2104.08315.pdf)
385 | # log p(candidate | input) = log p_lm(candidate | input) - log p_lm(candidate | sfc prompt)
386 | scores = [x['log_probs'].sum().item() - x['sfc_log_probs'].sum().item() for x in outputs]
387 | else:
388 | # (Default) length-normalized log probabilities
389 | # log p(candidate | input) = log p_lm(candidate | input) / |candidate #tokens|
390 | scores = [x['log_probs'].mean().item() for x in outputs]
391 |
392 | if verbose:
393 | logger.info(f"Prediction scores: {scores}")
394 |
395 | if isinstance(eval_sample.correct_candidate, list):
396 | # For some datasets there are multiple correct answers
397 | correct_candidate_id = [eval_sample.candidates.index(c) for c in eval_sample.correct_candidate]
398 | else:
399 | correct_candidate_id = eval_sample.candidates.index(eval_sample.correct_candidate)
400 |
401 | return Prediction(correct_candidate=correct_candidate_id, predicted_candidate=int(np.argmax(scores)))
402 |
403 | def evaluate(self, train_samples, eval_samples, one_train_set_per_eval_sample=False, description=None):
404 | """
405 | Evaluate function.
406 | Here, train_samples are used for demonstrations for ICL.
407 | If one_train_set_per_eval_sample is True, then each eval sample has its own training (demonstration) set.
408 | Otherwise, the same training set is used for all eval samples.
409 | """
410 | if one_train_set_per_eval_sample:
411 | logger.info(f"There are {len(eval_samples)} validation samples and one train set per eval sample")
412 | else:
413 | logger.info(f"There are {len(train_samples)} training samples and {len(eval_samples)} validation samples")
414 |
415 | # Prediction loop
416 | predictions = []
417 | for eval_id, eval_sample in enumerate(tqdm(eval_samples, desc=description)):
418 | predictions.append(
419 | self.one_step_pred(train_samples[eval_id] if one_train_set_per_eval_sample else train_samples,
420 | eval_sample, verbose=False))
421 |
422 | # Calculate metrics
423 | metric_name = getattr(self.task, "metric_name", "accuracy")
424 | metrics = {metric_name: calculate_metric(predictions, metric_name)}
425 | return metrics
426 |
427 | def train(self, train_samples, dev_samples, eval_samples, writer):
428 | """
429 | Training function
430 | if self.num_dev is not None, eval_samples are dev_samples
431 | """
432 | logger.info(f"Eval sample length is {len(eval_samples)}")
433 | # Set tokenizer to left padding (so that all the options are right aligned)
434 | self.tokenizer.padding_side = "left"
435 |
436 | class HFDataset(Dataset):
437 |
438 | def __init__(self, data):
439 | self.data = data
440 |
441 | def __len__(self):
442 | return len(self.data)
443 |
444 | def __getitem__(self, idx):
445 | return self.data[idx]
446 |
447 | def _convert(samples):
448 | """
449 | Convert samples to HF-compatible dataset
450 | """
451 | data = []
452 | for sample in samples:
453 | encoded_candidates, option_lens = encode_prompt(self.task, self.task.get_template(
454 | template_version=self.args.template_ver), [], sample,
455 | self.tokenizer, max_length=self.args.max_length,
456 | generation=self.task.generation,
457 | generation_with_gold=True,
458 | max_new_tokens=self.args.max_new_tokens)
459 | if self.task.generation:
460 | correct_candidate_id = 0
461 | elif isinstance(sample.correct_candidate, list):
462 | correct_candidate_id = sample.candidates.index(sample.correct_candidate[0])
463 | else:
464 | correct_candidate_id = sample.candidates.index(sample.correct_candidate)
465 |
466 | if self.args.non_diff:
467 | # For non-differentiable objective, there is no teacher forcing thus the
468 | # current answer part is removed
469 | encoded_candidates[correct_candidate_id] = encoded_candidates[correct_candidate_id][
470 | :-option_lens[correct_candidate_id]]
471 |
472 | if self.args.train_as_classification:
473 | # For classification, we provide the label as the correct candidate id
474 | data.append([{"input_ids": encoded_candidates[_i], "labels": correct_candidate_id,
475 | "option_len": option_lens[_i], "num_options": len(sample.candidates)} for _i in
476 | range(len(encoded_candidates))])
477 | elif self.args.only_train_option:
478 | # Otherwise, it is just LM-style teacher forcing
479 | if self.args.non_diff:
480 | # For non-differentiable objective, we need to provide the gold answer to calculate F1/acc
481 | data.append({"input_ids": encoded_candidates[correct_candidate_id],
482 | "labels": encoded_candidates[correct_candidate_id],
483 | "option_len": option_lens[correct_candidate_id], "gold": sample.correct_candidate})
484 | else:
485 | data.append({"input_ids": encoded_candidates[correct_candidate_id],
486 | "labels": encoded_candidates[correct_candidate_id],
487 | "option_len": option_lens[correct_candidate_id]})
488 | else:
489 | data.append({"input_ids": encoded_candidates[correct_candidate_id],
490 | "labels": encoded_candidates[correct_candidate_id]})
491 | return data
492 |
493 | with count_time("Tokenizing training samples"):
494 | train_dataset = HFDataset(_convert(train_samples))
495 | eval_dataset = HFDataset(_convert(eval_samples))
496 | dev_dataset = HFDataset(_convert(dev_samples))
497 |
498 | if self.args.only_train_option and not self.args.non_diff:
499 | # If --only_train_option and not with a non-differentiable objective, we wrap the forward function
500 | self.model.original_forward = self.model.forward
501 | self.model.forward = forward_wrap_with_option_len.__get__(self.model, type(self.model))
502 |
503 | if self.args.non_diff:
504 | collator = NondiffCollator
505 | else:
506 | collator = DataCollatorForTokenClassification
507 |
508 | trainer = OurTrainer(model=self.model,
509 | args=self.args,
510 | train_dataset=train_dataset,
511 | eval_dataset=eval_dataset,
512 | tokenizer=self.tokenizer,
513 | data_collator=DataCollatorWithPaddingAndNesting(self.tokenizer,
514 | pad_to_multiple_of=8) if self.args.train_as_classification else collator(
515 | self.tokenizer, pad_to_multiple_of=8),
516 | eval_samples=eval_samples,
517 | dev_samples=dev_samples,
518 | evaluate_func=self.evaluate,
519 | writer=writer
520 | )
521 |
522 | if self.args.save_on_interrupt:
523 | trainer.add_callback(SIGUSR1Callback())
524 |
525 | # Resume training from a last checkpoint
526 | last_checkpoint = None
527 | from transformers.trainer_utils import get_last_checkpoint
528 | if os.path.isdir(self.args.output_dir) and not self.args.overwrite_output_dir:
529 | last_checkpoint = get_last_checkpoint(self.args.output_dir)
530 | if last_checkpoint is not None and self.args.resume_from_checkpoint is None:
531 | logger.info(f"Checkpoint detected, resuming training at {last_checkpoint}. To avoid this behavior, change "
532 | "the `--output_dir` or add `--overwrite_output_dir` to train from scratch.")
533 | if self.args.resume_from_checkpoint is not None:
534 | last_checkpoint = self.args.resume_from_checkpoint
535 |
536 | # This calls the trainer._inner_training_loop()
537 | trainer.train(resume_from_checkpoint=last_checkpoint)
538 |
539 | # Explicitly save the model
540 | if self.args.save_model:
541 | logger.info("Save model..")
542 | trainer.save_model()
543 |
544 | # FSDP compatibility
545 | self.model = trainer.model
546 |
547 | # Reset the forward function for evaluation
548 | if self.args.only_train_option and not self.args.non_diff:
549 | if type(self.model) == FSDP:
550 | logger.info("This is an FSDP model now. Be careful when assigning back the original forward function")
551 | self.model._fsdp_wrapped_module.forward = self.model._fsdp_wrapped_module.original_forward
552 | else:
553 | self.model.forward = self.model.original_forward
554 |
555 | def delete_checkpoints(self):
556 | import shutil
557 | print(f"\nWARNING: Removing everything at end: {self.args.output_dir}")
558 | deleted_folders = [folder for folder in os.listdir(self.args.output_dir)
559 | if os.path.isdir(os.path.join(self.args.output_dir, folder))
560 | and folder.startswith("checkpoint-")]
561 | for f in deleted_folders:
562 | shutil.rmtree(os.path.join(self.args.output_dir, f))
563 | print(f"deleted folders: ", deleted_folders)
564 |
565 |
566 | def result_file_tag(args):
567 | """
568 | Get the result file tag
569 | """
570 | save_model_name = args.model_name.split("/")[-1]
571 | sfc_tag = "-sfc" if args.sfc else ""
572 | icl_sfc_tag = "-icl_sfc" if args.icl_sfc else ""
573 | sample_eval_tag = "-sampleeval%d" % args.num_eval if args.num_eval is not None else ""
574 | sample_train_tag = "-ntrain%d" % args.num_train if args.num_train > 0 else ""
575 | sample_dev_tag = "-ndev%d" % args.num_dev if args.num_dev is not None else ""
576 | customized_tag = f"-{args.tag}" if len(args.tag) > 0 else ""
577 | return f"{args.task_name}-{save_model_name}" + sfc_tag + icl_sfc_tag + sample_eval_tag + sample_train_tag + sample_dev_tag + customized_tag
578 |
579 |
580 | def main():
581 | args = parse_args()
582 | if args.prefix_tuning:
583 | args.mode = "prefix"
584 | elif args.lora:
585 | args.mode = "lora"
586 | elif args.prompt_tuning:
587 | args.mode = "prompt"
588 | else:
589 | args.mode = "ft"
590 | args.tag = f"{args.trainer}-{args.task_name}-{args.template_ver}-{args.model_name.split('/')[-1]}-OPTIM_{args.mode}-STEP{args.max_steps}-{args.optimizer}-momen{args.momentum}-LR{args.learning_rate}-{args.lr_scheduler_type}-ZOEPS{args.zo_eps}-T{args.update_interval}-gauss_rank{args.gauss_rank}-Q{args.q}-bs{args.per_device_train_batch_size}-gradAccumulation{args.gradient_accumulation_steps}"
591 | args.run_name = args.tag
592 | args.output_dir = f"result/{args.task_name}/{args.model_name.split('/')[-1]}/{args.mode}/{args.trainer}/{args.tag}"
593 | args.result_file = f"result/{args.task_name}/{args.model_name.split('/')[-1]}/{args.mode}/{args.trainer}/{args.tag}/results.json"
594 | os.makedirs(args.output_dir, exist_ok=True)
595 |
596 | current_date = datetime.now().strftime('%Y-%m-%d_%H-%M-%S')
597 | # wandb.init(project='zo-bench', name=args.tag, config=args)
598 | tensorboard_log_dir = f"result/{args.task_name}/{args.model_name.split('/')[-1]}/{args.mode}/{args.trainer}/{args.tag}/{current_date}"
599 | args.logging_dir = os.path.join(tensorboard_log_dir, "logs")
600 | os.makedirs(args.logging_dir, exist_ok=True)
601 |
602 | writer = SummaryWriter(tensorboard_log_dir)
603 | set_seed(args.seed)
604 | task = get_task(args.task_name)
605 |
606 | # This function samples both training and validation samples. The validation (dev) samples are also stored in "train_sets"
607 | # Later the train_samples and dev_samples are separated
608 | train_sets = task.sample_train_sets(num_train=args.num_train, num_dev=args.num_dev, num_eval=args.num_eval,
609 | num_train_sets=args.num_train_sets, seed=args.train_set_seed)
610 |
611 | # Initialize trainer and load model
612 | framework = Framework(args, task)
613 |
614 | # ZO-Bench Added
615 | # We add these parameters to evaluate the model during the training.
616 | # These two parameters will be used in the training loop
617 | # args.task = task
618 | # args.framework = framework
619 |
620 | if args.train_set_seed is not None or args.num_train_sets is not None:
621 |
622 | # Training goes to this way
623 |
624 | # Eval samples share one (or multiple) training set(s)
625 | for train_set_id, train_samples in enumerate(train_sets):
626 | train_set_seed = train_set_id if args.train_set_seed is None else args.train_set_seed
627 |
628 | # Sample eval samples
629 | if args.num_eval is not None:
630 | eval_samples = task.sample_subset(data_split="valid", seed=train_set_seed, num=args.num_eval)
631 | else:
632 | eval_samples = task.valid_samples
633 |
634 | if args.trainer != "none":
635 | # Here the training samples are seperated
636 | if args.num_dev is not None:
637 | # Dev samples
638 | # assert args.num_dev + args.num_train <= len(train_samples), f"num_dev({args.num_dev})+num_train({args.num_train}) is more than actual num of training samples ({len(train_samples)})."
639 | dev_samples = train_samples[-args.num_dev:]
640 | train_samples = train_samples[:-args.num_dev]
641 | logger.info("Dev samples: %d" % len(dev_samples))
642 | logger.info("Train samples: %d" % len(train_samples))
643 | else:
644 | dev_samples = None
645 | logger.info("Train samples: %d" % len(train_samples))
646 | logger.info("No dev samples")
647 |
648 | args.dev_samples = dev_samples
649 | args.eval_samples = eval_samples
650 |
651 | # Training
652 | framework.train(train_samples, dev_samples if dev_samples is not None else eval_samples, eval_samples, writer)
653 |
654 | if not args.no_eval: # This is True
655 | metrics = framework.evaluate([], eval_samples, description="Evaluating on the Test Set")
656 | _keys = list(metrics.keys())
657 | for m in _keys:
658 | metrics["test_" + m] = metrics[m]
659 | if dev_samples is not None:
660 | dev_metrics = framework.evaluate(
661 | [], dev_samples, description="Evaluating on the Validation Set"
662 | )
663 | _keys = list(dev_metrics.keys())
664 | for m in _keys:
665 | metrics["val_" + m] = dev_metrics[m]
666 | else:
667 | assert args.num_dev is None
668 | # Zero-shot / in-context learning
669 | metrics = framework.evaluate(train_samples, eval_samples)
670 | logger.info(metrics)
671 | print('metrics: \n\n\n', metrics)
672 | # wandb.log(metrics)
673 |
674 | # for key, value in metrics.items():
675 | # writer.add_scalar(key, value, global_step)
676 |
677 | if not args.no_eval:
678 | logger.info("===== Train set %d =====" % train_set_seed)
679 | logger.info(metrics)
680 | print('metric: /n/n/n', metrics)
681 | # wandb.log(metrics)
682 | if args.local_rank <= 0:
683 | write_metrics_to_file(metrics, "result/" + result_file_tag(
684 | args) + f"-trainset{train_set_id}.json" if args.result_file is None else args.result_file)
685 | if args.trainer != "none" and args.clean_model_at_end:
686 | framework.delete_checkpoints()
687 |
688 | else:
689 | # For each eval sample, there is a training set. no training is allowed
690 | # This is for in-context learning (ICL)
691 | assert args.trainer == "none"
692 | if args.num_eval is not None:
693 | eval_samples = task.sample_subset(data_split="valid", seed=0, num=args.num_eval)
694 | else:
695 | eval_samples = task.valid_samples
696 | metrics = framework.evaluate(train_sets, eval_samples, one_train_set_per_eval_sample=True)
697 | logger.info(metrics)
698 | # wandb.log(metrics)
699 | if args.local_rank <= 0:
700 | write_metrics_to_file(metrics, "result/" + result_file_tag(
701 | args) + "-onetrainpereval.json" if args.result_file is None else args.result_file)
702 |
703 | writer.close()
704 |
705 | if __name__ == "__main__":
706 | main()
707 |
--------------------------------------------------------------------------------
/large_models/tasks.py:
--------------------------------------------------------------------------------
1 | import logging
2 | import sys
3 | from dataclasses import dataclass
4 | from typing import List, Union
5 |
6 | import numpy as np
7 | from datasets import load_dataset
8 |
9 | from templates import *
10 | from utils import temp_seed
11 |
12 | logger = logging.getLogger(__name__)
13 | logger.setLevel(logging.INFO)
14 |
15 |
16 | def get_task(task_name):
17 | aa = task_name.split("__")
18 | if len(aa) == 2:
19 | task_group, subtask = aa
20 | else:
21 | task_group = aa[0]
22 | subtask = None
23 | class_ = getattr(sys.modules[__name__], f"{task_group}Dataset")
24 | instance = class_(subtask)
25 | return instance
26 |
27 |
28 | @dataclass
29 | class Sample:
30 | id: int = None
31 | data: dict = None
32 | correct_candidate: Union[str, List[str]] = None
33 | candidates: List[str] = None
34 |
35 |
36 | class Dataset:
37 | mixed_set = False
38 | train_sep = "\n\n"
39 | generation = False # whether this is a generation task
40 |
41 | def __init__(self, subtask=None, **kwargs) -> None:
42 | self.samples = None
43 | self.subtask = subtask
44 |
45 | def get_task_name(self):
46 | return self.subtask
47 |
48 | def load_dataset(self, path, **kwargs):
49 | raise NotImplementedError
50 |
51 | def get_template(self, template_version=0):
52 | templates = {0: Template}
53 | return templates[template_version]
54 |
55 | def build_sample(self, example):
56 | return
57 |
58 | def sample_train_sets(self, num_train=32, num_dev=None, num_eval=None, num_train_sets=None, seed=None):
59 | if seed is not None:
60 | # one train/demo set using the designated seed
61 | seeds = [seed]
62 | elif num_train_sets is not None:
63 | # num_train_sets train/demo sets
64 | seeds = list(range(num_train_sets))
65 | else:
66 | # one train/demo set per evaluation sample
67 | assert num_dev is None # not supported
68 | len_valid_samples = len(self.samples["valid"]) if num_eval is None else num_eval
69 | with temp_seed(0):
70 | seeds = np.random.randint(0, 10000, len_valid_samples)
71 |
72 | train_samples = []
73 | for i, set_seed in enumerate(seeds):
74 | if self.mixed_set: # This is always False for now
75 | raise NotImplementedError
76 | train_samples.append(self.sample_subset(data_split="valid", seed=set_seed, num=num_train, exclude=i))
77 | else:
78 | if num_dev is not None:
79 | train_samples.append(self.sample_subset(data_split="train", seed=set_seed,
80 | num=num_train + num_dev)) # dev set is included at the end of train set
81 | if num_train + num_dev > len(self.samples["train"]):
82 | logger.warn("num_train + num_dev > available training examples")
83 | else:
84 | train_samples.append(self.sample_subset(data_split="train", seed=set_seed, num=num_train))
85 | if num_dev is not None:
86 | logger.info(f"Sample train set {len(train_samples[-1])}/{len(self.samples['train'])}")
87 | logger.info(f"... including dev set {num_dev} samples")
88 | return train_samples
89 |
90 | def sample_subset(self, data_split="train", seed=0, num=100, exclude=None):
91 | with temp_seed(seed):
92 | samples = self.samples[data_split]
93 | lens = len(samples)
94 | index = np.random.permutation(lens).tolist()[:num if exclude is None else num + 1]
95 | if exclude is not None and exclude in index:
96 | index.remove(exclude)
97 | else:
98 | index = index[:num]
99 | return [samples[i] for i in index]
100 |
101 | @property
102 | def valid_samples(self):
103 | return self.samples["valid"]
104 |
105 |
106 | class SST2Dataset(Dataset):
107 | train_sep = "\n\n"
108 |
109 | def __init__(self, subtask=None, **kwargs) -> None:
110 | self.load_dataset(subtask, **kwargs)
111 |
112 | def load_dataset(self, path, **kwargs):
113 | d = load_dataset('glue', 'sst2')
114 | train_d = d["train"]
115 | validation_d = d["validation"]
116 |
117 | train_samples = [self.build_sample(example) for example in train_d]
118 | valid_samples = [self.build_sample(example) for example in validation_d]
119 |
120 | self.samples = {"train": train_samples, "valid": valid_samples}
121 |
122 | # for generative tasks, candidates are []
123 | def build_sample(self, example):
124 | label = int(example["label"])
125 | # print('example', example)
126 | return Sample(id=example["idx"], data=example, correct_candidate=label, candidates=[0, 1])
127 |
128 | def get_template(self, template_version=0):
129 | return {0: SST2Template, 1: SST2TemplateEmpty}[template_version]()
130 |
131 | class SST5Dataset(Dataset):
132 | train_sep = "\n\n"
133 |
134 | def __init__(self, subtask=None, **kwargs) -> None:
135 | self.load_dataset(subtask, **kwargs)
136 |
137 | def load_dataset(self, path, **kwargs):
138 | d = load_dataset("SetFit/sst5")
139 | # print(d)
140 | train_d = d["train"]
141 | validation_d = d["validation"]
142 |
143 | train_samples = [self.build_sample(example) for example in train_d]
144 | valid_samples = [self.build_sample(example) for example in validation_d]
145 |
146 | self.samples = {"train": train_samples, "valid": valid_samples}
147 |
148 | # for generative tasks, candidates are []
149 | def build_sample(self, example):
150 | label = int(example["label"])
151 | # print('example', example)
152 | return Sample(data=example, correct_candidate=label, candidates=[0, 1, 2, 3, 4])
153 |
154 | def get_template(self, template_version=0):
155 | return {0: SST5Template, 1: SST5TemplateEmpty}[template_version]()
156 |
157 | class CopaDataset(Dataset):
158 | train_sep = "\n\n"
159 | mixed_set = False
160 |
161 | def __init__(self, subtask=None, **kwargs) -> None:
162 | self.load_dataset(subtask, **kwargs)
163 |
164 | def load_dataset(self, path, **kwargs):
165 | train_examples = load_dataset('super_glue', "copa")["train"]
166 | valid_examples = load_dataset('super_glue', "copa")["validation"]
167 |
168 | train_samples = [self.build_sample(example) for example in train_examples]
169 | valid_samples = [self.build_sample(example) for example in valid_examples]
170 | self.samples = {"train": train_samples, "valid": valid_samples}
171 |
172 | # for generative tasks, candidates are []
173 | def build_sample(self, example):
174 | sample = \
175 | Sample(
176 | id=example["idx"],
177 | data=example,
178 | candidates=[example["choice1"], example["choice2"]],
179 | correct_candidate=example[f"choice{example['label'] + 1}"],
180 | )
181 |
182 | return sample
183 |
184 | def get_template(self, template_version=0):
185 | return {0: CopaTemplate, 1: CopaTemplateEmpty}[template_version]()
186 |
187 |
188 | class BoolQDataset(Dataset):
189 | def __init__(self, subtask=None, **kwargs) -> None:
190 | self.load_dataset(subtask, **kwargs)
191 |
192 | def load_dataset(self, path, **kwargs):
193 | d = load_dataset("boolq")
194 | train_set = d["train"]
195 | valid_set = d["validation"]
196 |
197 | train_samples = [self.build_sample(example) for example in train_set]
198 | valid_samples = [self.build_sample(example) for example in valid_set]
199 | self.samples = {"train": train_samples, "valid": valid_samples}
200 |
201 | def build_sample(self, example):
202 | # print('example', example)
203 | sample = \
204 | Sample(
205 | data=example,
206 | candidates=["Yes", "No"],
207 | correct_candidate="Yes" if example["answer"] else "No",
208 | )
209 |
210 | return sample
211 |
212 | def get_template(self, template_version=2):
213 | return {0: BoolQTemplate, 1: BoolQTemplateV2, 2: BoolQTemplateV3}[template_version]()
214 |
215 |
216 | class MultiRCDataset(Dataset):
217 |
218 | def __init__(self, subtask=None, **kwargs) -> None:
219 | self.load_dataset(subtask, **kwargs)
220 |
221 | def load_dataset(self, path, **kwargs):
222 | d = load_dataset("super_glue", "multirc")
223 | train_set = d["train"]
224 | valid_set = d["validation"]
225 |
226 | train_samples = [self.build_sample(example) for example in train_set]
227 | valid_samples = [self.build_sample(example) for example in valid_set]
228 | self.samples = {"train": train_samples, "valid": valid_samples}
229 |
230 | def build_sample(self, example):
231 | sample = \
232 | Sample(
233 | data=example,
234 | candidates=[0, 1],
235 | correct_candidate=example['label']
236 | )
237 |
238 | return sample
239 |
240 | def get_template(self, template_version=0):
241 | return {0: MultiRCTemplate}[template_version]()
242 |
243 |
244 | class CBDataset(Dataset):
245 |
246 | def __init__(self, subtask=None, **kwargs) -> None:
247 | self.load_dataset(subtask, **kwargs)
248 |
249 | def load_dataset(self, path, **kwargs):
250 | d = load_dataset("super_glue", "cb")
251 | train_set = d["train"]
252 | valid_set = d["validation"]
253 |
254 | train_samples = [self.build_sample(example) for example in train_set]
255 | valid_samples = [self.build_sample(example) for example in valid_set]
256 | self.samples = {"train": train_samples, "valid": valid_samples}
257 |
258 | def build_sample(self, example):
259 | sample = \
260 | Sample(
261 | data=example,
262 | candidates=[0, 1, 2],
263 | correct_candidate=example['label']
264 | )
265 |
266 | return sample
267 |
268 | def get_template(self, template_version=0):
269 | return {0: CBTemplate}[template_version]()
270 |
271 |
272 | class WICDataset(Dataset):
273 |
274 | def __init__(self, subtask=None, **kwargs) -> None:
275 | self.load_dataset(subtask, **kwargs)
276 |
277 | def load_dataset(self, path, **kwargs):
278 | d = load_dataset("super_glue", "wic")
279 | train_set = d["train"]
280 | valid_set = d["validation"]
281 |
282 | train_samples = [self.build_sample(example) for example in train_set]
283 | valid_samples = [self.build_sample(example) for example in valid_set]
284 | self.samples = {"train": train_samples, "valid": valid_samples}
285 |
286 | def build_sample(self, example):
287 | sample = \
288 | Sample(
289 | data=example,
290 | candidates=[0, 1],
291 | correct_candidate=example['label']
292 | )
293 |
294 | return sample
295 |
296 | def get_template(self, template_version=0):
297 | return {0: WICTemplate}[template_version]()
298 |
299 |
300 | class WSCDataset(Dataset):
301 |
302 | def __init__(self, subtask=None, **kwargs) -> None:
303 | self.load_dataset(subtask, **kwargs)
304 |
305 | def load_dataset(self, path, **kwargs):
306 | d = load_dataset("super_glue", "wsc.fixed")
307 | train_set = d["train"]
308 | valid_set = d["validation"]
309 |
310 | train_samples = [self.build_sample(example) for example in train_set]
311 | valid_samples = [self.build_sample(example) for example in valid_set]
312 | self.samples = {"train": train_samples, "valid": valid_samples}
313 |
314 | def build_sample(self, example):
315 | sample = \
316 | Sample(
317 | data=example,
318 | candidates=[0, 1],
319 | correct_candidate=example['label']
320 | )
321 |
322 | return sample
323 |
324 | def get_template(self, template_version=0):
325 | return {0: WSCTemplate}[template_version]()
326 |
327 |
328 | class ReCoRDDataset(Dataset):
329 |
330 | def __init__(self, subtask=None, **kwargs) -> None:
331 | self.load_dataset(subtask, **kwargs)
332 |
333 | def load_dataset(self, path, **kwargs):
334 | d = load_dataset("super_glue", "record")
335 | train_set = d["train"]
336 | valid_set = d["validation"]
337 |
338 | train_samples = [self.build_sample(example) for example in train_set]
339 | valid_samples = [self.build_sample(example) for example in valid_set]
340 | self.samples = {"train": train_samples, "valid": valid_samples}
341 |
342 | def build_sample(self, example):
343 | sample = \
344 | Sample(
345 | data=example,
346 | candidates=example['entities'],
347 | correct_candidate=example['answers']
348 | )
349 |
350 | return sample
351 |
352 | def get_template(self, template_version=0):
353 | return {0: ReCoRDTemplateGPT3}[template_version]()
354 |
355 |
356 | class RTEDataset(Dataset):
357 |
358 | def __init__(self, subtask=None, **kwargs) -> None:
359 | self.load_dataset(subtask, **kwargs)
360 |
361 | def load_dataset(self, path, **kwargs):
362 | d = load_dataset("super_glue", "rte")
363 | train_set = d["train"]
364 | valid_set = d["validation"]
365 |
366 | train_samples = [self.build_sample(example) for example in train_set]
367 | valid_samples = [self.build_sample(example) for example in valid_set]
368 | self.samples = {"train": train_samples, "valid": valid_samples}
369 |
370 | def build_sample(self, example):
371 | sample = \
372 | Sample(
373 | data=example,
374 | candidates=[0, 1],
375 | correct_candidate=example['label']
376 | )
377 |
378 | return sample
379 |
380 | def get_template(self, template_version=0):
381 | return {0: RTETemplate, 1: RTETemplateEmpty}[template_version]()
382 |
383 |
384 | class SQuADDataset(Dataset):
385 | metric_name = "f1"
386 | generation = True
387 |
388 | def __init__(self, subtask=None, **kwargs) -> None:
389 | self.load_dataset()
390 |
391 | def load_dataset(self):
392 | dataset = load_dataset("squad")
393 | train_examples = dataset["train"]
394 | valid_examples = dataset["validation"]
395 |
396 | train_samples = [self.build_sample(example, idx) for idx, example in enumerate(train_examples)]
397 | valid_samples = [self.build_sample(example, idx) for idx, example in enumerate(valid_examples)]
398 | self.samples = {"train": train_samples, "valid": valid_samples}
399 |
400 | # for generative tasks, candidates are []
401 | def build_sample(self, example, idx):
402 | answers = example['answers']['text']
403 | assert len(answers) > 0
404 | return Sample(
405 | id=idx,
406 | data={
407 | "title": example['title'],
408 | "context": example['context'],
409 | "question": example['question'],
410 | "answers": answers
411 | },
412 | candidates=None,
413 | correct_candidate=answers
414 | )
415 |
416 | def get_template(self, template_version=0):
417 | return {0: SQuADv2Template}[template_version]()
418 |
419 |
420 | class DROPDataset(Dataset):
421 | metric_name = "f1"
422 | generation = True
423 |
424 | def __init__(self, subtask=None, **kwargs) -> None:
425 | self.load_dataset()
426 |
427 | def load_dataset(self):
428 | dataset = load_dataset("drop")
429 | train_examples = dataset["train"]
430 | valid_examples = dataset["validation"]
431 |
432 | train_samples = [self.build_sample(example, idx) for idx, example in enumerate(train_examples)]
433 | valid_samples = [self.build_sample(example, idx) for idx, example in enumerate(valid_examples)]
434 | self.samples = {"train": train_samples, "valid": valid_samples}
435 |
436 | # for generative tasks, candidates are []
437 | def build_sample(self, example, idx):
438 | answers = example['answers_spans']['spans']
439 | assert len(answers) > 0
440 | return Sample(
441 | id=idx,
442 | data={
443 | "context": example['passage'],
444 | "question": example['question'],
445 | "answers": answers
446 | },
447 | candidates=None,
448 | correct_candidate=answers
449 | )
450 |
451 | def get_template(self, template_version=0):
452 | return {0: DROPTemplate}[template_version]()
453 |
454 |
455 | class WinoGrandeDataset(Dataset):
456 | def __init__(self, subtask=None, **kwargs) -> None:
457 | super().__init__(subtask, **kwargs)
458 | self.load_dataset(subtask, **kwargs)
459 |
460 | def load_dataset(self, path, **kwargs):
461 | train_set = load_dataset('winogrande', 'winogrande_m', split='train')
462 | valid_set = load_dataset('winogrande', 'winogrande_m', split='validation')
463 |
464 | train_samples = [self.build_sample(example) for example in train_set]
465 | valid_samples = [self.build_sample(example) for example in valid_set]
466 | self.samples = {"train": train_samples, "valid": valid_samples}
467 |
468 | def build_sample(self, example):
469 | """
470 | Prompt adapted from https://arxiv.org/pdf/2110.08207.pdf
471 | """
472 | sentence = example["sentence"]
473 | context, target = sentence.split("_")
474 | sample = Sample(
475 | data=example,
476 | candidates=[example['option1'] + target, example['option2'] + target],
477 | correct_candidate=example[f'option{example["answer"]}'] + target,
478 | )
479 | return sample
480 |
481 | def get_template(self, template_version=0):
482 | if template_version == 0:
483 | return WinoGrandeTemplate()
484 | else:
485 | raise NotImplementedError(f"Template version {template_version} not implemented for WinoGrande")
486 |
--------------------------------------------------------------------------------
/large_models/templates.py:
--------------------------------------------------------------------------------
1 | class Template:
2 | def encode(self, sample):
3 | """
4 | Return prompted version of the example (without the answer/candidate)
5 | """
6 | raise NotImplementedError
7 |
8 | def verbalize(self, sample, candidate):
9 | """
10 | Return the prompted version of the example (with the answer/candidate)
11 | """
12 | return candidate
13 |
14 | def encode_sfc(self, sample):
15 | """
16 | Same as encode, but for SFC (calibration) -- this usually means the input is not included
17 | """
18 | return ""
19 |
20 | def verbalize_sfc(self, sample, candidate):
21 | """
22 | Same as verbalize, but for SFC (calibration) -- this usually means the input is not included
23 | """
24 | return candidate
25 |
26 |
27 | class SST2Template(Template):
28 | verbalizer = {0: "terrible", 1: "great"}
29 |
30 | def encode(self, sample):
31 | text = sample.data["sentence"].strip()
32 | return f"{text} It was"
33 |
34 | def verbalize(self, sample, candidate):
35 | text = sample.data["sentence"].strip()
36 | return f"{text} It was {self.verbalizer[candidate]}"
37 |
38 | def encode_sfc(self, sample):
39 | return f" It was"
40 |
41 | def verbalize_sfc(self, sample, candidate):
42 | return f" It was {self.verbalizer[candidate]}"
43 |
44 | class SST2TemplateEmpty(Template):
45 | verbalizer = {0: "terrible", 1: "great"}
46 |
47 | def encode(self, sample):
48 | text = sample.data["sentence"].strip()
49 | return f"{text} "
50 |
51 | def verbalize(self, sample, candidate):
52 | text = sample.data["sentence"].strip()
53 | return f"{text} {self.verbalizer[candidate]}"
54 |
55 | def encode_sfc(self, sample):
56 | return f" "
57 |
58 | def verbalize_sfc(self, sample, candidate):
59 | return f" {self.verbalizer[candidate]}"
60 |
61 |
62 | class CopaTemplate(Template):
63 | capitalization: str = "correct"
64 | effect_conj: str = " so "
65 | cause_conj: str = " because "
66 |
67 | def get_conjucture(self, sample):
68 | if sample.data["question"] == "effect":
69 | conjunction = self.effect_conj
70 | elif sample.data["question"] == "cause":
71 | conjunction = self.cause_conj
72 | else:
73 | raise NotImplementedError
74 | return conjunction
75 |
76 | def get_prompt(self, sample):
77 | premise = sample.data["premise"].rstrip()
78 | if premise.endswith("."): # TODO Add other scripts with different punctuation
79 | premise = premise[:-1]
80 | conjunction = self.get_conjucture(sample)
81 | prompt = premise + conjunction
82 | if self.capitalization == "upper":
83 | prompt = prompt.upper()
84 | elif self.capitalization == "lower":
85 | prompt = prompt.lower()
86 | return prompt
87 |
88 | def encode(self, sample):
89 | prompt = self.get_prompt(sample)
90 | return prompt
91 |
92 | def capitalize(self, c):
93 | if self.capitalization == "correct":
94 | words = c.split(" ")
95 | if words[0] != "I":
96 | words[0] = words[0].lower()
97 | return " ".join(words)
98 | elif self.capitalization == "bug":
99 | return c
100 | elif self.capitalization == "upper":
101 | return c.upper()
102 | elif self.capitalization == "lower":
103 | return c.lower()
104 | else:
105 | raise NotImplementedError
106 |
107 | def verbalize(self, sample, candidate):
108 | prompt = self.get_prompt(sample)
109 | return prompt + self.capitalize(candidate)
110 |
111 | def encode_sfc(self, sample):
112 | conjunction = self.get_conjucture(sample)
113 | return conjunction.strip()
114 |
115 | def verbalize_sfc(self, sample, candidate):
116 | conjunction = self.get_conjucture(sample)
117 | sfc_prompt = conjunction.strip() + " " + self.capitalize(candidate)
118 | return sfc_prompt
119 |
120 |
121 | class CopaTemplateEmpty(Template):
122 | capitalization: str = "correct"
123 | effect_conj: str = " "
124 | cause_conj: str = " "
125 |
126 | def get_conjucture(self, sample):
127 | if sample.data["question"] == "effect":
128 | conjunction = self.effect_conj
129 | elif sample.data["question"] == "cause":
130 | conjunction = self.cause_conj
131 | else:
132 | raise NotImplementedError
133 | return conjunction
134 |
135 | def get_prompt(self, sample):
136 | premise = sample.data["premise"].rstrip()
137 | if premise.endswith("."): # TODO Add other scripts with different punctuation
138 | premise = premise[:-1]
139 | conjunction = self.get_conjucture(sample)
140 | prompt = premise + conjunction
141 | if self.capitalization == "upper":
142 | prompt = prompt.upper()
143 | elif self.capitalization == "lower":
144 | prompt = prompt.lower()
145 | return prompt
146 |
147 | def encode(self, sample):
148 | prompt = self.get_prompt(sample)
149 | return prompt
150 |
151 | def capitalize(self, c):
152 | if self.capitalization == "correct":
153 | words = c.split(" ")
154 | if words[0] != "I":
155 | words[0] = words[0].lower()
156 | return " ".join(words)
157 | elif self.capitalization == "bug":
158 | return c
159 | elif self.capitalization == "upper":
160 | return c.upper()
161 | elif self.capitalization == "lower":
162 | return c.lower()
163 | else:
164 | raise NotImplementedError
165 |
166 | def verbalize(self, sample, candidate):
167 | prompt = self.get_prompt(sample)
168 | return prompt + self.capitalize(candidate)
169 |
170 | def encode_sfc(self, sample):
171 | conjunction = self.get_conjucture(sample)
172 | return conjunction.strip()
173 |
174 | def verbalize_sfc(self, sample, candidate):
175 | conjunction = self.get_conjucture(sample)
176 | sfc_prompt = conjunction.strip() + " " + self.capitalize(candidate)
177 | return sfc_prompt
178 |
179 |
180 | class BoolQTemplate(Template):
181 | def encode(self, sample):
182 | passage = sample.data["passage"]
183 | question = sample.data["question"]
184 | if not question.endswith("?"):
185 | question = question + "?"
186 | question = question[0].upper() + question[1:]
187 | return f"{passage} {question}"
188 |
189 | def verbalize(self, sample, candidate):
190 | passage = sample.data["passage"]
191 | question = sample.data["question"]
192 | if not question.endswith("?"):
193 | question = question + "?"
194 | question = question[0].upper() + question[1:]
195 | return f"{passage} {question} {candidate}"
196 |
197 | def encode_sfc(self, sample):
198 | return ""
199 |
200 | def verbalize_sfc(self, sample, candidate):
201 | return candidate
202 |
203 |
204 | class BoolQTemplateV2(Template):
205 | def encode(self, sample):
206 | passage = sample.data["passage"]
207 | question = sample.data["question"]
208 | if not question.endswith("?"):
209 | question = question + "?"
210 | question = question[0].upper() + question[1:]
211 | return f"{passage} {question}\\n\\n"
212 |
213 | def verbalize(self, sample, candidate):
214 | passage = sample.data["passage"]
215 | question = sample.data["question"]
216 | if not question.endswith("?"):
217 | question = question + "?"
218 | question = question[0].upper() + question[1:]
219 | return f"{passage} {question}\\n\\n{candidate}"
220 |
221 | def encode_sfc(self, sample):
222 | return ""
223 |
224 | def verbalize_sfc(self, sample, candidate):
225 | return candidate
226 |
227 |
228 | class BoolQTemplateV3(Template):
229 | def encode(self, sample):
230 | passage = sample.data["passage"]
231 | question = sample.data["question"]
232 | if not question.endswith("?"):
233 | question = question + "?"
234 | question = question[0].upper() + question[1:]
235 | return f"{passage} {question}\n"
236 |
237 | def verbalize(self, sample, candidate):
238 | passage = sample.data["passage"]
239 | question = sample.data["question"]
240 | if not question.endswith("?"):
241 | question = question + "?"
242 | question = question[0].upper() + question[1:]
243 | return f"{passage} {question}\n{candidate}"
244 |
245 | def encode_sfc(self, sample):
246 | return ""
247 |
248 | def verbalize_sfc(self, sample, candidate):
249 | return candidate
250 |
251 |
252 | class MultiRCTemplate(Template):
253 | # From PromptSource 1
254 | verbalizer = {0: "No", 1: "Yes"}
255 |
256 | def encode(self, sample):
257 | paragraph = sample.data["paragraph"]
258 | question = sample.data["question"]
259 | answer = sample.data["answer"]
260 | return f"{paragraph}\nQuestion: {question}\nI found this answer \"{answer}\". Is that correct? Yes or No?\n"
261 |
262 | def verbalize(self, sample, candidate):
263 | paragraph = sample.data["paragraph"]
264 | question = sample.data["question"]
265 | answer = sample.data["answer"]
266 | return f"{paragraph}\nQuestion: {question}\nI found this answer \"{answer}\". Is that correct? Yes or No?\n{self.verbalizer[candidate]}"
267 |
268 | def encode_sfc(self, sample):
269 | return f""
270 |
271 | def verbalize_sfc(self, sample, candidate):
272 | return f"{self.verbalizer[candidate]}"
273 |
274 |
275 | class CBTemplate(Template):
276 | # From PromptSource 1
277 | verbalizer = {0: "Yes", 1: "No", 2: "Maybe"}
278 |
279 | def encode(self, sample):
280 | premise = sample.data["premise"]
281 | hypothesis = sample.data["hypothesis"]
282 | return f"Suppose {premise} Can we infer that \"{hypothesis}\"? Yes, No, or Maybe?\n"
283 |
284 | def verbalize(self, sample, candidate):
285 | premise = sample.data["premise"]
286 | hypothesis = sample.data["hypothesis"]
287 | return f"Suppose {premise} Can we infer that \"{hypothesis}\"? Yes, No, or Maybe?\n{self.verbalizer[candidate]}"
288 |
289 | def encode_sfc(self, sample):
290 | return f""
291 |
292 | def verbalize_sfc(self, sample, candidate):
293 | return f"{self.verbalizer[candidate]}"
294 |
295 |
296 | class WICTemplate(Template):
297 | # From PromptSource 1
298 | verbalizer = {0: "No", 1: "Yes"}
299 |
300 | def encode(self, sample):
301 | sent1 = sample.data["sentence1"]
302 | sent2 = sample.data["sentence2"]
303 | word = sample.data["word"]
304 | return f"Does the word \"{word}\" have the same meaning in these two sentences? Yes, No?\n{sent1}\n{sent2}\n"
305 |
306 | def verbalize(self, sample, candidate):
307 | sent1 = sample.data["sentence1"]
308 | sent2 = sample.data["sentence2"]
309 | word = sample.data["word"]
310 | return f"Does the word \"{word}\" have the same meaning in these two sentences? Yes, No?\n{sent1}\n{sent2}\n{self.verbalizer[candidate]}"
311 |
312 | def encode_sfc(self, sample):
313 | return f""
314 |
315 | def verbalize_sfc(self, sample, candidate):
316 | return f"{self.verbalizer[candidate]}"
317 |
318 |
319 | class WSCTemplate(Template):
320 | # From PromptSource 1
321 | verbalizer = {0: "No", 1: "Yes"}
322 |
323 | def encode(self, sample):
324 | text = sample.data['text']
325 | span1 = sample.data['span1_text']
326 | span2 = sample.data['span2_text']
327 | return f"{text}\nIn the previous sentence, does the pronoun \"{span2.lower()}\" refer to {span1}? Yes or No?\n"
328 |
329 | def verbalize(self, sample, candidate):
330 | text = sample.data['text']
331 | span1 = sample.data['span1_text']
332 | span2 = sample.data['span2_text']
333 | return f"{text}\nIn the previous sentence, does the pronoun \"{span2.lower()}\" refer to {span1}? Yes or No?\n{self.verbalizer[candidate]}"
334 |
335 | def encode_sfc(self, sample):
336 | return f""
337 |
338 | def verbalize_sfc(self, sample, candidate):
339 | return f"{self.verbalizer[candidate]}"
340 |
341 |
342 | class ReCoRDTemplate(Template):
343 | # From PromptSource 1 but modified
344 |
345 | def encode(self, sample):
346 | passage = sample.data['passage']
347 | query = sample.data['query']
348 | return f"{passage}\n{query}\nQuestion: what is the \"@placeholder\"\nAnswer:"
349 |
350 | def verbalize(self, sample, candidate):
351 | passage = sample.data['passage']
352 | query = sample.data['query']
353 | return f"{passage}\n{query}\nQuestion: what is the \"@placeholder\"\nAnswer: {candidate}"
354 |
355 | def encode_sfc(self, sample):
356 | return f"Answer:"
357 |
358 | def verbalize_sfc(self, sample, candidate):
359 | return f"Answer: {candidate}"
360 |
361 |
362 | class ReCoRDTemplateGPT3(Template):
363 | # From PromptSource 1 but modified
364 |
365 | def encode(self, sample):
366 | passage = sample.data['passage'].replace("@highlight\n", "- ")
367 | return f"{passage}\n-"
368 |
369 | def verbalize(self, sample, candidate):
370 | passage = sample.data['passage'].replace("@highlight\n", "- ")
371 | query = sample.data['query'].replace("@placeholder", candidate[0] if isinstance(candidate, list) else candidate)
372 | return f"{passage}\n- {query}"
373 |
374 | # passage = sample.data['passage']
375 | # query = sample.data['query']
376 | # return f"{passage}\n{query}\nQuestion: what is the \"@placeholder\"\nAnswer: {candidate}"
377 |
378 | def encode_sfc(self, sample):
379 | return f"-"
380 |
381 | def verbalize_sfc(self, sample, candidate):
382 | query = sample.data['query'].replace("@placeholder", candidate[0] if isinstance(candidate, list) else candidate)
383 | return f"- {query}"
384 |
385 |
386 | class RTETemplate(Template):
387 | # From PromptSource 1
388 | verbalizer = {0: "Yes", 1: "No"}
389 |
390 | def encode(self, sample):
391 | premise = sample.data['premise']
392 | hypothesis = sample.data['hypothesis']
393 | return f"{premise}\nDoes this mean that \"{hypothesis}\" is true? Yes or No?\n"
394 |
395 | def verbalize(self, sample, candidate):
396 | premise = sample.data['premise']
397 | hypothesis = sample.data['hypothesis']
398 | return f"{premise}\nDoes this mean that \"{hypothesis}\" is true? Yes or No?\n{self.verbalizer[candidate]}"
399 |
400 | def encode_sfc(self, sample):
401 | return f""
402 |
403 | def verbalize_sfc(self, sample, candidate):
404 | return f"{self.verbalizer[candidate]}"
405 |
406 | class RTETemplateEmpty(Template):
407 | # From PromptSource 1
408 | verbalizer = {0: "Yes", 1: "No"}
409 |
410 | def encode(self, sample):
411 | premise = sample.data['premise']
412 | hypothesis = sample.data['hypothesis']
413 | return f"{premise}\n\"{hypothesis}\"\n"
414 |
415 | def verbalize(self, sample, candidate):
416 | premise = sample.data['premise']
417 | hypothesis = sample.data['hypothesis']
418 | return f"{premise}\n\"{hypothesis}\"\n{self.verbalizer[candidate]}"
419 |
420 | def encode_sfc(self, sample):
421 | return f""
422 |
423 | def verbalize_sfc(self, sample, candidate):
424 | return f"{self.verbalizer[candidate]}"
425 |
426 |
427 | class SQuADv2Template(Template):
428 |
429 | def encode(self, sample):
430 | question = sample.data['question'].strip()
431 | title = sample.data['title']
432 | context = sample.data['context']
433 | answer = sample.data['answers'][0] # there are multiple answers. for the prompt we only take the first one
434 |
435 | return f"Title: {title}\nContext: {context}\nQuestion: {question}\nAnswer:"
436 |
437 | def verbalize(self, sample, candidate):
438 | question = sample.data['question'].strip()
439 | title = sample.data['title']
440 | context = sample.data['context']
441 | answer = sample.data['answers'][0] # there are multiple answers. for the prompt we only take the first one
442 |
443 | return f"Title: {title}\nContext: {context}\nQuestion: {question}\nAnswer: {answer}\n"
444 |
445 | def encode_sfc(self, sample):
446 | raise NotImplementedError
447 |
448 | def verbalize_sfc(self, sample, candidate):
449 | raise NotImplementedError
450 |
451 |
452 | class DROPTemplate(Template):
453 |
454 | def encode(self, sample):
455 | question = sample.data['question'].strip()
456 | # title = sample.data['title']
457 | context = sample.data['context']
458 | answer = sample.data['answers'][0] # there are multiple answers. for the prompt we only take the first one
459 |
460 | return f"Passage: {context}\nQuestion: {question}\nAnswer:"
461 |
462 | def verbalize(self, sample, candidate):
463 | question = sample.data['question'].strip()
464 | # title = sample.data['title']
465 | context = sample.data['context']
466 | answer = sample.data['answers'][0] # there are multiple answers. for the prompt we only take the first one
467 |
468 | return f"Passage: {context}\nQuestion: {question}\nAnswer: {answer}\n"
469 |
470 | def encode_sfc(self, sample):
471 | raise NotImplementedError
472 |
473 | def verbalize_sfc(self, sample, candidate):
474 | raise NotImplementedError
475 |
476 |
477 | class WinoGrandeTemplate(Template):
478 | @staticmethod
479 | def get_prompt(sample):
480 | """
481 | Prompt adapted from https://arxiv.org/pdf/2110.08207.pdf
482 | """
483 | sentence = sample.data["sentence"]
484 | context, target = sentence.split("_")
485 | return context
486 |
487 | def encode(self, sample):
488 | prompt = self.get_prompt(sample)
489 | return prompt
490 |
491 | def verbalize(self, sample, candidate):
492 | prompt = self.get_prompt(sample)
493 | return prompt + candidate
494 |
495 | def encode_sfc(self, sample):
496 | return ""
497 |
498 | def verbalize_sfc(self, sample, candidate):
499 | return candidate
500 |
--------------------------------------------------------------------------------
/large_models/utils.py:
--------------------------------------------------------------------------------
1 | import contextlib
2 | import json
3 | import logging
4 | import signal
5 | import time
6 | from collections.abc import Mapping
7 | from dataclasses import is_dataclass, asdict
8 | from typing import Any, Dict, List, NewType, Optional, Union
9 |
10 | import numpy as np
11 | import torch
12 | import torch.nn.functional as F
13 | import transformers
14 | from torch.nn import CrossEntropyLoss
15 | from transformers.data.data_collator import DataCollatorMixin
16 | from transformers.modeling_outputs import CausalLMOutputWithPast
17 | from transformers.utils import PaddingStrategy
18 |
19 | InputDataClass = NewType("InputDataClass", Any)
20 | from dataclasses import dataclass
21 | from transformers.tokenization_utils_base import PreTrainedTokenizerBase
22 |
23 | logger = logging.getLogger(__name__)
24 |
25 | def forward_wrap_with_option_len(
26 | self,
27 | input_ids=None,
28 | labels=None,
29 | option_len=None,
30 | num_options=None,
31 | return_dict=None,
32 | **kwargs
33 | ):
34 | """
35 | This is to replace the original forward function of Transformer models to enable:
36 | (1) Partial target sequence: loss will only be calculated on part of the sequence
37 | (2) Classification-style training: a classification loss (CE) will be calculated over several options
38 | Input:
39 | - input_ids, labels: same as the original forward function
40 | - option_len: a list of int indicating the option lengths, and loss will be calculated only on the
41 | last option_len tokens
42 | - num_options: a list of int indicating the number of options for each example (this will be #label
43 | words for classification tasks and #choices for multiple choice tasks), and a classification loss
44 | will be calculated.
45 | """
46 | outputs = self.original_forward(input_ids=input_ids, **kwargs)
47 |
48 | if labels is None:
49 | return outputs
50 |
51 | # in prompt tuning, we need to remove the virtual tokens from the logits to match the input ids
52 | logits = outputs.logits
53 |
54 | loss = None
55 | # Shift so that tokens < n predict n
56 | shift_logits = logits[..., :-1, :].contiguous()
57 | # Here we use input_ids (which should always = labels) bc sometimes labels are correct candidate IDs
58 | shift_labels = torch.clone(input_ids)[..., 1:].contiguous()
59 | shift_labels[shift_labels == self.config.pad_token_id] = -100
60 |
61 | # Apply option len (do not calculate loss on the non-option part)
62 | # for _i, _len in enumerate(option_len):
63 | # shift_labels[_i, :-_len] = -100
64 | # re-write the above code to avoid the for loop
65 | non_option_len = shift_labels.shape[1] - option_len
66 | mask = torch.arange(
67 | shift_labels.shape[1], device=shift_labels.device
68 | ).expand(shift_labels.shape[0], -1) < non_option_len.unsqueeze(-1)
69 | shift_labels[mask] = -100
70 |
71 | # Calculate the loss
72 | loss_fct = CrossEntropyLoss(ignore_index=-100)
73 |
74 | if num_options is not None:
75 | # Train as a classification tasks
76 | log_probs = F.log_softmax(shift_logits, dim=-1)
77 | mask = shift_labels != -100 # Option part
78 | shift_labels[~mask] = 0 # So that it doesn't mess up with indexing
79 |
80 | selected_log_probs = torch.gather(log_probs, dim=-1, index=shift_labels.unsqueeze(-1)).squeeze(
81 | -1) # (bsz x num_options, len)
82 | selected_log_probs = (selected_log_probs * mask).sum(-1) / mask.sum(-1) # (bsz x num_options)
83 |
84 | if any([x != num_options[0] for x in num_options]):
85 | # Multi choice tasks with different number of options
86 | loss = 0
87 | start_id = 0
88 | count = 0
89 | while start_id < len(num_options):
90 | end_id = start_id + num_options[start_id]
91 | _logits = selected_log_probs[start_id:end_id].unsqueeze(0) # (1, num_options)
92 | _labels = labels[start_id:end_id][0].unsqueeze(0) # (1)
93 | loss = loss_fct(_logits, _labels) + loss
94 | count += 1
95 | start_id = end_id
96 | loss = loss / count
97 | else:
98 | num_options = num_options[0]
99 | selected_log_probs = selected_log_probs.view(-1, num_options) # (bsz, num_options)
100 | labels = labels.view(-1, num_options)[:, 0] # Labels repeat so we only take the first one
101 | # print('selected_log_probs', selected_log_probs.shape, selected_log_probs.softmax(dim=1).argmax(dim=1))
102 | # print('log', selected_log_probs.argmax(dim=1))
103 | loss = loss_fct(selected_log_probs, labels)
104 |
105 | else:
106 | loss = loss_fct(shift_logits.view(-1, self.config.vocab_size), shift_labels.view(-1))
107 |
108 | if not return_dict:
109 | output = (logits,) + outputs[1:]
110 | return (loss,) + output if loss is not None else output
111 |
112 | return CausalLMOutputWithPast(
113 | loss=loss,
114 | logits=logits,
115 | past_key_values=outputs.past_key_values,
116 | hidden_states=outputs.hidden_states,
117 | attentions=outputs.attentions,
118 | )
119 |
120 |
121 | def encode_prompt(task, template, train_samples, eval_sample, tokenizer, max_length, sfc=False, icl_sfc=False,
122 | generation=False, generation_with_gold=False, max_new_tokens=None):
123 | """
124 | Encode prompts for eval_sample
125 | Input:
126 | - task, template: task and template class
127 | - train_samples, eval_sample: demonstrations and the actual sample
128 | - tokenizer, max_length: tokenizer and max length
129 | - sfc: generate prompts for calibration (surface form competition; https://arxiv.org/abs/2104.08315)
130 | - icl_sfc: generate prompts for ICL version calibration
131 | - generation: whether it is an generation task
132 | - generation_with_gold: whether to include the generation-task gold answers (for training)
133 | - max_new_tokens: max number of new tokens to generate so that we can save enough space
134 | (only for generation tasks)
135 | Output:
136 | - encodings: a list of N lists of tokens. N is the number of options for classification/multiple-choice.
137 | - option_lens: a list of N integers indicating the number of option tokens.
138 | """
139 |
140 | # Demonstrations for ICL
141 | train_prompts = [template.verbalize(sample, sample.correct_candidate).strip() for sample in train_samples]
142 | train_prompts = task.train_sep.join(train_prompts).strip()
143 |
144 | # sfc or icl_sfc indicates that this example is used for calibration
145 | if sfc or icl_sfc:
146 | encode_fn = template.encode_sfc
147 | verbalize_fn = template.verbalize_sfc
148 | else:
149 | encode_fn = template.encode
150 | verbalize_fn = template.verbalize
151 |
152 | unverbalized_eval_prompt = encode_fn(eval_sample).strip(' ')
153 | if not generation:
154 | # We generate one prompt for each candidate (different classes in classification)
155 | # or different choices in multiple-choice tasks
156 | verbalized_eval_prompts = [verbalize_fn(eval_sample, cand).strip(' ') for cand in eval_sample.candidates]
157 | unverbalized_eval_prompt_length = len(tokenizer.encode(unverbalized_eval_prompt))
158 | option_lens = [(len(tokenizer.encode(verbalized_eval_prompt)) - unverbalized_eval_prompt_length) for
159 | verbalized_eval_prompt in verbalized_eval_prompts]
160 |
161 | if sfc:
162 | # Without demonstrations
163 | final_prompts = verbalized_eval_prompts
164 | else:
165 | # With demonstrations
166 | final_prompts = [(train_prompts + task.train_sep + eval_prompt).lstrip().strip(' ') for eval_prompt in
167 | verbalized_eval_prompts]
168 | else:
169 | assert not sfc and not icl_sfc, "Generation tasks do not support SFC"
170 | if generation_with_gold:
171 | verbalized_eval_prompts = [verbalize_fn(eval_sample, eval_sample.correct_candidate)]
172 | unverbalized_eval_prompt_length = len(tokenizer.encode(unverbalized_eval_prompt))
173 | option_lens = [(len(tokenizer.encode(verbalized_eval_prompt)) - unverbalized_eval_prompt_length) for
174 | verbalized_eval_prompt in verbalized_eval_prompts]
175 | final_prompts = [(train_prompts + task.train_sep + eval_prompt).lstrip().strip(' ') for eval_prompt in
176 | verbalized_eval_prompts]
177 | else:
178 | option_lens = [0]
179 | final_prompts = [(train_prompts + task.train_sep + unverbalized_eval_prompt).lstrip().strip(' ')]
180 |
181 | # Tokenize
182 | encodings = [tokenizer.encode(final_prompt) for final_prompt in final_prompts]
183 |
184 | # Truncate (left truncate as demonstrations are less important)
185 | if generation and max_new_tokens is not None:
186 | max_length = max_length - max_new_tokens
187 |
188 | if any([len(encoding) > max_length for encoding in encodings]):
189 | logger.warn("Exceed max length")
190 | if hasattr(tokenizer, 'add_bos_token') and tokenizer.add_bos_token:
191 | encodings = [encoding[0:1] + encoding[1:][-(max_length - 1):] for encoding in encodings]
192 | else:
193 | encodings = [encoding[-max_length:] for encoding in encodings]
194 |
195 | return encodings, option_lens
196 |
197 |
198 | @dataclass
199 | class ICLCollator:
200 | """
201 | Collator for ICL
202 | """
203 | tokenizer: PreTrainedTokenizerBase
204 |
205 | def __call__(self, features: List[Dict[str, Any]]) -> Dict[str, Any]:
206 | if not isinstance(features[0], Mapping):
207 | features = [vars(f) for f in features]
208 | first = features[0]
209 | batch = {}
210 |
211 | pad_id = self.tokenizer.pad_token_id
212 |
213 | pad_ids = {"input_ids": pad_id, "attention_mask": 0, "sfc_input_ids": pad_id, "sfc_attention_mask": 0,
214 | "labels": pad_id}
215 | for key in first:
216 | pp = pad_ids[key]
217 | lens = [len(f[key]) for f in features]
218 | max_len = max(lens)
219 | feature = np.stack([np.pad(f[key], (0, max_len - lens[i]), "constant", constant_values=(0, pp)) for i, f in
220 | enumerate(features)])
221 | padded_feature = torch.from_numpy(feature).long()
222 | batch[key] = padded_feature
223 |
224 | return batch
225 |
226 |
227 | @dataclass
228 | class DataCollatorWithPaddingAndNesting:
229 | """
230 | Collator for training
231 | """
232 |
233 | tokenizer: PreTrainedTokenizerBase
234 | padding: Union[bool, str, PaddingStrategy] = True
235 | max_length: Optional[int] = None
236 | pad_to_multiple_of: Optional[int] = None
237 | return_tensors: str = "pt"
238 |
239 | def __call__(self, features: List[Dict[str, Any]]) -> Dict[str, Any]:
240 | features = [ff for f in features for ff in f]
241 | batch = self.tokenizer.pad(
242 | features,
243 | padding=self.padding,
244 | max_length=self.max_length,
245 | pad_to_multiple_of=self.pad_to_multiple_of,
246 | return_tensors=self.return_tensors,
247 | )
248 | if "label" in batch:
249 | batch["labels"] = batch["label"]
250 | del batch["label"]
251 | if "label_ids" in batch:
252 | batch["labels"] = batch["label_ids"]
253 | del batch["label_ids"]
254 | return batch
255 |
256 |
257 | @dataclass
258 | class NondiffCollator(DataCollatorMixin):
259 | """
260 | Collator for non-differentiable objectives
261 | """
262 | tokenizer: PreTrainedTokenizerBase
263 | padding: Union[bool, str, PaddingStrategy] = True
264 | max_length: Optional[int] = None
265 | pad_to_multiple_of: Optional[int] = None
266 | label_pad_token_id: int = -100
267 | return_tensors: str = "pt"
268 |
269 | def torch_call(self, features):
270 | import torch
271 |
272 | label_name = "label" if "label" in features[0].keys() else "labels"
273 | labels = [feature[label_name] for feature in features] if label_name in features[0].keys() else None
274 |
275 | no_labels_features = [{k: v for k, v in feature.items() if k != label_name and k != "gold"} for feature in
276 | features]
277 |
278 | batch = self.tokenizer.pad(
279 | no_labels_features,
280 | padding=self.padding,
281 | max_length=self.max_length,
282 | pad_to_multiple_of=self.pad_to_multiple_of,
283 | return_tensors="pt",
284 | )
285 |
286 | if labels is None:
287 | return batch
288 |
289 | sequence_length = batch["input_ids"].shape[1]
290 | padding_side = self.tokenizer.padding_side
291 |
292 | def to_list(tensor_or_iterable):
293 | if isinstance(tensor_or_iterable, torch.Tensor):
294 | return tensor_or_iterable.tolist()
295 | return list(tensor_or_iterable)
296 |
297 | if padding_side == "right":
298 | batch[label_name] = [
299 | to_list(label) + [self.label_pad_token_id] * (sequence_length - len(label)) for label in labels
300 | ]
301 | else:
302 | batch[label_name] = [
303 | [self.label_pad_token_id] * (sequence_length - len(label)) + to_list(label) for label in labels
304 | ]
305 |
306 | batch[label_name] = torch.tensor(batch[label_name], dtype=torch.int64)
307 | if "gold" in features[0]:
308 | batch["gold"] = [feature["gold"] for feature in features]
309 |
310 | return batch
311 |
312 |
313 | class SIGUSR1Callback(transformers.TrainerCallback):
314 | """
315 | This callback is used to save the model when a SIGUSR1 signal is received
316 | (SLURM stop signal or a keyboard interruption signal).
317 | """
318 |
319 | def __init__(self) -> None:
320 | super().__init__()
321 | self.signal_received = False
322 | signal.signal(signal.SIGUSR1, self.handle_signal)
323 | signal.signal(signal.SIGINT, self.handle_signal)
324 | logger.warn("Handler registered")
325 |
326 | def handle_signal(self, signum, frame):
327 | self.signal_received = True
328 | logger.warn("Signal received")
329 |
330 | def on_step_end(self, args, state, control, **kwargs):
331 | if self.signal_received:
332 | control.should_save = True
333 | control.should_training_stop = True
334 |
335 | def on_train_end(self, args, state, control, **kwargs):
336 | if self.signal_received:
337 | exit(0)
338 |
339 |
340 | @dataclass
341 | class Prediction:
342 | correct_candidate: Union[int, str]
343 | predicted_candidate: Union[int, str]
344 |
345 |
346 | @contextlib.contextmanager
347 | def count_time(name):
348 | logger.info("%s..." % name)
349 | start_time = time.time()
350 | try:
351 | yield
352 | finally:
353 | logger.info("Done with %.2fs" % (time.time() - start_time))
354 |
355 |
356 | @contextlib.contextmanager
357 | def temp_seed(seed):
358 | state = np.random.get_state()
359 | np.random.seed(seed)
360 | try:
361 | yield
362 | finally:
363 | np.random.set_state(state)
364 |
365 |
366 | class EnhancedJSONEncoder(json.JSONEncoder):
367 | def default(self, o):
368 | if is_dataclass(o):
369 | return asdict(o)
370 | return super().default(o)
371 |
372 |
373 | def write_predictions_to_file(final_preds, output):
374 | with open(output, "w") as f:
375 | for pred in final_preds:
376 | f.write(json.dumps(pred, cls=EnhancedJSONEncoder) + "\n")
377 |
378 |
379 | def write_metrics_to_file(metrics, output):
380 | json.dump(metrics, open(output, "w"), cls=EnhancedJSONEncoder, indent=4)
381 |
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | accelerate==0.25.0
2 | aiohttp==3.9.1
3 | aiosignal==1.3.1
4 | appdirs==1.4.4
5 | async-timeout==4.0.3
6 | attrs==23.1.0
7 | certifi==2023.11.17
8 | charset-normalizer==3.3.2
9 | click==8.1.7
10 | datasets==2.16.0
11 | dill==0.3.7
12 | docker-pycreds==0.4.0
13 | filelock==3.13.1
14 | fsspec==2023.10.0
15 | gitdb==4.0.11
16 | GitPython==3.1.40
17 | gmpy2==2.1.2
18 | huggingface-hub==0.20.1
19 | idna==3.6
20 | Jinja2==3.1.2
21 | joblib==1.3.2
22 | llvmlite==0.41.1
23 | MarkupSafe==2.1.3
24 | mpmath==1.3.0
25 | multidict==6.0.4
26 | multiprocess==0.70.15
27 | networkx
28 | numba==0.58.1
29 | numpy
30 | packaging==23.2
31 | pandas
32 | pip==23.3.2
33 | protobuf==4.25.1
34 | psutil==5.9.7
35 | pyarrow==14.0.2
36 | pyarrow-hotfix==0.6
37 | python-dateutil==2.8.2
38 | pytz==2023.3.post1
39 | PyYAML==6.0.1
40 | regex==2023.12.25
41 | requests==2.31.0
42 | safetensors==0.4.1
43 | scikit-learn==1.3.2
44 | sentry-sdk==1.39.1
45 | setproctitle==1.3.3
46 | setuptools==68.2.2
47 | six==1.16.0
48 | smmap==5.0.1
49 | sympy==1.12
50 | threadpoolctl==3.2.0
51 | tokenizers==0.13.3
52 | torch==2.1.0
53 | tqdm==4.66.1
54 | transformers==4.28.1
55 | typing_extensions==4.9.0
56 | tzdata==2023.3
57 | urllib3==2.1.0
58 | wandb==0.16.1
59 | wheel==0.42.0
60 | xxhash==3.4.1
61 | yarl==1.9.4
62 |
--------------------------------------------------------------------------------