├── README.md ├── eval_helpers.py ├── quiet-star-train.py ├── configuration_mistral.py ├── LICENSE ├── zero-shotcot-eval.py └── modeling_mistral.py /README.md: -------------------------------------------------------------------------------- 1 | # Quiet-STaR 2 | 3 | Code for [Quiet-STaR: Language Models Can Teach Themselves to Think Before Speaking](https://arxiv.org/abs/2403.09629). 4 | 5 | This project is implemented by simply patching the base Mistral implementation in Huggingface `transformers` using a new `modeling_mistral.py` and a new `configuration_mistral.py` and otherwise applying standard `transformers` features (e.g. the default Trainer). Our patches were applied to Huggingface's `transformers` version `4.37.0.dev0` under `src/transformers/models/mistral/` -- we cannot guarantee that other changes to their implementation will not affect our implementation, so for reproducibility, we encourage using the same version. 6 | 7 | One pitfall to be wary of: the model is not taught not to generate start and end thought tokens. Thus, when performing actual inference, it is necessary to mask these out. 8 | 9 | We make an 8-thought-token ahead (including start and end tokens) model [available via Huggingface](https://huggingface.co/ezelikman/quietstar-8-ahead). 10 | -------------------------------------------------------------------------------- /eval_helpers.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import random 3 | from transformers import AutoTokenizer 4 | 5 | initial_tokenizer = AutoTokenizer.from_pretrained("mistralai/Mistral-7B-v0.1") 6 | initial_tokenizer.padding_side = "right" 7 | initial_tokenizer.pad_token_id = initial_tokenizer.eos_token_id 8 | eval_answer_marker="\nA:" 9 | 10 | def preprocess_function(examples): 11 | dataset_transform = lambda xs: xs["text"] 12 | all_tokenized = [initial_tokenizer.encode(t, return_tensors="pt") for t in dataset_transform(examples)] 13 | new_tokenized = [{"input_ids": t} for t in all_tokenized] 14 | for i, t in enumerate(new_tokenized): 15 | new_tokenized[i]["input_ids"] = truncate_or_pad(t['input_ids'], initial_tokenizer.pad_token_id) 16 | new_input_ids = torch.cat([t["input_ids"] for t in new_tokenized], dim=0) 17 | new_attention_mask = (new_input_ids != initial_tokenizer.pad_token_id).long() 18 | tokenized = {"input_ids": new_input_ids, "attention_mask": new_attention_mask} 19 | tokenized["labels"] = tokenized["input_ids"].clone() 20 | return tokenized 21 | 22 | def preprocess_eval_function_gsm(examples, use_few_shot=False, max_length=256): 23 | to_answer = lambda q, a: "Q: " + q + eval_answer_marker + a.split("####")[-1] + "\n" 24 | all_prompts = [to_answer(q, a) for q, a in zip(examples['question'], examples['answer'])] 25 | all_tokenized = [initial_tokenizer.encode(p, return_tensors="pt") for p in all_prompts] 26 | new_tokenized = [{"input_ids": t} for t in all_tokenized] 27 | for i, t in enumerate(new_tokenized): 28 | new_tokenized[i]["input_ids"] = truncate_or_pad(t['input_ids'], initial_tokenizer.pad_token_id, max_length) 29 | new_input_ids = torch.cat([t["input_ids"] for t in new_tokenized], dim=0) 30 | new_attention_mask = (new_input_ids != initial_tokenizer.pad_token_id).long() 31 | tokenized = {"input_ids": new_input_ids, "attention_mask": new_attention_mask} 32 | tokenized["labels"] = tokenized["input_ids"].clone() 33 | return tokenized 34 | 35 | def preprocess_eval_function_csqa(examples, max_length=256): 36 | def construct_question(q, choices): 37 | choice_list = "\n".join([f"({label}) {choice}" for label, choice in zip(choices["label"], choices["text"])]) 38 | return f"Q: {q}" + "\n" + choice_list 39 | to_answer = lambda q, c, a: construct_question(q, c) + eval_answer_marker + " " + a + "\n" 40 | all_prompts = [to_answer(q, c, a) for q, c, a in zip(examples['question'], examples['choices'], examples['answerKey'])] 41 | all_tokenized = [initial_tokenizer.encode(p, return_tensors="pt") for p in all_prompts] 42 | new_tokenized = [{"input_ids": t} for t in all_tokenized] 43 | for i, t in enumerate(new_tokenized): 44 | new_tokenized[i]["input_ids"] = truncate_or_pad(t['input_ids'], initial_tokenizer.pad_token_id, max_length) 45 | new_input_ids = torch.cat([t["input_ids"] for t in new_tokenized], dim=0) 46 | new_attention_mask = (new_input_ids != initial_tokenizer.pad_token_id).long() 47 | tokenized = {"input_ids": new_input_ids, "attention_mask": new_attention_mask} 48 | tokenized["labels"] = tokenized["input_ids"].clone() 49 | return tokenized 50 | 51 | def compute_metrics(eval_pred, filter_numbers=True): 52 | logits, labels, _ = eval_pred 53 | accuracy = 0 54 | valid_number_tokens = [28740, 28750, 28770, 28781, 28782, 28784, 28787, 28783, 28774, 28734, 13] # numbers 55 | valid_letter_tokens = [330, 365, 334, 384, 413, 13] # answer tokens 56 | for question, logits_guess in zip(labels, logits): 57 | # find which token corresponds to eval_answer_marker 58 | # chop off tokens from the end until the number of eval_answer_marker goes down 59 | detokenized_question = initial_tokenizer.decode(question) 60 | is_numeric = detokenized_question.split(eval_answer_marker)[-1][1].isdigit() 61 | valid_tokens = valid_number_tokens if is_numeric else valid_letter_tokens 62 | answer_count = detokenized_question.count(eval_answer_marker) 63 | for i in range(len(question) - 1, 0, -1): 64 | tokenized_subquestion = question[:i] 65 | if tokenized_subquestion[-1] == initial_tokenizer.pad_token_id: 66 | continue 67 | detokenized_subquestion = initial_tokenizer.decode(question[:i]) 68 | if detokenized_subquestion.count(eval_answer_marker) < answer_count: 69 | break 70 | correct_answer_prob = 1 71 | # if is_numeric, then the first token just indicates that it's a number 72 | question_offset = 1 if is_numeric else 0 73 | for j in range(i + question_offset, len(question) - 1): 74 | if question[j + 1] == initial_tokenizer.pad_token_id: 75 | break 76 | true_token = question[j + 1] 77 | guess = torch.nn.functional.softmax(torch.tensor(logits_guess), dim=-1) 78 | # we only care about the logits assigned to the correct token 79 | if filter_numbers: 80 | if true_token not in valid_tokens: 81 | continue 82 | guess_filtered = torch.zeros_like(guess) 83 | guess_filtered[:, valid_tokens] = guess[:, valid_tokens] 84 | guess_filtered = guess_filtered / guess_filtered.sum(dim=-1, keepdim=True) 85 | token_prob = guess_filtered[j, true_token] 86 | else: 87 | token_prob = guess[j, true_token] 88 | correct_answer_prob *= token_prob 89 | accuracy += correct_answer_prob / len(labels) 90 | return {"accuracy": accuracy} 91 | 92 | def truncate_or_pad(t, padding_idx=0, max_length=256): 93 | if t.shape[1] > max_length: 94 | start = random.randint(0, t.shape[1] - max_length) 95 | t = t[:, start:start + max_length] 96 | else: 97 | padding = torch.zeros(t.shape[0], max_length - t.shape[1], dtype=t.dtype, device=t.device) 98 | t = torch.cat([t, padding + padding_idx], dim=1) 99 | return t 100 | -------------------------------------------------------------------------------- /quiet-star-train.py: -------------------------------------------------------------------------------- 1 | import torch 2 | torch.backends.cuda.matmul.allow_tf32 = True 3 | import random 4 | from transformers import AutoTokenizer, AutoModelForCausalLM, TextGenerationPipeline, AutoConfig 5 | from accelerate import infer_auto_device_map, init_empty_weights, dispatch_model 6 | from datasets import load_dataset 7 | from torch.nn import CrossEntropyLoss 8 | from transformers import TrainingArguments, Trainer 9 | import os 10 | import time 11 | import wandb 12 | from huggingface_custom_callback import EarlyStoppingCallback 13 | from eval_helpers import preprocess_eval_function_gsm, preprocess_eval_function_csqa, preprocess_function, compute_metrics, truncate_or_pad 14 | random_seed = 42 15 | torch.manual_seed(random_seed) 16 | random.seed(random_seed) 17 | 18 | # MAIN SETUP 19 | root_prefix = "YOUR_CACHE_PATH_HERE" 20 | wandb_cache_dir = root_prefix + "cache/quietstar/wandb_cache" 21 | dataset_name = 'open-web-math/open-web-math' 22 | # dataset_name = 'c4' 23 | project_name = "quiet-star" 24 | os.environ["WANDB_PROJECT"] = project_name + "-" + dataset_name.split("/")[-1] 25 | os.environ["WANDB_CACHE_DIR"] = wandb_cache_dir 26 | n_ahead_talk_global = 4 27 | n_passes_global = 2 28 | n_ahead_global = 12 29 | n_examples = 1_000 30 | full_batch_size = 8 31 | eval_and_logging_steps = 10 32 | save_steps = 100 33 | 34 | def model_init(params): 35 | original = False 36 | if params is None: 37 | params = {} 38 | else: 39 | params = params.params 40 | # save params to file 41 | n_ahead = params.get("n_ahead", n_ahead_global if not original else 1) 42 | n_ahead_talk = params.get("n_ahead_talk", n_ahead_talk_global if not original else 1) 43 | n_passes = params.get("n_passes", n_passes_global if not original else 1) 44 | gumbel_temperature = params.get("gumbel_temperature", 1) 45 | use_start_thought_token = params.get("use_start_thought_token", True) 46 | use_end_thought_token = params.get("use_end_thought_token", True) 47 | include_policy_loss = params.get("include_policy_loss", True) 48 | gumbel_detach = params.get("gumbel_detach", True) 49 | merged_talk_heads = params.get("merged_talk_heads", True) 50 | gradient_accumulation_steps = params.get("gradient_accumulation_steps", global_gradient_accumulation_steps) 51 | residual_think_head = params.get("residual_think_head", False) 52 | optimize_lm_head_only_at_start = params.get("optimize_lm_head_only_at_start", False) 53 | 54 | model_name = "mistralai/Mistral-7B-v0.1" 55 | print("Loading model") 56 | model = AutoModelForCausalLM.from_pretrained( 57 | model_name, 58 | torch_dtype=torch.bfloat16 if torch.cuda.is_available() else torch.float32, 59 | device_map='auto', 60 | cache_dir=root_prefix + "cache", 61 | max_thoughts=n_ahead + n_ahead_talk + 1, 62 | merged_talk_heads=merged_talk_heads, 63 | merged_lm_and_talk_heads=False, 64 | merged_lm_and_think_heads=True, 65 | use_concat_talk_head=True, 66 | use_shallow_think=True, 67 | use_shallow_talk=False, 68 | use_complex_think_head=False, 69 | use_complex_talk_head=True, 70 | use_weighted_talk_head=True, 71 | ) 72 | print("Loaded model") 73 | tokenizer = AutoTokenizer.from_pretrained("mistralai/Mistral-7B-v0.1") 74 | tokenizer.padding_side = "right" 75 | tokenizer.pad_token_id = tokenizer.eos_token_id 76 | 77 | special_tokens_to_add = [] 78 | if model.use_start_thought_token: 79 | special_tokens_to_add.append("<|startthought|>") 80 | if model.use_end_thought_token: 81 | special_tokens_to_add.append("<|endthought|>") 82 | if special_tokens_to_add: 83 | tokenizer.add_special_tokens({"additional_special_tokens": special_tokens_to_add}) 84 | model.resize_token_embeddings(len(tokenizer)) 85 | model.tokenizer = tokenizer 86 | model.gumbel_detach = gumbel_detach 87 | model.include_policy_loss = include_policy_loss 88 | model.use_end_thought_token = use_end_thought_token 89 | model.use_start_thought_token = use_start_thought_token 90 | model.n_ahead = n_ahead 91 | model.n_ahead_talk = n_ahead_talk 92 | model.n_passes = n_passes 93 | model.n_tokens_print = gradient_accumulation_steps 94 | model.gradient_accumulation_steps = gradient_accumulation_steps 95 | model.residual_think_head = residual_think_head 96 | model.optimize_lm_head_only_at_start = optimize_lm_head_only_at_start 97 | model.gumbel_temperature = gumbel_temperature 98 | model.wandb_enabled = True 99 | model.original_mode = original 100 | model.config_params = params 101 | model.run_start = int(time.time()) 102 | model.kill_after = 100 103 | model.train() 104 | return model 105 | 106 | # Load dataset 107 | dataset = load_dataset( 108 | dataset_name, 109 | "en" if "c4" in dataset_name else "default", 110 | split=f"train[:{n_examples}]", 111 | ignore_verifications=True, 112 | num_proc=16, 113 | cache_dir=root_prefix + "cache/datasets/", 114 | ) 115 | 116 | train_dataset = dataset.shuffle(seed=random_seed).map(preprocess_function, batched=True, writer_batch_size=200) 117 | eval_dataset_gsm = load_dataset("gsm8k", "main", split="test", ignore_verifications=True).map(preprocess_eval_function_gsm, batched=True, writer_batch_size=200) 118 | eval_dataset_csqa = load_dataset("tau/commonsense_qa", "default", split="validation", ignore_verifications=True).map(preprocess_eval_function_csqa, batched=True, writer_batch_size=200) 119 | 120 | eval_datasets = { 121 | "gsm8k": eval_dataset_gsm, 122 | "csqa": eval_dataset_csqa, 123 | } 124 | 125 | batch_size = full_batch_size // n_passes_global 126 | global_gradient_accumulation_steps = full_batch_size // batch_size 127 | run_id = int(time.time()) 128 | training_args = TrainingArguments( 129 | output_dir=root_prefix + f"cache/quietstar/{run_id}", 130 | learning_rate=1e-6, 131 | optim="adamw_torch_fused" if torch.cuda.is_available() else "adamw_torch", 132 | per_device_train_batch_size=batch_size, 133 | per_device_eval_batch_size=batch_size, 134 | gradient_accumulation_steps=global_gradient_accumulation_steps, 135 | max_grad_norm=1.0, 136 | max_steps=100000, 137 | warmup_steps=20, 138 | auto_find_batch_size=True, 139 | weight_decay=0.001, 140 | label_names=["labels"], 141 | include_inputs_for_metrics=True, 142 | logging_steps=eval_and_logging_steps, 143 | eval_steps=eval_and_logging_steps, 144 | evaluation_strategy="steps", 145 | save_steps=save_steps, 146 | run_name=f"n={n_ahead_global}_nt={n_ahead_talk_global}_np={n_passes_global}", 147 | ) 148 | 149 | trainer = Trainer( 150 | args=training_args, 151 | train_dataset=train_dataset, 152 | eval_dataset=eval_datasets, 153 | compute_metrics=compute_metrics, 154 | model_init=model_init, 155 | ) 156 | 157 | trainer.train() 158 | -------------------------------------------------------------------------------- /configuration_mistral.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2023 Mistral AI and the HuggingFace Inc. team. All rights reserved. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | """ Mistral model configuration""" 16 | 17 | from ...configuration_utils import PretrainedConfig 18 | from ...utils import logging 19 | 20 | 21 | logger = logging.get_logger(__name__) 22 | 23 | MISTRAL_PRETRAINED_CONFIG_ARCHIVE_MAP = { 24 | "mistralai/Mistral-7B-v0.1": "https://huggingface.co/mistralai/Mistral-7B-v0.1/resolve/main/config.json", 25 | "mistralai/Mistral-7B-Instruct-v0.1": "https://huggingface.co/mistralai/Mistral-7B-Instruct-v0.1/resolve/main/config.json", 26 | } 27 | 28 | 29 | class MistralConfig(PretrainedConfig): 30 | r""" 31 | This is the configuration class to store the configuration of a [`MistralModel`]. It is used to instantiate an 32 | Mistral model according to the specified arguments, defining the model architecture. Instantiating a configuration 33 | with the defaults will yield a similar configuration to that of the Mistral-7B-v0.1 or Mistral-7B-Instruct-v0.1. 34 | 35 | [mistralai/Mistral-7B-v0.1](https://huggingface.co/mistralai/Mistral-7B-v0.1) 36 | [mistralai/Mistral-7B-Instruct-v0.1](https://huggingface.co/mistralai/Mistral-7B-Instruct-v0.1) 37 | 38 | Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the 39 | documentation from [`PretrainedConfig`] for more information. 40 | 41 | 42 | Args: 43 | vocab_size (`int`, *optional*, defaults to 32000): 44 | Vocabulary size of the Mistral model. Defines the number of different tokens that can be represented by the 45 | `inputs_ids` passed when calling [`MistralModel`] 46 | hidden_size (`int`, *optional*, defaults to 4096): 47 | Dimension of the hidden representations. 48 | intermediate_size (`int`, *optional*, defaults to 14336): 49 | Dimension of the MLP representations. 50 | num_hidden_layers (`int`, *optional*, defaults to 32): 51 | Number of hidden layers in the Transformer encoder. 52 | num_attention_heads (`int`, *optional*, defaults to 32): 53 | Number of attention heads for each attention layer in the Transformer encoder. 54 | num_key_value_heads (`int`, *optional*, defaults to 8): 55 | This is the number of key_value heads that should be used to implement Grouped Query Attention. If 56 | `num_key_value_heads=num_attention_heads`, the model will use Multi Head Attention (MHA), if 57 | `num_key_value_heads=1 the model will use Multi Query Attention (MQA) otherwise GQA is used. When 58 | converting a multi-head checkpoint to a GQA checkpoint, each group key and value head should be constructed 59 | by meanpooling all the original heads within that group. For more details checkout [this 60 | paper](https://arxiv.org/pdf/2305.13245.pdf). If it is not specified, will default to `8`. 61 | hidden_act (`str` or `function`, *optional*, defaults to `"silu"`): 62 | The non-linear activation function (function or string) in the decoder. 63 | max_position_embeddings (`int`, *optional*, defaults to `4096*32`): 64 | The maximum sequence length that this model might ever be used with. Mistral's sliding window attention 65 | allows sequence of up to 4096*32 tokens. 66 | initializer_range (`float`, *optional*, defaults to 0.02): 67 | The standard deviation of the truncated_normal_initializer for initializing all weight matrices. 68 | rms_norm_eps (`float`, *optional*, defaults to 1e-06): 69 | The epsilon used by the rms normalization layers. 70 | use_cache (`bool`, *optional*, defaults to `True`): 71 | Whether or not the model should return the last key/values attentions (not used by all models). Only 72 | relevant if `config.is_decoder=True`. 73 | pad_token_id (`int`, *optional*): 74 | The id of the padding token. 75 | bos_token_id (`int`, *optional*, defaults to 1): 76 | The id of the "beginning-of-sequence" token. 77 | eos_token_id (`int`, *optional*, defaults to 2): 78 | The id of the "end-of-sequence" token. 79 | tie_word_embeddings (`bool`, *optional*, defaults to `False`): 80 | Whether the model's input and output word embeddings should be tied. 81 | rope_theta (`float`, *optional*, defaults to 10000.0): 82 | The base period of the RoPE embeddings. 83 | sliding_window (`int`, *optional*, defaults to 4096): 84 | Sliding window attention window size. If not specified, will default to `4096`. 85 | attention_dropout (`float`, *optional*, defaults to 0.0): 86 | The dropout ratio for the attention probabilities. 87 | 88 | ```python 89 | >>> from transformers import MistralModel, MistralConfig 90 | 91 | >>> # Initializing a Mistral 7B style configuration 92 | >>> configuration = MistralConfig() 93 | 94 | >>> # Initializing a model from the Mistral 7B style configuration 95 | >>> model = MistralModel(configuration) 96 | 97 | >>> # Accessing the model configuration 98 | >>> configuration = model.config 99 | ```""" 100 | 101 | model_type = "mistral" 102 | keys_to_ignore_at_inference = ["past_key_values"] 103 | 104 | def __init__( 105 | self, 106 | vocab_size=32000, 107 | hidden_size=4096, 108 | intermediate_size=14336, 109 | num_hidden_layers=32, 110 | num_attention_heads=32, 111 | num_key_value_heads=8, 112 | hidden_act="silu", 113 | max_position_embeddings=4096 * 32, 114 | initializer_range=0.02, 115 | rms_norm_eps=1e-6, 116 | use_cache=True, 117 | pad_token_id=None, 118 | bos_token_id=1, 119 | eos_token_id=2, 120 | tie_word_embeddings=False, 121 | rope_theta=10000.0, 122 | sliding_window=4096, 123 | attention_dropout=0.0, 124 | max_thoughts=16, 125 | merged_talk_heads=True, 126 | merged_lm_and_talk_heads=False, 127 | merged_lm_and_think_heads=True, 128 | use_concat_talk_head=True, 129 | use_shallow_think=True, 130 | use_shallow_talk=False, 131 | use_complex_think_head=False, 132 | use_complex_talk_head=True, 133 | use_weighted_talk_head=True, 134 | **kwargs, 135 | ): 136 | self.vocab_size = vocab_size 137 | self.max_position_embeddings = max_position_embeddings 138 | self.hidden_size = hidden_size 139 | self.intermediate_size = intermediate_size 140 | self.num_hidden_layers = num_hidden_layers 141 | self.num_attention_heads = num_attention_heads 142 | self.sliding_window = sliding_window 143 | 144 | # for backward compatibility 145 | if num_key_value_heads is None: 146 | num_key_value_heads = num_attention_heads 147 | 148 | self.num_key_value_heads = num_key_value_heads 149 | self.hidden_act = hidden_act 150 | self.initializer_range = initializer_range 151 | self.rms_norm_eps = rms_norm_eps 152 | self.use_cache = use_cache 153 | self.rope_theta = rope_theta 154 | self.attention_dropout = attention_dropout 155 | self.max_thoughts = max_thoughts 156 | self.merged_talk_heads = merged_talk_heads 157 | self.merged_lm_and_talk_heads = merged_lm_and_talk_heads 158 | self.merged_lm_and_think_heads = merged_lm_and_think_heads 159 | self.use_concat_talk_head = use_concat_talk_head 160 | self.use_shallow_think = use_shallow_think 161 | self.use_shallow_talk = use_shallow_talk 162 | self.use_complex_think_head = use_complex_think_head 163 | self.use_complex_talk_head = use_complex_talk_head 164 | self.use_weighted_talk_head = use_weighted_talk_head 165 | 166 | super().__init__( 167 | pad_token_id=pad_token_id, 168 | bos_token_id=bos_token_id, 169 | eos_token_id=eos_token_id, 170 | tie_word_embeddings=tie_word_embeddings, 171 | **kwargs, 172 | ) 173 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright [yyyy] [name of copyright owner] 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | -------------------------------------------------------------------------------- /zero-shotcot-eval.py: -------------------------------------------------------------------------------- 1 | import torch 2 | torch.backends.cuda.matmul.allow_tf32 = True 3 | import random 4 | from transformers import AutoTokenizer, AutoModelForCausalLM 5 | from datasets import load_dataset 6 | import os 7 | import time 8 | import re 9 | from tqdm import tqdm 10 | from collections import Counter 11 | 12 | import argparse 13 | parser = argparse.ArgumentParser() 14 | parser.add_argument("--batch_idx", type=int, default=0) 15 | parser.add_argument("--baseline", action="store_true") 16 | parser.add_argument("--device_batch_size", type=int, default=8) 17 | parser.add_argument("--max_idx", type=int, default=128) 18 | parser.add_argument("--n_votes", type=int, default=8) 19 | parser.add_argument("--temp", type=float, default=0.9) 20 | parser.add_argument("--start_final_answer_idx", type=int, default=384) 21 | parser.add_argument("--answer_length", type=int, default=12) 22 | parser.add_argument("--root_prefix", type=str, default="YOUR_ROOT_HERE") 23 | parser.add_argument("--checkpoint", type=str, default="ezelikman/quietstar-8-ahead") 24 | parser.add_argument("--final_answer_text", type=str, default="\nTherefore, the answer (arabic numerals) is") 25 | parser.add_argument("--zero_shot_cot_prompt", type=str, default="\nA: Let's think step by step.") 26 | parser.add_argument("--n_ahead", type=int, default=8) 27 | args = parser.parse_args() 28 | 29 | def model_init(params): 30 | if params is None: 31 | params = {} 32 | else: 33 | params = params.params 34 | n_ahead = params.get("n_ahead", args.n_ahead if not args.baseline else 1) 35 | n_ahead_talk = 1 36 | use_start_thought_token = params.get("use_start_thought_token", True) 37 | use_end_thought_token = params.get("use_end_thought_token", True) 38 | include_policy_loss = params.get("include_policy_loss", True) 39 | gumbel_detach = params.get("gumbel_detach", True) 40 | merged_talk_heads = params.get("merged_talk_heads", True) 41 | residual_think_head = params.get("residual_think_head", False) 42 | optimize_lm_head_only_at_start = params.get("optimize_lm_head_only_at_start", False) 43 | print("Loading model") 44 | model = AutoModelForCausalLM.from_pretrained( 45 | args.checkpoint, 46 | torch_dtype=torch.bfloat16 if torch.cuda.is_available() else torch.float32, 47 | device_map='auto', 48 | cache_dir=args.root_prefix + "cache", 49 | max_thoughts=n_ahead + n_ahead_talk + 1, 50 | merged_talk_heads=merged_talk_heads, 51 | merged_lm_and_talk_heads=False, 52 | merged_lm_and_think_heads=True, 53 | use_concat_talk_head=True, 54 | use_shallow_think=True, 55 | use_shallow_talk=False, 56 | use_complex_think_head=False, 57 | use_complex_talk_head=True, 58 | use_weighted_talk_head=True, 59 | ) 60 | print("Loaded model") 61 | tokenizer = AutoTokenizer.from_pretrained("mistralai/Mistral-7B-v0.1") 62 | tokenizer.padding_side = "right" 63 | tokenizer.pad_token_id = tokenizer.eos_token_id 64 | special_tokens_to_add = [] 65 | if model.use_start_thought_token: 66 | special_tokens_to_add.append("<|startthought|>") 67 | if model.use_end_thought_token: 68 | special_tokens_to_add.append("<|endthought|>") 69 | if special_tokens_to_add: 70 | tokenizer.add_special_tokens({"additional_special_tokens": special_tokens_to_add}) 71 | model.resize_token_embeddings(len(tokenizer)) 72 | model.tokenizer = tokenizer 73 | model.gumbel_detach = gumbel_detach 74 | model.include_policy_loss = include_policy_loss 75 | model.use_end_thought_token = use_end_thought_token 76 | model.use_start_thought_token = use_start_thought_token 77 | model.n_ahead = n_ahead 78 | model.n_ahead_talk = n_ahead_talk 79 | model.n_passes = 1 80 | model.residual_think_head = residual_think_head 81 | if args.baseline: 82 | model.skip_residual = True 83 | model.cumulative_residual = False 84 | model.clever_residual = False 85 | model.base_residual = False 86 | model.optimize_lm_head_only_at_start = optimize_lm_head_only_at_start 87 | model.use_policy_loss = False 88 | model.rm_initialized = True 89 | model.first_run = False 90 | model.wandb_enabled = False 91 | model.config_params = params 92 | model.run_start = int(time.time()) 93 | model.eval_mode = True 94 | model.eval() 95 | return model 96 | 97 | def extract_first_integer(s): 98 | match = re.search(r'\d+', s.replace(',', '')) 99 | if match: 100 | return int(match.group()) 101 | return None 102 | 103 | # Set random seeds for reproducibility 104 | random_seed = 42 105 | torch.manual_seed(random_seed) 106 | random.seed(random_seed) 107 | 108 | # Load the GSM8K dataset and the model 109 | cot_dataset_gsm = load_dataset("gsm8k", "main", split="test", ignore_verifications=True).shuffle(seed=random_seed) 110 | model = model_init(None) 111 | 112 | start_question = args.device_batch_size * args.batch_idx 113 | end_question = args.device_batch_size * (args.batch_idx + 1) 114 | # Iterate over the questions for the current device 115 | batch_size = 1 116 | for batch_start in tqdm(range(start_question, min(args.max_idx, end_question), batch_size)): 117 | last_save_folder = f"answers/eval_{'baseline' if args.baseline else 'ft'}_{args.n_ahead if not args.baseline else 1}_{args.temp}_{args.n_votes}" 118 | if os.path.exists(last_save_folder + f"/{batch_start}.txt"): 119 | print(f"Skipping {batch_start}") 120 | continue 121 | extracted_answers = [] 122 | for vote_idx in range(1, args.n_votes + 1): 123 | folder_name = f"answers/eval_{'baseline' if args.baseline else 'ft'}_{args.n_ahead if not args.baseline else 1}_{args.temp}_{vote_idx}" 124 | if not os.path.exists(folder_name): 125 | os.makedirs(folder_name) 126 | 127 | # Get the current batch of questions 128 | batch_questions = cot_dataset_gsm[batch_start:batch_start+batch_size] 129 | input_texts = ["Q: " + q + args.zero_shot_cot_prompt for q in batch_questions["question"]] 130 | input_ids = model.tokenizer(input_texts, return_tensors="pt", padding=True).to(model.device) 131 | attention_mask = input_ids.attention_mask 132 | input_ids = input_ids.input_ids 133 | started_generating_answer_at = None 134 | 135 | # Generate the solution 136 | with torch.no_grad(): 137 | finished_generating = torch.zeros(len(input_ids), dtype=torch.bool, device=input_ids.device) 138 | for cur_token_idx in range(args.start_final_answer_idx + args.answer_length): 139 | # Sample the next token 140 | new_ids = model( 141 | input_ids[~finished_generating], 142 | attention_mask=attention_mask[~finished_generating] 143 | )['logits'] 144 | # Mask out the start and end thought tokens so we don't accidentally sample them 145 | new_ids[:, :, model.tokenizer.vocab_size:] = -float("inf") 146 | for list_idx, answer_idx in enumerate((~finished_generating).nonzero(as_tuple=True)[0]): 147 | # Find the index of the last token that is not padding 148 | base_answer_ids = input_ids[answer_idx] 149 | new_answer_ids = new_ids[list_idx] 150 | last_token_idx = (base_answer_ids != model.tokenizer.pad_token_id).nonzero(as_tuple=True)[0].max() 151 | if args.temp == 0: 152 | new_ids_sampled = torch.argmax(new_answer_ids[last_token_idx]).unsqueeze(0) 153 | else: 154 | new_ids_sampled = torch.multinomial(torch.nn.functional.softmax(new_answer_ids[last_token_idx] / args.temp, dim=-1), 1) 155 | # Assign the new id to the last token 156 | if last_token_idx + 1 >= len(base_answer_ids): 157 | # Add padding everywhere 158 | new_padding = torch.full((len(input_ids), 1), model.tokenizer.pad_token_id, dtype=torch.long, device=input_ids.device) 159 | input_ids = torch.cat([input_ids, new_padding], dim=-1) 160 | attention_mask = torch.cat([attention_mask, torch.zeros_like(new_padding)], dim=-1) 161 | attention_mask[answer_idx, last_token_idx + 1] = 1 162 | input_ids[answer_idx, last_token_idx + 1] = new_ids_sampled 163 | if new_ids_sampled == model.tokenizer.eos_token_id or new_ids_sampled == model.tokenizer.bos_token_id or new_ids_sampled == model.tokenizer.pad_token_id: 164 | finished_generating[answer_idx] = 1 165 | # "if "Q:" shows up multiple times, remove the last "Q:" and everything after it 166 | decoded = model.tokenizer.decode(input_ids[answer_idx], skip_special_tokens=True) 167 | end_strs = ["Q:", "\n\n\n"] 168 | if any([decoded.count(end_str) > 1 for end_str in end_strs]): 169 | # Get the first end_str that shows up in the decoded text multiple times 170 | end_str = next(end_str for end_str in end_strs if decoded.count(end_str) > 1) 171 | # Remove the last "Q:" and everything after it 172 | decoded = decoded.split(end_str)[:-1] 173 | new_answer = model.tokenizer.encode(decoded, return_tensors="pt").to(model.device) 174 | input_ids[answer_idx] = torch.ones_like(input_ids[answer_idx]) * model.tokenizer.pad_token_id 175 | input_ids[answer_idx, :new_answer.shape[1]] = new_answer 176 | attention_mask[answer_idx] = (input_ids[answer_idx] != model.tokenizer.pad_token_id).long() 177 | finished_generating[answer_idx] = 1 178 | 179 | # Check if we should start generating the final answer 180 | if ( 181 | (cur_token_idx == args.start_final_answer_idx and started_generating_answer_at is None) 182 | or finished_generating.all() 183 | ): 184 | # If we haven't started generating the final answer yet, start now 185 | if started_generating_answer_at is None: 186 | finished_generating = torch.zeros(len(input_ids), dtype=torch.bool, device=input_ids.device) 187 | started_generating_answer_at = cur_token_idx 188 | # Append "Final Answer:" to the end of the generated text 189 | base_texts = [model.tokenizer.decode(ids, skip_special_tokens=True) for ids in input_ids] 190 | final_texts = [text.rstrip() + args.final_answer_text for text in base_texts] 191 | encoded_final_texts = model.tokenizer(final_texts, return_tensors="pt", padding=True).to(model.device) 192 | attention_mask = encoded_final_texts.attention_mask 193 | input_ids = encoded_final_texts.input_ids 194 | else: 195 | # We finished generating the answer 196 | break 197 | 198 | if started_generating_answer_at is not None: 199 | if cur_token_idx - started_generating_answer_at > args.answer_length: 200 | break 201 | 202 | # Collect the generated answers for evaluation 203 | for i, encoded_final_text in enumerate(input_ids): 204 | question_idx = batch_start + i 205 | decoded_text = model.tokenizer.decode(encoded_final_text, skip_special_tokens=True) 206 | vote_extracted_number = decoded_text.split(args.final_answer_text)[-1] 207 | # Extract the first number from the answer text 208 | vote_extracted_number = extract_first_integer(vote_extracted_number) 209 | extracted_correct_answer = extract_first_integer(cot_dataset_gsm[question_idx]["answer"].split("#### ")[-1]) 210 | extracted_answers.append((vote_extracted_number, extracted_correct_answer, decoded_text)) 211 | 212 | # Save the current to vote_idx folder 213 | extracted_number = Counter([extracted_number for extracted_number, _, _ in extracted_answers]) 214 | extracted_most_common = extracted_number.most_common(1)[0][0] 215 | correct = extracted_most_common == extracted_answers[0][1] 216 | print(f"Question {batch_start + i} - Correct: {correct} - Extracted: {extracted_number} - True: {extracted_correct_answer}") 217 | joined_final_texts = ("\n" + "=" * 100 + "\n").join([decoded for _, _, decoded in extracted_answers]) 218 | save_filename = f"{folder_name}/{batch_start}.txt" 219 | with open(save_filename, "w") as f: 220 | f.write(joined_final_texts + "\n" + "Extracted: " + str(extracted_most_common) + "\n" + "True: " + str(extracted_correct_answer) + "\n" + "Correct: " + str(correct)) 221 | -------------------------------------------------------------------------------- /modeling_mistral.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2023 Mistral AI and the HuggingFace Inc. team. All rights reserved. 3 | # 4 | # This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX 5 | # and OPT implementations in this library. It has been modified from its 6 | # original forms to accommodate minor architectural differences compared 7 | # to GPT-NeoX and OPT used by the Meta AI team that trained the model. 8 | # 9 | # Licensed under the Apache License, Version 2.0 (the "License"); 10 | # you may not use this file except in compliance with the License. 11 | # You may obtain a copy of the License at 12 | # 13 | # http://www.apache.org/licenses/LICENSE-2.0 14 | # 15 | # Unless required by applicable law or agreed to in writing, software 16 | # distributed under the License is distributed on an "AS IS" BASIS, 17 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 18 | # See the License for the specific language governing permissions and 19 | # limitations under the License. 20 | """ PyTorch Mistral model.""" 21 | import inspect 22 | import math 23 | import copy 24 | import os 25 | import time 26 | import pandas as pd 27 | import seaborn as sns 28 | import matplotlib.pyplot as plt 29 | import wandb 30 | from termcolor import colored 31 | from tqdm import tqdm 32 | import random 33 | import numpy as np 34 | from matplotlib.colors import LinearSegmentedColormap, LogNorm 35 | import warnings 36 | from collections import defaultdict 37 | from typing import List, Optional, Tuple, Union 38 | 39 | import torch 40 | import torch.nn.functional as F 41 | import torch.utils.checkpoint 42 | from torch import nn 43 | from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss 44 | 45 | from ...activations import ACT2FN 46 | from ...cache_utils import Cache, DynamicCache 47 | from ...modeling_attn_mask_utils import _prepare_4d_causal_attention_mask, _prepare_4d_causal_attention_mask_for_sdpa 48 | from ...modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast, SequenceClassifierOutputWithPast 49 | from ...modeling_utils import PreTrainedModel 50 | from ...utils import ( 51 | add_start_docstrings, 52 | add_start_docstrings_to_model_forward, 53 | is_flash_attn_2_available, 54 | is_flash_attn_greater_or_equal_2_10, 55 | logging, 56 | replace_return_docstrings, 57 | ) 58 | from .configuration_mistral import MistralConfig 59 | 60 | 61 | if is_flash_attn_2_available(): 62 | from flash_attn import flash_attn_func, flash_attn_varlen_func 63 | from flash_attn.bert_padding import index_first_axis, pad_input, unpad_input # noqa 64 | 65 | _flash_supports_window_size = "window_size" in list(inspect.signature(flash_attn_func).parameters) 66 | 67 | 68 | logger = logging.get_logger(__name__) 69 | 70 | _CONFIG_FOR_DOC = "MistralConfig" 71 | 72 | from reportlab.pdfgen import canvas 73 | from reportlab.lib.pagesizes import letter 74 | from reportlab.lib.colors import HexColor 75 | 76 | def save_tokens_with_rewards_to_pdf(input_ids, token_rewards, tokenizer, output_file="text.pdf", eps=0.2, eps2=0.5): 77 | c = canvas.Canvas(output_file, pagesize=letter) 78 | c.setFont("Courier", 8) 79 | x, y = 50, 750 80 | previous_text = "" 81 | current_text = "" 82 | for token_idx, reward in enumerate(token_rewards): 83 | current_text = tokenizer.decode(input_ids[: token_idx + 1]) 84 | if current_text != previous_text: 85 | diff_text = current_text[len(previous_text) :] 86 | if "\n" in diff_text: 87 | lines = diff_text.split("\n") 88 | for line_idx, line in enumerate(lines): 89 | if line_idx > 0: 90 | x = 50 91 | y -= 12 92 | if abs(reward) < eps: 93 | opacity = 0 94 | elif abs(reward) > eps2: 95 | opacity = 0.8 96 | else: 97 | opacity = 0.8 * (abs(reward) - eps) / (eps2 - eps) 98 | text_width = c.stringWidth(line) 99 | if reward > 0: 100 | highlight_color = HexColor("#4CCD99") 101 | else: 102 | highlight_color = HexColor("#FFC700") 103 | highlight_color.alpha = opacity 104 | c.setFillColor(highlight_color) 105 | c.rect(x, y - 2, text_width, 10, fill=True, stroke=False) 106 | c.setFillColor(HexColor("#000000")) 107 | c.drawString(x, y, line) 108 | x += text_width 109 | else: 110 | if abs(reward) < eps: 111 | opacity = 0 112 | elif abs(reward) > eps2: 113 | opacity = 0.8 114 | else: 115 | opacity = 0.8 * (abs(reward) - eps) / (eps2 - eps) 116 | text_width = c.stringWidth(diff_text) 117 | if reward > 0: 118 | highlight_color = HexColor("#4CCD99") 119 | else: 120 | highlight_color = HexColor("#FFC700") 121 | highlight_color.alpha = opacity 122 | c.setFillColor(highlight_color) 123 | c.rect(x, y - 2, text_width, 10, fill=True, stroke=False) 124 | c.setFillColor(HexColor("#000000")) 125 | c.drawString(x, y, diff_text) 126 | x += text_width 127 | if x > 550: 128 | x = 50 129 | y -= 12 130 | if y < 50: 131 | c.showPage() 132 | y = 750 133 | x = 50 134 | previous_text = current_text 135 | c.showPage() 136 | c.save() 137 | 138 | 139 | # Copied from transformers.models.llama.modeling_llama._get_unpad_data 140 | def _get_unpad_data(attention_mask): 141 | seqlens_in_batch = attention_mask.sum(dim=-1, dtype=torch.int32) 142 | indices = torch.nonzero(attention_mask.flatten(), as_tuple=False).flatten() 143 | max_seqlen_in_batch = seqlens_in_batch.max().item() 144 | cu_seqlens = F.pad(torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.torch.int32), (1, 0)) 145 | return ( 146 | indices, 147 | cu_seqlens, 148 | max_seqlen_in_batch, 149 | ) 150 | 151 | 152 | # Copied from transformers.models.llama.modeling_llama.LlamaRMSNorm with Llama->Mistral 153 | class MistralRMSNorm(nn.Module): 154 | def __init__(self, hidden_size, eps=1e-6): 155 | """ 156 | MistralRMSNorm is equivalent to T5LayerNorm 157 | """ 158 | super().__init__() 159 | self.weight = nn.Parameter(torch.ones(hidden_size)) 160 | self.variance_epsilon = eps 161 | 162 | def forward(self, hidden_states): 163 | input_dtype = hidden_states.dtype 164 | hidden_states = hidden_states.to(torch.float32) 165 | variance = hidden_states.pow(2).mean(-1, keepdim=True) 166 | hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon) 167 | return hidden_states.to(input_dtype) * self.weight.to(hidden_states.device) 168 | 169 | 170 | # Copied from transformers.models.llama.modeling_llama.LlamaRotaryEmbedding with Llama->Mistral 171 | class MistralRotaryEmbedding(nn.Module): 172 | def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None): 173 | super().__init__() 174 | 175 | self.dim = dim 176 | self.max_position_embeddings = max_position_embeddings 177 | self.base = base 178 | inv_freq = 1.0 / (self.base ** (torch.arange(0, self.dim, 2).float().to(device) / self.dim)) 179 | self.register_buffer("inv_freq", inv_freq, persistent=False) 180 | 181 | # Build here to make `torch.jit.trace` work. 182 | self._set_cos_sin_cache( 183 | seq_len=max_position_embeddings, device=self.inv_freq.device, dtype=torch.get_default_dtype() 184 | ) 185 | 186 | def _set_cos_sin_cache(self, seq_len, device, dtype): 187 | self.max_seq_len_cached = seq_len 188 | t = torch.arange(self.max_seq_len_cached, device=device, dtype=self.inv_freq.dtype) 189 | 190 | freqs = torch.outer(t, self.inv_freq) 191 | # Different from paper, but it uses a different permutation in order to obtain the same calculation 192 | emb = torch.cat((freqs, freqs), dim=-1) 193 | self.register_buffer("cos_cached", emb.cos().to(dtype), persistent=False) 194 | self.register_buffer("sin_cached", emb.sin().to(dtype), persistent=False) 195 | 196 | def forward(self, x, seq_len=None): 197 | # x: [bs, num_attention_heads, seq_len, head_size] 198 | if seq_len > self.max_seq_len_cached: 199 | self._set_cos_sin_cache(seq_len=seq_len, device=x.device, dtype=x.dtype) 200 | 201 | return ( 202 | self.cos_cached[:seq_len].to(dtype=x.dtype), 203 | self.sin_cached[:seq_len].to(dtype=x.dtype), 204 | ) 205 | 206 | 207 | # Copied from transformers.models.llama.modeling_llama.rotate_half 208 | def rotate_half(x): 209 | """Rotates half the hidden dims of the input.""" 210 | x1 = x[..., : x.shape[-1] // 2] 211 | x2 = x[..., x.shape[-1] // 2 :] 212 | return torch.cat((-x2, x1), dim=-1) 213 | 214 | 215 | # Copied from transformers.models.llama.modeling_llama.apply_rotary_pos_emb 216 | def apply_rotary_pos_emb(q, k, cos, sin, position_ids, unsqueeze_dim=1): 217 | """Applies Rotary Position Embedding to the query and key tensors. 218 | 219 | Args: 220 | q (`torch.Tensor`): The query tensor. 221 | k (`torch.Tensor`): The key tensor. 222 | cos (`torch.Tensor`): The cosine part of the rotary embedding. 223 | sin (`torch.Tensor`): The sine part of the rotary embedding. 224 | position_ids (`torch.Tensor`): 225 | The position indices of the tokens corresponding to the query and key tensors. For example, this can be 226 | used to pass offsetted position ids when working with a KV-cache. 227 | unsqueeze_dim (`int`, *optional*, defaults to 1): 228 | The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and 229 | sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note 230 | that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and 231 | k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes 232 | cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have 233 | the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2. 234 | Returns: 235 | `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding. 236 | """ 237 | cos = cos[position_ids].unsqueeze(unsqueeze_dim) 238 | sin = sin[position_ids].unsqueeze(unsqueeze_dim) 239 | q_embed = (q * cos) + (rotate_half(q) * sin) 240 | k_embed = (k * cos) + (rotate_half(k) * sin) 241 | return q_embed, k_embed 242 | 243 | 244 | class MistralMLP(nn.Module): 245 | def __init__(self, config): 246 | super().__init__() 247 | self.config = config 248 | self.hidden_size = config.hidden_size 249 | self.intermediate_size = config.intermediate_size 250 | self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False) 251 | self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False) 252 | self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False) 253 | self.act_fn = ACT2FN[config.hidden_act] 254 | 255 | def forward(self, x): 256 | return self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x)) 257 | 258 | 259 | # Copied from transformers.models.llama.modeling_llama.repeat_kv 260 | def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor: 261 | """ 262 | This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch, 263 | num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim) 264 | """ 265 | batch, num_key_value_heads, slen, head_dim = hidden_states.shape 266 | if n_rep == 1: 267 | return hidden_states 268 | hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim) 269 | return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim) 270 | 271 | 272 | class MistralAttention(nn.Module): 273 | """ 274 | Multi-headed attention from 'Attention Is All You Need' paper. Modified to use sliding window attention: Longformer 275 | and "Generating Long Sequences with Sparse Transformers". 276 | """ 277 | 278 | def __init__(self, config: MistralConfig, layer_idx: Optional[int] = None): 279 | super().__init__() 280 | self.config = config 281 | self.layer_idx = layer_idx 282 | if layer_idx is None: 283 | logger.warning_once( 284 | f"Instantiating {self.__class__.__name__} without passing `layer_idx` is not recommended and will " 285 | "to errors during the forward call, if caching is used. Please make sure to provide a `layer_idx` " 286 | "when creating this class." 287 | ) 288 | 289 | self.hidden_size = config.hidden_size 290 | self.num_heads = config.num_attention_heads 291 | self.head_dim = self.hidden_size // self.num_heads 292 | self.num_key_value_heads = config.num_key_value_heads 293 | self.num_key_value_groups = self.num_heads // self.num_key_value_heads 294 | self.max_position_embeddings = config.max_position_embeddings 295 | self.rope_theta = config.rope_theta 296 | self.is_causal = True 297 | self.attention_dropout = config.attention_dropout 298 | 299 | if (self.head_dim * self.num_heads) != self.hidden_size: 300 | raise ValueError( 301 | f"hidden_size must be divisible by num_heads (got `hidden_size`: {self.hidden_size}" 302 | f" and `num_heads`: {self.num_heads})." 303 | ) 304 | self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=False) 305 | self.k_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=False) 306 | self.v_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=False) 307 | self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=False) 308 | 309 | self.rotary_emb = MistralRotaryEmbedding( 310 | self.head_dim, 311 | max_position_embeddings=self.max_position_embeddings, 312 | base=self.rope_theta, 313 | ) 314 | 315 | def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int): 316 | return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous() 317 | 318 | def forward( 319 | self, 320 | hidden_states: torch.Tensor, 321 | attention_mask: Optional[torch.Tensor] = None, 322 | position_ids: Optional[torch.LongTensor] = None, 323 | past_key_value: Optional[Cache] = None, 324 | output_attentions: bool = False, 325 | use_cache: bool = False, 326 | **kwargs, 327 | ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: 328 | if "padding_mask" in kwargs: 329 | warnings.warn( 330 | "Passing `padding_mask` is deprecated and will be removed in v4.37. Please make sure use `attention_mask` instead.`" 331 | ) 332 | bsz, q_len, _ = hidden_states.size() 333 | 334 | query_states = self.q_proj(hidden_states) 335 | key_states = self.k_proj(hidden_states) 336 | value_states = self.v_proj(hidden_states) 337 | 338 | query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) 339 | key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) 340 | value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) 341 | 342 | kv_seq_len = key_states.shape[-2] 343 | if past_key_value is not None: 344 | if self.layer_idx is None: 345 | raise ValueError( 346 | f"The cache structure has changed since version v4.36. If you are using {self.__class__.__name__} " 347 | "for auto-regressive decoding with k/v caching, please make sure to initialize the attention class " 348 | "with a layer index." 349 | ) 350 | kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx) 351 | cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len) 352 | query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids) 353 | 354 | if past_key_value is not None: 355 | cache_kwargs = {"sin": sin, "cos": cos} # Specific to RoPE models 356 | key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) 357 | 358 | # repeat k/v heads if n_kv_heads < n_heads 359 | key_states = repeat_kv(key_states, self.num_key_value_groups) 360 | value_states = repeat_kv(value_states, self.num_key_value_groups) 361 | 362 | attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim) 363 | 364 | if attn_weights.size() != (bsz, self.num_heads, q_len, kv_seq_len): 365 | raise ValueError( 366 | f"Attention weights should be of size {(bsz, self.num_heads, q_len, kv_seq_len)}, but is" 367 | f" {attn_weights.size()}" 368 | ) 369 | 370 | if attention_mask is not None: 371 | if attention_mask.size() != (bsz, 1, q_len, kv_seq_len): 372 | raise ValueError( 373 | f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}" 374 | ) 375 | 376 | attn_weights = attn_weights + attention_mask 377 | 378 | # upcast attention to fp32 379 | attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype) 380 | attn_weights = nn.functional.dropout(attn_weights, p=self.attention_dropout, training=self.training) 381 | attn_output = torch.matmul(attn_weights, value_states) 382 | 383 | if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim): 384 | raise ValueError( 385 | f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is" 386 | f" {attn_output.size()}" 387 | ) 388 | 389 | attn_output = attn_output.transpose(1, 2).contiguous() 390 | attn_output = attn_output.reshape(bsz, q_len, self.hidden_size) 391 | 392 | attn_output = self.o_proj(attn_output) 393 | 394 | if not output_attentions: 395 | attn_weights = None 396 | 397 | return attn_output, attn_weights, past_key_value 398 | 399 | 400 | class MistralFlashAttention2(MistralAttention): 401 | """ 402 | Mistral flash attention module. This module inherits from `MistralAttention` as the weights of the module stays 403 | untouched. The only required change would be on the forward pass where it needs to correctly call the public API of 404 | flash attention and deal with padding tokens in case the input contains any of them. 405 | """ 406 | 407 | # Copied from transformers.models.llama.modeling_llama.LlamaFlashAttention2.__init__ 408 | def __init__(self, *args, **kwargs): 409 | super().__init__(*args, **kwargs) 410 | 411 | # TODO: Should be removed once Flash Attention for RoCm is bumped to 2.1. 412 | # flash_attn<2.1 generates top-left aligned causal mask, while what is needed here is bottom-right alignement, that was made default for flash_attn>=2.1. This attribute is used to handle this difference. Reference: https://github.com/Dao-AILab/flash-attention/releases/tag/v2.1.0. 413 | # Beware that with flash_attn<2.1, using q_seqlen != k_seqlen (except for the case q_seqlen == 1) produces a wrong mask (top-left). 414 | self._flash_attn_uses_top_left_mask = not is_flash_attn_greater_or_equal_2_10() 415 | 416 | def forward( 417 | self, 418 | hidden_states: torch.Tensor, 419 | attention_mask: Optional[torch.Tensor] = None, 420 | position_ids: Optional[torch.LongTensor] = None, 421 | past_key_value: Optional[Cache] = None, 422 | output_attentions: bool = False, 423 | use_cache: bool = False, 424 | **kwargs, 425 | ): 426 | if "padding_mask" in kwargs: 427 | warnings.warn( 428 | "Passing `padding_mask` is deprecated and will be removed in v4.37. Please make sure use `attention_mask` instead.`" 429 | ) 430 | 431 | # overwrite attention_mask with padding_mask 432 | attention_mask = kwargs.pop("padding_mask") 433 | bsz, q_len, _ = hidden_states.size() 434 | 435 | query_states = self.q_proj(hidden_states) 436 | key_states = self.k_proj(hidden_states) 437 | value_states = self.v_proj(hidden_states) 438 | 439 | query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) 440 | key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) 441 | value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) 442 | 443 | kv_seq_len = key_states.shape[-2] 444 | if past_key_value is not None: 445 | if self.layer_idx is None: 446 | raise ValueError( 447 | f"The cache structure has changed since version v4.36. If you are using {self.__class__.__name__} " 448 | "for auto-regressive decoding with k/v caching, please make sure to initialize the attention class " 449 | "with a layer index." 450 | ) 451 | kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx) 452 | 453 | # Because the input can be padded, the absolute sequence length depends on the max position id. 454 | rotary_seq_len = max(kv_seq_len, position_ids[:, -1].max().item()) + 1 455 | cos, sin = self.rotary_emb(value_states, seq_len=rotary_seq_len) 456 | 457 | query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids) 458 | 459 | use_sliding_windows = ( 460 | _flash_supports_window_size 461 | and getattr(self.config, "sliding_window", None) is not None 462 | and kv_seq_len > self.config.sliding_window 463 | ) 464 | 465 | if not _flash_supports_window_size: 466 | logger.warning_once( 467 | "The current flash attention version does not support sliding window attention, for a more memory efficient implementation" 468 | " make sure to upgrade flash-attn library." 469 | ) 470 | 471 | if past_key_value is not None: 472 | # Activate slicing cache only if the config has a value `sliding_windows` attribute 473 | cache_has_contents = past_key_value.get_seq_length(self.layer_idx) > 0 474 | if ( 475 | getattr(self.config, "sliding_window", None) is not None 476 | and kv_seq_len > self.config.sliding_window 477 | and cache_has_contents 478 | ): 479 | slicing_tokens = 1 - self.config.sliding_window 480 | 481 | past_key = past_key_value[self.layer_idx][0] 482 | past_value = past_key_value[self.layer_idx][1] 483 | 484 | past_key = past_key[:, :, slicing_tokens:, :].contiguous() 485 | past_value = past_value[:, :, slicing_tokens:, :].contiguous() 486 | 487 | if past_key.shape[-2] != self.config.sliding_window - 1: 488 | raise ValueError( 489 | f"past key must have a shape of (`batch_size, num_heads, self.config.sliding_window-1, head_dim`), got" 490 | f" {past_key.shape}" 491 | ) 492 | 493 | if attention_mask is not None: 494 | attention_mask = attention_mask[:, slicing_tokens:] 495 | attention_mask = torch.cat([attention_mask, torch.ones_like(attention_mask[:, -1:])], dim=-1) 496 | 497 | cache_kwargs = {"sin": sin, "cos": cos} # Specific to RoPE models 498 | key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) 499 | 500 | # repeat k/v heads if n_kv_heads < n_heads 501 | key_states = repeat_kv(key_states, self.num_key_value_groups) 502 | value_states = repeat_kv(value_states, self.num_key_value_groups) 503 | dropout_rate = 0.0 if not self.training else self.attention_dropout 504 | 505 | # In PEFT, usually we cast the layer norms in float32 for training stability reasons 506 | # therefore the input hidden states gets silently casted in float32. Hence, we need 507 | # cast them back in float16 just to be sure everything works as expected. 508 | input_dtype = query_states.dtype 509 | if input_dtype == torch.float32: 510 | if torch.is_autocast_enabled(): 511 | target_dtype = torch.get_autocast_gpu_dtype() 512 | # Handle the case where the model is quantized 513 | elif hasattr(self.config, "_pre_quantization_dtype"): 514 | target_dtype = self.config._pre_quantization_dtype 515 | else: 516 | target_dtype = self.q_proj.weight.dtype 517 | 518 | logger.warning_once( 519 | f"The input hidden states seems to be silently casted in float32, this might be related to" 520 | f" the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in" 521 | f" {target_dtype}." 522 | ) 523 | 524 | query_states = query_states.to(target_dtype) 525 | key_states = key_states.to(target_dtype) 526 | value_states = value_states.to(target_dtype) 527 | 528 | # Reashape to the expected shape for Flash Attention 529 | query_states = query_states.transpose(1, 2) 530 | key_states = key_states.transpose(1, 2) 531 | value_states = value_states.transpose(1, 2) 532 | 533 | attn_output = self._flash_attention_forward( 534 | query_states, 535 | key_states, 536 | value_states, 537 | attention_mask, 538 | q_len, 539 | dropout=dropout_rate, 540 | use_sliding_windows=use_sliding_windows, 541 | ) 542 | 543 | attn_output = attn_output.reshape(bsz, q_len, self.hidden_size).contiguous() 544 | attn_output = self.o_proj(attn_output) 545 | 546 | if not output_attentions: 547 | attn_weights = None 548 | 549 | return attn_output, attn_weights, past_key_value 550 | 551 | def _flash_attention_forward( 552 | self, 553 | query_states, 554 | key_states, 555 | value_states, 556 | attention_mask, 557 | query_length, 558 | dropout=0.0, 559 | softmax_scale=None, 560 | use_sliding_windows=False, 561 | ): 562 | """ 563 | Calls the forward method of Flash Attention - if the input hidden states contain at least one padding token 564 | first unpad the input, then computes the attention scores and pad the final attention scores. 565 | 566 | Args: 567 | query_states (`torch.Tensor`): 568 | Input query states to be passed to Flash Attention API 569 | key_states (`torch.Tensor`): 570 | Input key states to be passed to Flash Attention API 571 | value_states (`torch.Tensor`): 572 | Input value states to be passed to Flash Attention API 573 | attention_mask (`torch.Tensor`): 574 | The padding mask - corresponds to a tensor of size `(batch_size, seq_len)` where 0 stands for the 575 | position of padding tokens and 1 for the position of non-padding tokens. 576 | dropout (`int`, *optional*): 577 | Attention dropout 578 | softmax_scale (`float`, *optional*): 579 | The scaling of QK^T before applying softmax. Default to 1 / sqrt(head_dim) 580 | use_sliding_windows (`bool`, *optional*): 581 | Whether to activate sliding window attention. 582 | """ 583 | if not self._flash_attn_uses_top_left_mask: 584 | causal = self.is_causal 585 | else: 586 | # TODO: Remove the `query_length != 1` check once Flash Attention for RoCm is bumped to 2.1. For details, please see the comment in LlamaFlashAttention2 __init__. 587 | causal = self.is_causal and query_length != 1 588 | 589 | # Contains at least one padding token in the sequence 590 | if attention_mask is not None: 591 | batch_size = query_states.shape[0] 592 | query_states, key_states, value_states, indices_q, cu_seq_lens, max_seq_lens = self._upad_input( 593 | query_states, key_states, value_states, attention_mask, query_length 594 | ) 595 | 596 | cu_seqlens_q, cu_seqlens_k = cu_seq_lens 597 | max_seqlen_in_batch_q, max_seqlen_in_batch_k = max_seq_lens 598 | 599 | if not use_sliding_windows: 600 | attn_output_unpad = flash_attn_varlen_func( 601 | query_states, 602 | key_states, 603 | value_states, 604 | cu_seqlens_q=cu_seqlens_q, 605 | cu_seqlens_k=cu_seqlens_k, 606 | max_seqlen_q=max_seqlen_in_batch_q, 607 | max_seqlen_k=max_seqlen_in_batch_k, 608 | dropout_p=dropout, 609 | softmax_scale=softmax_scale, 610 | causal=causal, 611 | ) 612 | else: 613 | attn_output_unpad = flash_attn_varlen_func( 614 | query_states, 615 | key_states, 616 | value_states, 617 | cu_seqlens_q=cu_seqlens_q, 618 | cu_seqlens_k=cu_seqlens_k, 619 | max_seqlen_q=max_seqlen_in_batch_q, 620 | max_seqlen_k=max_seqlen_in_batch_k, 621 | dropout_p=dropout, 622 | softmax_scale=softmax_scale, 623 | causal=causal, 624 | window_size=(self.config.sliding_window, self.config.sliding_window), 625 | ) 626 | 627 | attn_output = pad_input(attn_output_unpad, indices_q, batch_size, query_length) 628 | else: 629 | if not use_sliding_windows: 630 | attn_output = flash_attn_func( 631 | query_states, 632 | key_states, 633 | value_states, 634 | dropout, 635 | softmax_scale=softmax_scale, 636 | causal=causal, 637 | ) 638 | else: 639 | attn_output = flash_attn_func( 640 | query_states, 641 | key_states, 642 | value_states, 643 | dropout, 644 | softmax_scale=softmax_scale, 645 | causal=causal, 646 | window_size=(self.config.sliding_window, self.config.sliding_window), 647 | ) 648 | 649 | return attn_output 650 | 651 | def _upad_input(self, query_layer, key_layer, value_layer, attention_mask, query_length): 652 | batch_size, kv_seq_len, num_heads, head_dim = key_layer.shape 653 | 654 | # On the first iteration we need to properly re-create the padding mask 655 | # by slicing it on the proper place 656 | if kv_seq_len != attention_mask.shape[-1]: 657 | attention_mask_num_tokens = attention_mask.shape[-1] 658 | attention_mask = attention_mask[:, attention_mask_num_tokens - kv_seq_len :] 659 | 660 | indices_k, cu_seqlens_k, max_seqlen_in_batch_k = _get_unpad_data(attention_mask) 661 | 662 | key_layer = index_first_axis(key_layer.reshape(batch_size * kv_seq_len, num_heads, head_dim), indices_k) 663 | value_layer = index_first_axis(value_layer.reshape(batch_size * kv_seq_len, num_heads, head_dim), indices_k) 664 | 665 | if query_length == kv_seq_len: 666 | query_layer = index_first_axis( 667 | query_layer.reshape(batch_size * kv_seq_len, num_heads, head_dim), indices_k 668 | ) 669 | cu_seqlens_q = cu_seqlens_k 670 | max_seqlen_in_batch_q = max_seqlen_in_batch_k 671 | indices_q = indices_k 672 | elif query_length == 1: 673 | max_seqlen_in_batch_q = 1 674 | cu_seqlens_q = torch.arange( 675 | batch_size + 1, dtype=torch.int32, device=query_layer.device 676 | ) # There is a memcpy here, that is very bad. 677 | indices_q = cu_seqlens_q[:-1] 678 | query_layer = query_layer.squeeze(1) 679 | else: 680 | # The -q_len: slice assumes left padding. 681 | attention_mask = attention_mask[:, -query_length:] 682 | query_layer, indices_q, cu_seqlens_q, max_seqlen_in_batch_q = unpad_input(query_layer, attention_mask) 683 | 684 | return ( 685 | query_layer, 686 | key_layer, 687 | value_layer, 688 | indices_q, 689 | (cu_seqlens_q, cu_seqlens_k), 690 | (max_seqlen_in_batch_q, max_seqlen_in_batch_k), 691 | ) 692 | 693 | 694 | # Copied from transformers.models.llama.modeling_llama.LlamaSdpaAttention with Llama->Mistral 695 | class MistralSdpaAttention(MistralAttention): 696 | """ 697 | Mistral attention module using torch.nn.functional.scaled_dot_product_attention. This module inherits from 698 | `MistralAttention` as the weights of the module stays untouched. The only changes are on the forward pass to adapt to 699 | SDPA API. 700 | """ 701 | 702 | # Adapted from MistralAttention.forward 703 | def forward( 704 | self, 705 | hidden_states: torch.Tensor, 706 | attention_mask: Optional[torch.Tensor] = None, 707 | position_ids: Optional[torch.LongTensor] = None, 708 | past_key_value: Optional[Cache] = None, 709 | output_attentions: bool = False, 710 | use_cache: bool = False, 711 | ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: 712 | if output_attentions: 713 | # TODO: Improve this warning with e.g. `model.config.attn_implementation = "manual"` once this is implemented. 714 | logger.warning_once( 715 | "MistralModel is using MistralSdpaAttention, but `torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to the manual attention implementation, " 716 | 'but specifying the manual implementation will be required from Transformers version v5.0.0 onwards. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.' 717 | ) 718 | return super().forward( 719 | hidden_states=hidden_states, 720 | attention_mask=attention_mask, 721 | position_ids=position_ids, 722 | past_key_value=past_key_value, 723 | output_attentions=output_attentions, 724 | use_cache=use_cache, 725 | ) 726 | 727 | bsz, q_len, _ = hidden_states.size() 728 | 729 | query_states = self.q_proj(hidden_states) 730 | key_states = self.k_proj(hidden_states) 731 | value_states = self.v_proj(hidden_states) 732 | 733 | query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) 734 | key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) 735 | value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) 736 | 737 | kv_seq_len = key_states.shape[-2] 738 | if past_key_value is not None: 739 | kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx) 740 | cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len) 741 | 742 | query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids) 743 | 744 | if past_key_value is not None: 745 | cache_kwargs = {"sin": sin, "cos": cos} # Specific to RoPE models 746 | key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) 747 | 748 | key_states = repeat_kv(key_states, self.num_key_value_groups) 749 | value_states = repeat_kv(value_states, self.num_key_value_groups) 750 | 751 | if attention_mask is not None: 752 | if attention_mask.size() != (bsz, 1, q_len, kv_seq_len): 753 | raise ValueError( 754 | f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}" 755 | ) 756 | 757 | # SDPA with memory-efficient backend is currently (torch==2.1.2) bugged with non-contiguous inputs with custom attn_mask, 758 | # Reference: https://github.com/pytorch/pytorch/issues/112577. 759 | if query_states.device.type == "cuda" and attention_mask is not None: 760 | query_states = query_states.contiguous() 761 | key_states = key_states.contiguous() 762 | value_states = value_states.contiguous() 763 | 764 | attn_output = torch.nn.functional.scaled_dot_product_attention( 765 | query_states, 766 | key_states, 767 | value_states, 768 | attn_mask=attention_mask.to(query_states.device) if attention_mask is not None else None, 769 | dropout_p=self.attention_dropout if self.training else 0.0, 770 | # The q_len > 1 is necessary to match with AttentionMaskConverter.to_causal_4d that does not create a causal mask in case q_len == 1. 771 | is_causal=self.is_causal and attention_mask is None and q_len > 1, 772 | ) 773 | 774 | attn_output = attn_output.transpose(1, 2).contiguous() 775 | attn_output = attn_output.reshape(bsz, q_len, self.hidden_size) 776 | 777 | attn_output = self.o_proj(attn_output) 778 | 779 | return attn_output, None, past_key_value 780 | 781 | 782 | MISTRAL_ATTENTION_CLASSES = { 783 | "eager": MistralAttention, 784 | "flash_attention_2": MistralFlashAttention2, 785 | "sdpa": MistralSdpaAttention, 786 | } 787 | 788 | 789 | class MistralDecoderLayer(nn.Module): 790 | def __init__(self, config: MistralConfig, layer_idx: int): 791 | super().__init__() 792 | self.hidden_size = config.hidden_size 793 | 794 | self.self_attn = MISTRAL_ATTENTION_CLASSES[config._attn_implementation](config, layer_idx) 795 | 796 | self.mlp = MistralMLP(config) 797 | self.input_layernorm = MistralRMSNorm(config.hidden_size, eps=config.rms_norm_eps) 798 | self.post_attention_layernorm = MistralRMSNorm(config.hidden_size, eps=config.rms_norm_eps) 799 | 800 | def forward( 801 | self, 802 | hidden_states: torch.Tensor, 803 | attention_mask: Optional[torch.Tensor] = None, 804 | position_ids: Optional[torch.LongTensor] = None, 805 | past_key_value: Optional[Tuple[torch.Tensor]] = None, 806 | output_attentions: Optional[bool] = False, 807 | use_cache: Optional[bool] = False, 808 | **kwargs, 809 | ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]: 810 | if "padding_mask" in kwargs: 811 | warnings.warn( 812 | "Passing `padding_mask` is deprecated and will be removed in v4.37. Please make sure use `attention_mask` instead.`" 813 | ) 814 | """ 815 | Args: 816 | hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)` 817 | attention_mask (`torch.FloatTensor`, *optional*): attention mask of size 818 | `(batch, sequence_length)` where padding elements are indicated by 0. 819 | output_attentions (`bool`, *optional*): 820 | Whether or not to return the attentions tensors of all attention layers. See `attentions` under 821 | returned tensors for more detail. 822 | use_cache (`bool`, *optional*): 823 | If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding 824 | (see `past_key_values`). 825 | past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states 826 | """ 827 | 828 | residual = hidden_states 829 | 830 | hidden_states = self.input_layernorm(hidden_states) 831 | 832 | # Self Attention 833 | hidden_states, self_attn_weights, present_key_value = self.self_attn( 834 | hidden_states=hidden_states, 835 | attention_mask=attention_mask, 836 | position_ids=position_ids, 837 | past_key_value=past_key_value, 838 | output_attentions=output_attentions, 839 | use_cache=use_cache, 840 | ) 841 | hidden_states = residual.to(hidden_states.device) + hidden_states 842 | 843 | # Fully Connected 844 | residual = hidden_states 845 | hidden_states = self.post_attention_layernorm(hidden_states) 846 | hidden_states = self.mlp(hidden_states) 847 | hidden_states = residual + hidden_states 848 | 849 | outputs = (hidden_states,) 850 | 851 | if output_attentions: 852 | outputs += (self_attn_weights,) 853 | 854 | if use_cache: 855 | outputs += (present_key_value,) 856 | 857 | return outputs 858 | 859 | 860 | MISTRAL_START_DOCSTRING = r""" 861 | This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the 862 | library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads 863 | etc.) 864 | 865 | This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass. 866 | Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage 867 | and behavior. 868 | 869 | Parameters: 870 | config ([`MistralConfig`]): 871 | Model configuration class with all the parameters of the model. Initializing with a config file does not 872 | load the weights associated with the model, only the configuration. Check out the 873 | [`~PreTrainedModel.from_pretrained`] method to load the model weights. 874 | """ 875 | 876 | 877 | @add_start_docstrings( 878 | "The bare Mistral Model outputting raw hidden-states without any specific head on top.", 879 | MISTRAL_START_DOCSTRING, 880 | ) 881 | class MistralPreTrainedModel(PreTrainedModel): 882 | config_class = MistralConfig 883 | base_model_prefix = "model" 884 | supports_gradient_checkpointing = True 885 | _no_split_modules = ["MistralDecoderLayer"] 886 | _skip_keys_device_placement = "past_key_values" 887 | _supports_flash_attn_2 = True 888 | _supports_sdpa = True 889 | _supports_cache_class = True 890 | 891 | def _init_weights(self, module): 892 | std = self.config.initializer_range 893 | if isinstance(module, nn.Linear): 894 | module.weight.data.normal_(mean=0.0, std=std) 895 | if module.bias is not None: 896 | module.bias.data.zero_() 897 | elif isinstance(module, nn.Embedding): 898 | module.weight.data.normal_(mean=0.0, std=std) 899 | if module.padding_idx is not None: 900 | module.weight.data[module.padding_idx].zero_() 901 | 902 | 903 | MISTRAL_INPUTS_DOCSTRING = r""" 904 | Args: 905 | input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): 906 | Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide 907 | it. 908 | 909 | Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and 910 | [`PreTrainedTokenizer.__call__`] for details. 911 | 912 | [What are input IDs?](../glossary#input-ids) 913 | attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*): 914 | Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: 915 | 916 | - 1 for tokens that are **not masked**, 917 | - 0 for tokens that are **masked**. 918 | 919 | [What are attention masks?](../glossary#attention-mask) 920 | 921 | Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and 922 | [`PreTrainedTokenizer.__call__`] for details. 923 | 924 | If `past_key_values` is used, optionally only the last `decoder_input_ids` have to be input (see 925 | `past_key_values`). 926 | 927 | If you want to change padding behavior, you should read [`modeling_opt._prepare_decoder_attention_mask`] 928 | and modify to your needs. See diagram 1 in [the paper](https://arxiv.org/abs/1910.13461) for more 929 | information on the default strategy. 930 | 931 | - 1 indicates the head is **not masked**, 932 | - 0 indicates the head is **masked**. 933 | position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): 934 | Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0, 935 | config.n_positions - 1]`. 936 | 937 | [What are position IDs?](../glossary#position-ids) 938 | past_key_values (`Cache` or `tuple(tuple(torch.FloatTensor))`, *optional*): 939 | Pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention 940 | blocks) that can be used to speed up sequential decoding. This typically consists in the `past_key_values` 941 | returned by the model at a previous stage of decoding, when `use_cache=True` or `config.use_cache=True`. 942 | 943 | Two formats are allowed: 944 | - a [`~cache_utils.Cache`] instance; 945 | - Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of 946 | shape `(batch_size, num_heads, sequence_length, embed_size_per_head)`). This is also known as the legacy 947 | cache format. 948 | 949 | The model will output the same cache format that is fed as input. If no `past_key_values` are passed, the 950 | legacy cache format will be returned. 951 | 952 | If `past_key_values` are used, the user can optionally input only the last `input_ids` (those that don't 953 | have their past key value states given to this model) of shape `(batch_size, 1)` instead of all `input_ids` 954 | of shape `(batch_size, sequence_length)`. 955 | inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*): 956 | Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This 957 | is useful if you want more control over how to convert `input_ids` indices into associated vectors than the 958 | model's internal embedding lookup matrix. 959 | use_cache (`bool`, *optional*): 960 | If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see 961 | `past_key_values`). 962 | output_attentions (`bool`, *optional*): 963 | Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned 964 | tensors for more detail. 965 | output_hidden_states (`bool`, *optional*): 966 | Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for 967 | more detail. 968 | return_dict (`bool`, *optional*): 969 | Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. 970 | """ 971 | 972 | 973 | @add_start_docstrings( 974 | "The bare Mistral Model outputting raw hidden-states without any specific head on top.", 975 | MISTRAL_START_DOCSTRING, 976 | ) 977 | class MistralModel(MistralPreTrainedModel): 978 | """ 979 | Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`MistralDecoderLayer`] 980 | 981 | Args: 982 | config: MistralConfig 983 | """ 984 | 985 | def __init__(self, config: MistralConfig): 986 | super().__init__(config) 987 | self.padding_idx = config.pad_token_id 988 | self.vocab_size = config.vocab_size 989 | 990 | self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx) 991 | self.layers = nn.ModuleList( 992 | [MistralDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)] 993 | ) 994 | self._attn_implementation = config._attn_implementation 995 | self.norm = MistralRMSNorm(config.hidden_size, eps=config.rms_norm_eps) 996 | 997 | self.gradient_checkpointing = False 998 | # Initialize weights and apply final processing 999 | self.post_init() 1000 | 1001 | def get_input_embeddings(self): 1002 | return self.embed_tokens 1003 | 1004 | def set_input_embeddings(self, value): 1005 | self.embed_tokens = value 1006 | 1007 | @add_start_docstrings_to_model_forward(MISTRAL_INPUTS_DOCSTRING) 1008 | def forward( 1009 | self, 1010 | input_ids: torch.LongTensor = None, 1011 | attention_mask: Optional[torch.Tensor] = None, 1012 | position_ids: Optional[torch.LongTensor] = None, 1013 | past_key_values: Optional[List[torch.FloatTensor]] = None, 1014 | inputs_embeds: Optional[torch.FloatTensor] = None, 1015 | use_cache: Optional[bool] = None, 1016 | output_attentions: Optional[bool] = None, 1017 | output_hidden_states: Optional[bool] = None, 1018 | return_dict: Optional[bool] = None, 1019 | ) -> Union[Tuple, BaseModelOutputWithPast]: 1020 | output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions 1021 | output_hidden_states = ( 1022 | output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states 1023 | ) 1024 | use_cache = use_cache if use_cache is not None else self.config.use_cache 1025 | 1026 | return_dict = return_dict if return_dict is not None else self.config.use_return_dict 1027 | 1028 | # retrieve input_ids and inputs_embeds 1029 | if input_ids is not None and inputs_embeds is not None: 1030 | raise ValueError("You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time") 1031 | elif input_ids is not None: 1032 | batch_size, seq_length = input_ids.shape 1033 | elif inputs_embeds is not None: 1034 | batch_size, seq_length, _ = inputs_embeds.shape 1035 | else: 1036 | raise ValueError("You have to specify either decoder_input_ids or decoder_inputs_embeds") 1037 | 1038 | if self.gradient_checkpointing and self.training: 1039 | if use_cache: 1040 | logger.warning_once( 1041 | "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." 1042 | ) 1043 | use_cache = False 1044 | 1045 | past_key_values_length = 0 1046 | 1047 | if use_cache: 1048 | use_legacy_cache = not isinstance(past_key_values, Cache) 1049 | if use_legacy_cache: 1050 | past_key_values = DynamicCache.from_legacy_cache(past_key_values) 1051 | past_key_values_length = past_key_values.get_usable_length(seq_length) 1052 | 1053 | if position_ids is None: 1054 | device = input_ids.device if input_ids is not None else inputs_embeds.device 1055 | position_ids = torch.arange( 1056 | past_key_values_length, seq_length + past_key_values_length, dtype=torch.long, device=device 1057 | ) 1058 | position_ids = position_ids.unsqueeze(0).view(-1, seq_length) 1059 | else: 1060 | position_ids = position_ids.view(-1, seq_length).long() 1061 | 1062 | if inputs_embeds is None: 1063 | inputs_embeds = self.embed_tokens(input_ids) 1064 | 1065 | if attention_mask is not None and self._attn_implementation == "flash_attention_2" and use_cache: 1066 | is_padding_right = attention_mask[:, -1].sum().item() != batch_size 1067 | if is_padding_right: 1068 | raise ValueError( 1069 | "You are attempting to perform batched generation with padding_side='right'" 1070 | " this may lead to unexpected behaviour for Flash Attention version of Mistral. Make sure to " 1071 | " call `tokenizer.padding_side = 'left'` before tokenizing the input. " 1072 | ) 1073 | 1074 | if self._attn_implementation == "flash_attention_2": 1075 | # 2d mask is passed through the layers 1076 | attention_mask = attention_mask if (attention_mask is not None and 0 in attention_mask) else None 1077 | elif self._attn_implementation == "sdpa" and not output_attentions and attention_mask.dim() == 2 and False: 1078 | # output_attentions=True can not be supported when using SDPA, and we fall back on 1079 | # the manual implementation that requires a 4D causal mask in all cases. 1080 | attention_mask = _prepare_4d_causal_attention_mask_for_sdpa( 1081 | attention_mask, 1082 | (batch_size, seq_length), 1083 | inputs_embeds, 1084 | past_key_values_length, 1085 | ) 1086 | elif attention_mask is None or attention_mask.dim() == 2: 1087 | # 4d mask is passed through the layers 1088 | attention_mask = _prepare_4d_causal_attention_mask( 1089 | attention_mask, 1090 | (batch_size, seq_length), 1091 | inputs_embeds, 1092 | past_key_values_length, 1093 | sliding_window=self.config.sliding_window, 1094 | ) 1095 | 1096 | hidden_states = inputs_embeds 1097 | 1098 | # decoder layers 1099 | all_hidden_states = () if output_hidden_states else None 1100 | all_self_attns = () if output_attentions else None 1101 | next_decoder_cache = None 1102 | 1103 | for decoder_layer in self.layers: 1104 | if output_hidden_states: 1105 | all_hidden_states += (hidden_states,) 1106 | 1107 | if self.gradient_checkpointing and self.training: 1108 | layer_outputs = self._gradient_checkpointing_func( 1109 | decoder_layer.__call__, 1110 | hidden_states, 1111 | attention_mask, 1112 | position_ids, 1113 | past_key_values, 1114 | output_attentions, 1115 | use_cache, 1116 | ) 1117 | else: 1118 | layer_outputs = decoder_layer( 1119 | hidden_states, 1120 | attention_mask=attention_mask, 1121 | position_ids=position_ids, 1122 | past_key_value=past_key_values, 1123 | output_attentions=output_attentions, 1124 | use_cache=use_cache, 1125 | ) 1126 | 1127 | hidden_states = layer_outputs[0] 1128 | 1129 | if use_cache: 1130 | next_decoder_cache = layer_outputs[2 if output_attentions else 1] 1131 | 1132 | if output_attentions: 1133 | all_self_attns += (layer_outputs[1],) 1134 | 1135 | hidden_states = self.norm(hidden_states) 1136 | 1137 | # add hidden states from the last decoder layer 1138 | if output_hidden_states: 1139 | all_hidden_states += (hidden_states,) 1140 | 1141 | next_cache = None 1142 | if use_cache: 1143 | next_cache = next_decoder_cache.to_legacy_cache() if use_legacy_cache else next_decoder_cache 1144 | 1145 | if not return_dict: 1146 | return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None) 1147 | return BaseModelOutputWithPast( 1148 | last_hidden_state=hidden_states, 1149 | past_key_values=next_cache, 1150 | hidden_states=all_hidden_states, 1151 | attentions=all_self_attns, 1152 | ) 1153 | 1154 | def nonzero_mean(x, axis=None): 1155 | if axis is not None: 1156 | return x.sum(axis) / (x != 0).sum(axis) 1157 | return x.sum() / (x != 0).sum() 1158 | 1159 | def loss_mean(x): 1160 | return x.sum() / (x != 0).sum() 1161 | 1162 | class MistralForCausalLM(MistralPreTrainedModel): 1163 | _tied_weights_keys = ["lm_head.weight"] 1164 | 1165 | def __init__(self, config): 1166 | super().__init__(config) 1167 | self.model = MistralModel(config) 1168 | self.vocab_size = config.vocab_size 1169 | self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) 1170 | self.max_thoughts = config.max_thoughts 1171 | self.merged_lm_and_talk_heads = config.merged_lm_and_talk_heads 1172 | self.use_concat_talk_head = config.use_concat_talk_head 1173 | self.use_shallow_talk = config.use_shallow_talk 1174 | self.use_complex_talk_head = config.use_complex_talk_head 1175 | self.use_weighted_talk_head = config.use_weighted_talk_head 1176 | # the weighted head will output a single value, so it can't be passed to the lm head 1177 | assert not (self.use_weighted_talk_head and self.use_shallow_talk) 1178 | 1179 | self.n_ahead = 1 1180 | self.n_ahead_talk = 1 1181 | self.n_passes = 1 1182 | self.n_tokens_print = 1 1183 | self.gradient_accumulation_steps = 1 1184 | self.training_steps = 0 1185 | self.tokenizer = None 1186 | self.start_token_id = None 1187 | self.end_token_id = None 1188 | self.rm_initialized = False 1189 | self.residual_talk_head = True 1190 | self.thought_init_std_scale = 1e-2 1191 | 1192 | self.final_only_mode = False 1193 | self.first_and_last_mode = True 1194 | self.first_only = False 1195 | self.original_loss_weight = 0.5 1196 | 1197 | self.cumulative_residual = False 1198 | self.clever_residual = False 1199 | self.skip_residual = False 1200 | self.no_residual = True 1201 | 1202 | self.optimize_lm_head_only_at_start = False 1203 | self.optimize_model_only_at_start = False 1204 | 1205 | if self.optimize_model_only_at_start: 1206 | raise NotImplementedError 1207 | self.train_only_thinking_embedding = False 1208 | self.weighted_embeddings = False 1209 | self.use_start_thought_token = True 1210 | self.use_end_thought_token = True 1211 | self.initialize_thought_embedding_to_normal = False 1212 | self.initial_start_token = "---" 1213 | self.initial_end_token = "---" 1214 | self.output_logits_at_the_end = True 1215 | 1216 | self.wandb_enabled = False 1217 | self.gumbel_temperature = 0.001 1218 | 1219 | self.use_policy_loss = True 1220 | self.include_policy_loss = True 1221 | self.trice_mode = True 1222 | self.remove_negative_rewards = True 1223 | self.use_policy_loss_for_end_thought = True 1224 | 1225 | self.base_original_mode = False 1226 | self.original_mode = False 1227 | 1228 | self.thought_prefix = "(Let's think step by step" 1229 | self.tokenized_thought_prefix = None 1230 | self.log_dict = defaultdict(int) 1231 | self.eval_log_dict = defaultdict(int) 1232 | self.print_final_only = True 1233 | self.loss_mean = loss_mean 1234 | self.all_rewards = [] 1235 | self.all_unreduced_losses = [] 1236 | self.kill_after = 100 1237 | 1238 | self.start_embedding = nn.Parameter(torch.zeros(2, self.model.config.hidden_size)) 1239 | self.end_embedding = nn.Parameter(torch.zeros(2, self.model.config.hidden_size)) 1240 | 1241 | self.policy_loss_beta = 1e6 1242 | self.embedding_scale = 1e2 1243 | self.reinforce_temperature = 3 1244 | self.base_loss_beta = 1 1245 | 1246 | # Not used in the paper: 1247 | self.use_thought_prefix = False 1248 | self.use_reparam_for_thought_embeddings = False 1249 | self.use_upper_triangular = False 1250 | self.subtract_mean_reward = False 1251 | self.comparison_mode = False 1252 | self.gumbel_detach = True 1253 | 1254 | # For visualization 1255 | self.eval_mode = False 1256 | 1257 | num_talk = 1 1258 | talk_input_dim = config.hidden_size if not self.use_concat_talk_head else config.hidden_size * 2 1259 | if self.use_weighted_talk_head: 1260 | talk_output_dim = 1 1261 | else: 1262 | talk_output_dim = config.hidden_size if self.use_shallow_talk else config.vocab_size 1263 | 1264 | if not self.merged_lm_and_talk_heads: 1265 | if self.use_complex_talk_head: 1266 | self.talk_head = nn.ModuleList([nn.Sequential( 1267 | nn.Linear(talk_input_dim, config.hidden_size), 1268 | nn.ReLU(), 1269 | nn.Linear(config.hidden_size, config.hidden_size), 1270 | nn.ReLU(), 1271 | nn.Linear(config.hidden_size, talk_output_dim, bias=False) 1272 | )]) 1273 | else: 1274 | self.talk_head = nn.ModuleList([nn.Sequential( 1275 | nn.Linear(talk_input_dim, talk_output_dim, bias=False) 1276 | )]) 1277 | 1278 | # Initialize weights and apply final processing 1279 | self.post_init() 1280 | 1281 | def get_input_embeddings(self): 1282 | return self.model.embed_tokens 1283 | 1284 | def set_input_embeddings(self, value): 1285 | self.model.embed_tokens = value 1286 | 1287 | def get_output_embeddings(self): 1288 | return self.lm_head 1289 | 1290 | def set_output_embeddings(self, new_embeddings): 1291 | self.lm_head = new_embeddings 1292 | 1293 | def set_decoder(self, decoder): 1294 | self.model = decoder 1295 | 1296 | def get_decoder(self): 1297 | return self.model 1298 | 1299 | @torch.no_grad() 1300 | def infer( 1301 | self, 1302 | input_ids: torch.LongTensor, 1303 | attention_mask: Optional[torch.Tensor] = None, 1304 | position_ids: Optional[torch.LongTensor] = None, 1305 | past_key_values: Optional[List[torch.FloatTensor]] = None, 1306 | inputs_embeds: Optional[torch.FloatTensor] = None, 1307 | use_cache: Optional[bool] = None, 1308 | output_attentions: Optional[bool] = None, 1309 | output_hidden_states: Optional[bool] = None, 1310 | return_dict: Optional[bool] = None, 1311 | ): 1312 | batch_size, seq_len = input_ids.shape 1313 | 1314 | # Save the original input_ids and attention_mask for later use 1315 | original_input_ids = input_ids.clone() 1316 | original_attention_mask = attention_mask.clone() if attention_mask is not None else None 1317 | 1318 | # Append the start thought token to the input sequence 1319 | start_thought_token_id = self.tokenizer.convert_tokens_to_ids("<|startthought|>") 1320 | input_ids = torch.cat([input_ids, torch.tensor([[start_thought_token_id]] * batch_size).to(input_ids.device)], dim=-1) 1321 | seq_len += 1 1322 | 1323 | # Update the attention mask 1324 | if attention_mask is not None: 1325 | attention_mask = torch.cat([attention_mask, torch.ones((batch_size, 1)).to(attention_mask.device)], dim=-1) 1326 | 1327 | # Generate the continuation 1328 | continuation_length = self.n_ahead - 2 1329 | new_key_values = past_key_values 1330 | 1331 | start_time = time.time() 1332 | for continuation_idx in range(continuation_length): 1333 | outputs = self.model( 1334 | input_ids=input_ids if continuation_idx == 0 else next_token_id.unsqueeze(-1).to(input_ids.device), 1335 | attention_mask=attention_mask, 1336 | position_ids=position_ids, 1337 | past_key_values=new_key_values, 1338 | inputs_embeds=inputs_embeds, 1339 | use_cache=True, 1340 | output_attentions=output_attentions, 1341 | output_hidden_states=output_hidden_states, 1342 | return_dict=return_dict, 1343 | ) 1344 | new_key_values = outputs.past_key_values 1345 | 1346 | hidden_states = outputs[0] 1347 | 1348 | logits = self.lm_head(hidden_states) 1349 | logits = logits[:, -1, :] # Only consider the last token 1350 | 1351 | # Apply Gumbel-Softmax to the logits 1352 | next_token_logits = F.gumbel_softmax(logits, tau=self.gumbel_temperature, hard=True, dim=-1) 1353 | next_token_id = torch.argmax(next_token_logits, dim=-1) 1354 | 1355 | # Append the generated token to the input sequence 1356 | input_ids = torch.cat([input_ids, next_token_id.unsqueeze(-1).to(input_ids.device)], dim=-1) 1357 | seq_len += 1 1358 | 1359 | # Update the attention mask 1360 | if attention_mask is not None: 1361 | attention_mask = torch.cat([attention_mask, torch.ones((batch_size, 1)).to(attention_mask.device)], dim=-1) 1362 | 1363 | # Append the end thought token to the input sequence 1364 | end_thought_token_id = self.tokenizer.convert_tokens_to_ids("<|endthought|>") 1365 | input_ids = torch.cat([input_ids, torch.tensor([[end_thought_token_id]] * batch_size).to(input_ids.device)], dim=-1) 1366 | seq_len += 1 1367 | 1368 | # Update the attention mask 1369 | if attention_mask is not None: 1370 | attention_mask = torch.cat([attention_mask, torch.ones((batch_size, 1)).to(attention_mask.device)], dim=-1) 1371 | 1372 | # Get the hidden states before and after the thought 1373 | outputs_before = self.model( 1374 | input_ids=original_input_ids, 1375 | attention_mask=original_attention_mask, 1376 | position_ids=position_ids, 1377 | past_key_values=past_key_values, 1378 | inputs_embeds=inputs_embeds, 1379 | use_cache=use_cache, 1380 | output_attentions=output_attentions, 1381 | output_hidden_states=output_hidden_states, 1382 | return_dict=return_dict, 1383 | ) 1384 | hidden_states_before = outputs_before[0][:, -1:, :] 1385 | 1386 | # two new tokens: last continuation token and end thought token 1387 | outputs_after = self.model( 1388 | input_ids=torch.cat([next_token_id.unsqueeze(-1).to(input_ids.device), torch.tensor(end_thought_token_id).unsqueeze(-1).unsqueeze(-1).to(input_ids.device)], dim=-1), 1389 | attention_mask=attention_mask, 1390 | position_ids=position_ids, 1391 | past_key_values=new_key_values, 1392 | inputs_embeds=inputs_embeds, 1393 | use_cache=use_cache, 1394 | output_attentions=output_attentions, 1395 | output_hidden_states=output_hidden_states, 1396 | return_dict=return_dict, 1397 | ) 1398 | hidden_states_after = outputs_after[0][:, -1:, :] 1399 | 1400 | # Apply the talk head to get the mixing weight 1401 | mixing_weight = self.talk_head[0](torch.cat([hidden_states_before, hidden_states_after], dim=-1)) 1402 | 1403 | # Apply the mixing weight to the hidden states 1404 | mixed_hidden_states = (1 - mixing_weight) * hidden_states_before + mixing_weight * hidden_states_after 1405 | 1406 | # Apply the language model head to get the final logits 1407 | logits = self.lm_head(mixed_hidden_states) 1408 | return logits 1409 | 1410 | @add_start_docstrings_to_model_forward(MISTRAL_INPUTS_DOCSTRING) 1411 | @replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC) 1412 | def forward( 1413 | self, 1414 | input_ids: torch.LongTensor = None, 1415 | attention_mask: Optional[torch.Tensor] = None, 1416 | position_ids: Optional[torch.LongTensor] = None, 1417 | past_key_values: Optional[List[torch.FloatTensor]] = None, 1418 | inputs_embeds: Optional[torch.FloatTensor] = None, 1419 | labels: Optional[torch.LongTensor] = None, 1420 | use_cache: Optional[bool] = None, 1421 | output_attentions: Optional[bool] = None, 1422 | output_hidden_states: Optional[bool] = None, 1423 | return_dict: Optional[bool] = None, 1424 | ) -> Union[Tuple, CausalLMOutputWithPast]: 1425 | r""" 1426 | Args: 1427 | labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): 1428 | Labels for computing the masked language modeling loss. Indices should either be in `[0, ..., 1429 | config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored 1430 | (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. 1431 | 1432 | Returns: 1433 | 1434 | Example: 1435 | 1436 | ```python 1437 | >>> from transformers import AutoTokenizer, MistralForCausalLM 1438 | 1439 | >>> model = MistralForCausalLM.from_pretrained("mistralai/Mistral-7B-v0.1") 1440 | >>> tokenizer = AutoTokenizer.from_pretrained("mistralai/Mistral-7B-v0.1") 1441 | 1442 | >>> prompt = "Hey, are you conscious? Can you talk to me?" 1443 | >>> inputs = tokenizer(prompt, return_tensors="pt") 1444 | 1445 | >>> # Generate 1446 | >>> generate_ids = model.generate(inputs.input_ids, max_length=30) 1447 | >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0] 1448 | "Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you." 1449 | ```""" 1450 | log_dict = self.log_dict if self.training else self.eval_log_dict 1451 | 1452 | if self.training and self.kill_after is not None and self.training_steps // self.gradient_accumulation_steps > self.kill_after: 1453 | raise ValueError("Killed after") 1454 | 1455 | if not self.training: 1456 | n_ahead_talk_to_restore = self.n_ahead_talk 1457 | n_passes_to_restore = self.n_passes 1458 | self.n_ahead_talk = 1 1459 | self.n_passes = 1 1460 | 1461 | output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions 1462 | output_hidden_states = ( 1463 | output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states 1464 | ) 1465 | return_dict = return_dict if return_dict is not None else self.config.use_return_dict 1466 | 1467 | assert self.cumulative_residual or self.clever_residual or self.skip_residual or self.no_residual 1468 | assert not (self.skip_residual and self.use_policy_loss) 1469 | 1470 | if self.tokenized_thought_prefix is None and self.use_thought_prefix: 1471 | self.tokenized_thought_prefix = self.tokenizer(self.thought_prefix, return_tensors="pt", add_special_tokens=False)["input_ids"] 1472 | 1473 | def apply_head(head, states, detach=False): 1474 | if detach: 1475 | head_weight = head.weight.detach() 1476 | else: 1477 | head_weight = head.weight 1478 | head_weight = head_weight.to(states.device) 1479 | return (head_weight @ states.transpose(-1, -2)).transpose(-1, -2).contiguous() 1480 | 1481 | def idx_if_sequential(head, idx=0): 1482 | if isinstance(head, nn.Sequential) or isinstance(head, nn.ModuleList): 1483 | return idx_if_sequential(head[idx], idx=idx) 1484 | return head 1485 | 1486 | def none_repeat_interleave(x, n): 1487 | if x is None: 1488 | return x 1489 | return x.repeat_interleave(n, dim=0) 1490 | 1491 | if self.n_passes > 1: 1492 | input_ids = none_repeat_interleave(input_ids, self.n_passes) 1493 | attention_mask = none_repeat_interleave(attention_mask, self.n_passes) 1494 | position_ids = none_repeat_interleave(position_ids, self.n_passes) 1495 | inputs_embeds = none_repeat_interleave(inputs_embeds, self.n_passes) 1496 | labels = none_repeat_interleave(labels, self.n_passes) 1497 | if past_key_values is not None: 1498 | past_key_values = [none_repeat_interleave(p, self.n_passes) for p in past_key_values] 1499 | cur_token_indices = torch.arange(input_ids.shape[1], device=input_ids.device) 1500 | 1501 | self.tokenizer_has_start_thought_token = True 1502 | self.tokenizer_has_end_thought_token = True 1503 | if self.start_token_id is None: 1504 | self.start_token_id = self.tokenizer.convert_tokens_to_ids("<|startthought|>") 1505 | if self.start_token_id == 0: 1506 | self.start_token_id = self.tokenizer.bos_token_id 1507 | self.tokenizer_has_start_thought_token = False 1508 | elif self.use_start_thought_token: 1509 | # base_start_id = self.tokenizer.convert_tokens_to_ids(self.initial_start_token) 1510 | base_start_id = self.tokenizer.encode(self.initial_start_token, add_special_tokens=False)[0] 1511 | if self.initialize_thought_embedding_to_normal: 1512 | self.start_embedding.data = torch.zeros_like(self.start_embedding.data) 1513 | else: 1514 | self.start_embedding.data[0] = self.model.embed_tokens.weight.data[base_start_id].clone().detach() / self.embedding_scale 1515 | self.start_embedding.data[1] = torch.log(self.model.embed_tokens.weight.data.std(dim=0) * self.thought_init_std_scale / self.embedding_scale) 1516 | if self.end_token_id is None: 1517 | self.end_token_id = self.tokenizer.convert_tokens_to_ids("<|endthought|>") 1518 | if self.end_token_id == 0: 1519 | self.end_token_id = self.tokenizer.eos_token_id 1520 | self.tokenizer_has_end_thought_token = False 1521 | elif self.use_end_thought_token: 1522 | # base_end_id = self.tokenizer.convert_tokens_to_ids(self.initial_end_token) 1523 | base_end_id = self.tokenizer.encode(self.initial_end_token, add_special_tokens=False)[0] 1524 | if self.initialize_thought_embedding_to_normal: 1525 | self.end_embedding.data = torch.zeros_like(self.end_embedding.data) 1526 | else: 1527 | self.end_embedding.data[0] = self.model.embed_tokens.weight.data[base_end_id].clone().detach() / self.embedding_scale 1528 | self.end_embedding.data[1] = torch.log(self.model.embed_tokens.weight.data.std(dim=0) * self.thought_init_std_scale / self.embedding_scale) 1529 | 1530 | if not self.rm_initialized and (self.n_ahead > 1 or not self.base_original_mode): 1531 | self.rm_initialized = True 1532 | if not self.use_shallow_talk: 1533 | head = self.talk_head[0] 1534 | cur_head = head[-1] if isinstance(head, nn.Sequential) else head 1535 | talk_input_dim = cur_head.weight.data.shape[1] 1536 | talk_output_dim = 1 if self.use_weighted_talk_head else self.lm_head.weight.data.shape[0] 1537 | cur_head.weight.data = torch.zeros(talk_output_dim, talk_input_dim, device=cur_head.weight.device, dtype=cur_head.weight.dtype) 1538 | else: 1539 | # convert to identity transform 1540 | def lambda_transform(cur_head): 1541 | if cur_head.weight.data.shape[0] != cur_head.weight.data.shape[1]: 1542 | return torch.cat([ 1543 | torch.eye( 1544 | cur_head.weight.data.shape[0], 1545 | device=cur_head.weight.device, 1546 | dtype=cur_head.weight.dtype 1547 | ), 1548 | torch.zeros( 1549 | cur_head.weight.data.shape[0], 1550 | cur_head.weight.data.shape[1] - cur_head.weight.data.shape[0], 1551 | device=cur_head.weight.device, 1552 | dtype=cur_head.weight.dtype 1553 | )], dim=1) 1554 | return torch.eye( 1555 | cur_head.weight.data.shape[0], 1556 | device=cur_head.weight.device, 1557 | dtype=cur_head.weight.dtype 1558 | ) 1559 | if isinstance(self.talk_head[0], nn.Sequential): 1560 | for cur_head in self.talk_head[0]: 1561 | # if it has weights 1562 | if hasattr(cur_head, "weight"): 1563 | cur_head.weight.data = lambda_transform(cur_head) 1564 | else: 1565 | self.talk_head[-1].weight.data = lambda_transform(self.talk_head[0]) 1566 | 1567 | loss = None 1568 | prev_rm_tokens = None 1569 | cur_rm_tokens = None 1570 | prev_rm_logits = None 1571 | prev_sample_probs = None 1572 | did_skip_sampling = None 1573 | skip_sampling = None 1574 | sample_probs = None 1575 | hidden_states = None 1576 | logits = None 1577 | talk_kl_penalty = None 1578 | rm_logits = None 1579 | residual_logits = None 1580 | probabilities_2d = None 1581 | prev_probabilities_2d = None 1582 | policy_reward = None 1583 | logits_to_output = None 1584 | batch_size, seq_len = input_ids.shape 1585 | base_input_ids = input_ids.clone() 1586 | loss_list = [] 1587 | dqn_loss_list = [] 1588 | sampled_token_history = [] 1589 | sample_probs_history = [] 1590 | action_loglikelihoods_list = [] 1591 | 1592 | if self.use_end_thought_token or self.use_start_thought_token: 1593 | if not self.use_reparam_for_thought_embeddings: 1594 | start_embedding = self.start_embedding[0].unsqueeze(0) * self.embedding_scale 1595 | end_embedding = self.end_embedding[0].unsqueeze(0) * self.embedding_scale 1596 | else: 1597 | start_embedding = self.start_embedding * self.embedding_scale 1598 | end_embedding = self.end_embedding * self.embedding_scale 1599 | base_embeddings = self.model.embed_tokens.weight 1600 | if self.train_only_thinking_embedding: 1601 | base_embeddings = base_embeddings.detach() 1602 | # # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn) 1603 | fwd_iters = 1 if self.original_mode else self.n_ahead + self.n_ahead_talk - 1 1604 | for ahead_idx in range(fwd_iters): 1605 | past_key_values_length = 0 1606 | if past_key_values is not None: 1607 | use_legacy_cache = not isinstance(past_key_values, Cache) 1608 | if use_legacy_cache: 1609 | past_key_values = DynamicCache.from_legacy_cache(past_key_values) 1610 | past_key_values_length = past_key_values.get_usable_length(seq_len) 1611 | 1612 | if position_ids is None: 1613 | device = input_ids.device if input_ids is not None else inputs_embeds.device 1614 | position_ids = torch.arange( 1615 | past_key_values_length, seq_len + past_key_values_length, dtype=torch.long, device=device 1616 | ) 1617 | position_ids = position_ids.unsqueeze(0).view(-1, seq_len) 1618 | else: 1619 | position_ids = position_ids.view(-1, seq_len).long() 1620 | 1621 | if inputs_embeds is None: 1622 | contains_start = self.use_start_thought_token and (input_ids == self.start_token_id).any() 1623 | contains_end = self.use_end_thought_token and (input_ids == self.end_token_id).any() 1624 | contains_thought = contains_start or contains_end 1625 | if contains_thought: 1626 | thought_id = self.start_token_id if contains_start else self.end_token_id 1627 | cur_thought_embedding = start_embedding if contains_start else end_embedding 1628 | if self.use_reparam_for_thought_embeddings: 1629 | inputs_embeds = torch.randn(batch_size, seq_len, self.model.config.hidden_size, device=input_ids.device, dtype=cur_thought_embedding.dtype) 1630 | inputs_embeds = inputs_embeds.detach() * torch.exp(cur_thought_embedding[1]) + cur_thought_embedding[0] 1631 | if contains_start: 1632 | sampled_start = inputs_embeds.clone().detach() 1633 | if contains_end: 1634 | sampled_end = inputs_embeds.clone().detach() 1635 | else: 1636 | inputs_embeds = cur_thought_embedding.unsqueeze(0).repeat(batch_size, seq_len, 1) 1637 | else: 1638 | with torch.set_grad_enabled(not self.train_only_thinking_embedding): 1639 | inputs_embeds = self.model.embed_tokens(input_ids) 1640 | 1641 | if self.n_ahead != 1 or self.n_ahead_talk != 1 or self.comparison_mode: 1642 | if attention_mask is None: 1643 | base_attention_mask = torch.triu(torch.ones(seq_len, seq_len), diagonal=0).to(input_ids.device) 1644 | base_attention_mask = base_attention_mask.view(1, 1, seq_len, seq_len) 1645 | base_attention_mask = base_attention_mask.repeat(input_ids.shape[0], 1, 1, 1) 1646 | attention_mask = base_attention_mask 1647 | breakpoint() 1648 | elif attention_mask.dim() == 2: 1649 | if seq_len + past_key_values_length != attention_mask.shape[-1]: 1650 | breakpoint() 1651 | attention_mask = torch.cat( 1652 | [torch.ones((attention_mask.shape[0], past_key_values_length), dtype=attention_mask.dtype, device=attention_mask.device), attention_mask], 1653 | dim=-1 1654 | ) 1655 | # # if the attention mask 1656 | attention_mask = _prepare_4d_causal_attention_mask( 1657 | attention_mask, 1658 | (batch_size, seq_len), 1659 | inputs_embeds, 1660 | past_key_values_length, 1661 | sliding_window=self.config.sliding_window, 1662 | ) 1663 | 1664 | outputs = self.model( 1665 | # input_ids=input_ids, 1666 | attention_mask=attention_mask, 1667 | position_ids=position_ids, 1668 | past_key_values=past_key_values, 1669 | inputs_embeds=inputs_embeds, 1670 | use_cache=use_cache, 1671 | output_attentions=output_attentions, 1672 | output_hidden_states=output_hidden_states, 1673 | return_dict=return_dict, 1674 | ) 1675 | 1676 | prev_hidden_states = hidden_states 1677 | hidden_states = outputs[0] 1678 | prev_rm_logits = rm_logits # for policy gradient 1679 | prev_rm_tokens = cur_rm_tokens # for policy gradient 1680 | 1681 | if ahead_idx == 0: 1682 | hidden_states_lm = hidden_states 1683 | logits = self.lm_head(hidden_states_lm) 1684 | base_hidden_states = hidden_states.clone() 1685 | initial_loss_logits = logits.clone() 1686 | if self.optimize_lm_head_only_at_start or self.optimize_model_only_at_start: 1687 | logits = logits.detach() 1688 | base_hidden_states = base_hidden_states.detach() 1689 | if self.optimize_model_only_at_start: 1690 | hidden_states = hidden_states.detach() 1691 | base_logits = logits.clone() 1692 | else: 1693 | talk_hidden_states = hidden_states 1694 | if self.merged_lm_and_talk_heads: 1695 | assert self.no_residual 1696 | residual_logits = self.lm_head(hidden_states) 1697 | talk_hidden_states = hidden_states 1698 | else: 1699 | if ahead_idx > self.n_ahead - 1: 1700 | cur_base_hidden = torch.cat([ 1701 | base_hidden_states[..., ahead_idx - self.n_ahead + 1:, :], 1702 | base_hidden_states[..., :ahead_idx - self.n_ahead + 1, :] 1703 | ], dim=-2) 1704 | else: 1705 | cur_base_hidden = base_hidden_states 1706 | 1707 | if self.use_concat_talk_head: 1708 | # concatenate the hidden states with the original hidden states 1709 | head_input_hidden_states = torch.cat([cur_base_hidden, talk_hidden_states], dim=-1) 1710 | else: 1711 | head_input_hidden_states = talk_hidden_states 1712 | 1713 | residual_logits = self.talk_head[0](head_input_hidden_states) 1714 | if self.use_shallow_talk: 1715 | residual_logits = apply_head(self.lm_head, residual_logits, detach=self.optimize_lm_head_only_at_start) 1716 | residual_logits = residual_logits.to(logits.device) 1717 | if self.use_weighted_talk_head: 1718 | # combine the cur_base_hidden with the talk_hidden_states according to the weighted head 1719 | residual_logits = cur_base_hidden * (1 - residual_logits) + talk_hidden_states * residual_logits 1720 | residual_logits = apply_head(self.lm_head, residual_logits, detach=self.optimize_lm_head_only_at_start) 1721 | 1722 | assert sum([self.cumulative_residual, self.clever_residual, self.skip_residual, self.no_residual]) == 1 1723 | if self.clever_residual: 1724 | if ahead_idx >= self.n_ahead - 1: 1725 | # get the logits shifted according to the current talk ahead 1726 | cur_base_logits = torch.cat([ 1727 | base_logits[..., ahead_idx - self.n_ahead + 1:, :], 1728 | base_logits[..., :ahead_idx - self.n_ahead + 1, :] 1729 | ], dim=-2) 1730 | if self.optimize_lm_head_only_at_start: 1731 | cur_base_logits = cur_base_logits.detach() 1732 | logits = cur_base_logits + residual_logits 1733 | else: 1734 | logits += residual_logits / self.n_ahead 1735 | elif self.cumulative_residual: 1736 | if self.residual_talk_head: 1737 | if ahead_idx < self.n_ahead: 1738 | logits += residual_logits 1739 | else: 1740 | # get the logits shifted according to the current talk ahead 1741 | cur_base_logits = torch.cat([ 1742 | base_logits[..., ahead_idx - self.n_ahead + 1:, :], 1743 | base_logits[..., :ahead_idx - self.n_ahead + 1, :] 1744 | ], dim=-2) 1745 | if self.optimize_lm_head_only_at_start: 1746 | cur_base_logits = cur_base_logits.detach() 1747 | logits = cur_base_logits + residual_logits 1748 | else: 1749 | if ahead_idx < self.n_ahead: 1750 | logits += residual_logits 1751 | else: 1752 | logits = residual_logits 1753 | elif self.skip_residual: 1754 | if ahead_idx >= self.n_ahead: 1755 | # get the logits shifted according to the current talk ahead 1756 | cur_base_logits = torch.cat([ 1757 | base_logits[..., ahead_idx - self.n_ahead + 1:, :], 1758 | base_logits[..., :ahead_idx - self.n_ahead + 1, :] 1759 | ], dim=-2) 1760 | if self.optimize_lm_head_only_at_start: 1761 | cur_base_logits = cur_base_logits.detach() 1762 | logits = cur_base_logits 1763 | elif self.no_residual: 1764 | logits = residual_logits 1765 | else: 1766 | logits = base_logits + residual_logits 1767 | 1768 | attempted = False 1769 | talk_loss_list = [] 1770 | if self.original_mode or (self.n_ahead == 1) or (self.comparison_mode and ahead_idx == 0):# or (self.optimize_lm_head_only_at_start and ahead_idx == 0): 1771 | loss = None 1772 | attempted = True 1773 | 1774 | if labels is not None: 1775 | for shift_amount in range(self.n_ahead_talk): 1776 | # Shift so that tokens < n predict n 1777 | # ab[cde]f 1778 | # abc[def] 1779 | if ahead_idx == 0 and self.optimize_lm_head_only_at_start: 1780 | loss_logits = initial_loss_logits 1781 | else: 1782 | loss_logits = logits 1783 | shift_logits = loss_logits[..., shift_amount:-1, :].contiguous() 1784 | shift_labels = labels[..., 1 + shift_amount:].contiguous() 1785 | # Flatten the tokens 1786 | loss_fct = CrossEntropyLoss(reduction="none") 1787 | shift_logits = shift_logits.view(-1, self.config.vocab_size) 1788 | shift_labels = shift_labels.view(-1).clone() 1789 | # Enable model parallelism 1790 | shift_labels[shift_labels == self.tokenizer.pad_token_id] = -100 1791 | shift_labels = shift_labels.to(shift_logits.device) 1792 | loss = loss_fct(shift_logits, shift_labels) 1793 | if not self.comparison_mode and not (self.optimize_lm_head_only_at_start and (self.n_ahead + self.n_ahead_talk > 2)) or self.original_mode: 1794 | loss_list.append(loss) 1795 | talk_loss_list.append(nonzero_mean(loss).detach()) 1796 | 1797 | if not attempted or self.comparison_mode: 1798 | rm_hidden_states = hidden_states 1799 | # print("Magnitude of RM hidden states before RM head", rm_hidden_states.norm()) 1800 | rm_logits = apply_head(self.lm_head, rm_hidden_states, detach=self.optimize_lm_head_only_at_start) 1801 | 1802 | # don't allow it to predict the thinking token 1803 | if self.tokenizer_has_start_thought_token: 1804 | rm_logits[..., self.start_token_id] = -1e10 1805 | if self.tokenizer_has_end_thought_token: 1806 | rm_logits[..., self.end_token_id] = -1e10 1807 | probabilities = rm_logits 1808 | if probabilities_2d is not None: 1809 | prev_probabilities_2d = probabilities_2d.clone() 1810 | probabilities_2d = probabilities.view(-1, probabilities.size(-1)) 1811 | 1812 | did_skip_sampling = skip_sampling 1813 | skip_sampling = False 1814 | if ahead_idx == 0 and self.use_start_thought_token: 1815 | override_token = self.start_token_id 1816 | elif self.use_thought_prefix and ahead_idx < self.tokenized_thought_prefix.shape[-1]: 1817 | override_token = self.tokenized_thought_prefix[..., ahead_idx] 1818 | elif ahead_idx == self.n_ahead - 2 and self.use_end_thought_token: 1819 | override_token = self.end_token_id 1820 | else: 1821 | override_token = None 1822 | if override_token is not None and self.n_ahead > 1: 1823 | # always start with the start token 1824 | probabilities_2d = torch.zeros_like(probabilities_2d) 1825 | probabilities_2d[:, override_token] = 1.0 1826 | skip_sampling = True 1827 | elif ahead_idx >= self.n_ahead - 1: 1828 | if labels is not None: # we're in the talk phase 1829 | cur_talk_n = ahead_idx - (self.n_ahead - 1) + 1 1830 | # print("Setting rm to labels", cur_talk_n, "during", ahead_idx) 1831 | shift_labels = labels[..., cur_talk_n:].contiguous().to(probabilities_2d.device) 1832 | padding = torch.full_like( 1833 | labels[..., :cur_talk_n], 1834 | self.tokenizer.pad_token_id, 1835 | dtype=torch.long, 1836 | device=shift_labels.device 1837 | ) 1838 | new_rm_tokens = torch.cat( 1839 | [shift_labels, padding], 1840 | dim=-1 1841 | ) 1842 | # convert rm tokens to one-hot 1843 | probabilities_2d = F.one_hot(new_rm_tokens, num_classes=self.vocab_size).reshape(-1, self.vocab_size).to(probabilities_2d.dtype) 1844 | skip_sampling = True 1845 | else: 1846 | continue 1847 | temperature = self.gumbel_temperature if self.training else 0.001 1848 | prev_sample_probs = sample_probs 1849 | sample_probs = probabilities_2d 1850 | if ahead_idx < self.n_ahead - 1 and not skip_sampling: 1851 | probabilities_2d = F.gumbel_softmax(sample_probs, tau=temperature, hard=True, dim=-1) 1852 | if self.gumbel_detach: 1853 | probabilities_2d = probabilities_2d.detach() 1854 | sampled_token_history.append(probabilities_2d.argmax(dim=-1).detach().cpu()) 1855 | # convert rm logits directly to embeddings 1856 | contains_start = self.use_start_thought_token and (probabilities_2d[..., self.start_token_id].sum() > 0) 1857 | contains_end = self.use_end_thought_token and (probabilities_2d[..., self.end_token_id].sum() > 0) 1858 | contains_thought = contains_start or contains_end 1859 | 1860 | if not contains_thought: 1861 | with torch.set_grad_enabled(not self.train_only_thinking_embedding): 1862 | inputs_embeds = probabilities_2d @ (self.model.embed_tokens.weight.to(probabilities.device).to(probabilities.dtype)) 1863 | else: 1864 | thought_id = self.start_token_id if contains_start else self.end_token_id 1865 | cur_thought_embedding = start_embedding if contains_start else end_embedding 1866 | if self.use_reparam_for_thought_embeddings: 1867 | inputs_embeds = torch.randn(batch_size, seq_len, self.model.config.hidden_size, device=input_ids.device, dtype=cur_thought_embedding.dtype) 1868 | inputs_embeds = inputs_embeds * torch.exp(cur_thought_embedding[1]) + cur_thought_embedding[0] 1869 | if contains_start: 1870 | sampled_start = inputs_embeds.clone().detach() 1871 | else: 1872 | sampled_end = inputs_embeds.clone().detach() 1873 | else: 1874 | inputs_embeds = cur_thought_embedding.unsqueeze(0).repeat(batch_size, seq_len, 1) 1875 | inputs_embeds = inputs_embeds.view(probabilities.size(0), probabilities.size(1), -1).to(self.model.embed_tokens.weight.dtype) 1876 | inputs_embeds = inputs_embeds.view(probabilities.size(0), probabilities.size(1), -1).to(self.model.embed_tokens.weight.dtype) 1877 | 1878 | if len(attention_mask.shape) == 2: 1879 | breakpoint() 1880 | else: 1881 | original_attention = attention_mask[..., :attention_mask.shape[-2]] 1882 | if self.use_upper_triangular: 1883 | new_attention = original_attention 1884 | else: 1885 | original_attention = original_attention == attention_mask.max() 1886 | # because eye isn't implemented for BF16, we need to handle the case 1887 | if not attention_mask.dtype == torch.bfloat16: 1888 | new_attention = torch.eye( 1889 | seq_len, dtype=attention_mask.dtype, device=attention_mask.device 1890 | ) 1891 | else: 1892 | new_attention = torch.eye( 1893 | seq_len, dtype=torch.float32, device=attention_mask.device 1894 | ).to(attention_mask.dtype) 1895 | 1896 | new_attention = new_attention.view(1, 1, seq_len, seq_len).repeat(input_ids.shape[0], 1, 1, 1) 1897 | new_attention = new_attention * original_attention 1898 | new_attention[new_attention == 0] = attention_mask.min() 1899 | new_attention[new_attention == 1] = attention_mask.max() 1900 | attention_mask = torch.cat([attention_mask, new_attention], dim=-1) 1901 | past_key_values = outputs.past_key_values 1902 | position_ids = position_ids + 1 1903 | 1904 | if labels is not None and (self.n_ahead > 1 or not self.base_original_mode): 1905 | # Shift so that tokens < n predict n 1906 | # logits: abcdef -> bcdef? -> cdef?? 1907 | # labels: abcdef -> ?bcdef -> ??cdef 1908 | if ahead_idx == 0 and self.optimize_lm_head_only_at_start: 1909 | loss_logits = initial_loss_logits 1910 | else: 1911 | loss_logits = logits 1912 | shift_idx = 1 + max(0, ahead_idx - (self.n_ahead - 1)) 1913 | shift_logits = loss_logits[..., :-shift_idx, :].contiguous() 1914 | shift_labels = labels[..., shift_idx:].contiguous() 1915 | # Flatten the tokens 1916 | loss_fct = CrossEntropyLoss(reduction="none") 1917 | shift_logits = shift_logits.view(-1, self.config.vocab_size) 1918 | shift_labels = shift_labels.view(-1) 1919 | # Enable model parallelism 1920 | shift_labels = shift_labels.to(shift_logits.device) 1921 | # if shift_labels.min() == self.tokenizer.pad_token_id: 1922 | shift_labels = torch.where(shift_labels == self.tokenizer.pad_token_id, -100, shift_labels) 1923 | unreduced_loss = loss_fct(shift_logits, shift_labels) 1924 | if torch.any(unreduced_loss != unreduced_loss): 1925 | raise ValueError("NaN loss") 1926 | unreduced_loss = unreduced_loss.reshape(logits.shape[0], -1) 1927 | loss_list.append(unreduced_loss) 1928 | 1929 | 1930 | if self.use_policy_loss and ahead_idx > 0 and (ahead_idx > 1 or not self.use_start_thought_token): 1931 | # we treat the change in loss as the reward 1932 | previous_loss = loss_list[-2] 1933 | # for example, suppose n_ahead = 3 and n_ahead_talk = 2 1934 | # note that we end at self.n_ahead + self.n_ahead_talk - 2 1935 | # in this case, 5 - 2 = 3, so we end at ahead_idx = 3 1936 | # we also predict the next token at ahead_idx = 2 1937 | # when we get to ahead_idx = 2, we predict ahead 1938 | # so we shift by 1 1939 | # note that this is ahead_idx = n_ahead - 1 1940 | # when we get to ahead_idx = 3, we predict ahead 1941 | # so we shift by 2 1942 | # note that this is ahead_idx = n_ahead 1943 | if ahead_idx < self.n_ahead - 1: 1944 | shift_amount = 0 1945 | original_dqn_reward = (previous_loss - unreduced_loss).detach() 1946 | if self.first_and_last_mode: 1947 | original_dqn_reward = original_dqn_reward * 0.0 1948 | else: 1949 | # logits vs cur_policy_shift_logits 1950 | # let's look at rm_logits and prev_rm_logits 1951 | shift_amount = max(0, ahead_idx - (self.n_ahead - 1)) 1952 | # let's say shift_amount = 2 1953 | # abcdefg -> bcdefg? -> cdefg?? 1954 | # logits = [a b]c d e f[g] 1955 | # labels = [a b c]d e f g 1956 | cur_policy_shift_logits = initial_loss_logits[..., shift_amount:-1, :].contiguous().detach() 1957 | cur_policy_shift_labels = labels[..., 1 + shift_amount:].contiguous() 1958 | # Flatten the tokens 1959 | cur_policy_loss_fct = CrossEntropyLoss(reduction="none") 1960 | cur_policy_shift_logits = cur_policy_shift_logits.view(-1, self.config.vocab_size) 1961 | cur_policy_shift_labels = cur_policy_shift_labels.view(-1).clone() 1962 | # Enable model parallelism 1963 | cur_policy_shift_labels[cur_policy_shift_labels == self.tokenizer.pad_token_id] = -100 1964 | cur_policy_shift_labels = cur_policy_shift_labels.to(cur_policy_shift_labels.device) 1965 | cur_policy_reward_base_loss = loss_fct( 1966 | cur_policy_shift_logits, cur_policy_shift_labels.to(cur_policy_shift_logits.device) 1967 | ).reshape(logits.shape[0], -1) 1968 | original_dqn_reward = cur_policy_reward_base_loss.detach() - unreduced_loss 1969 | 1970 | if not did_skip_sampling: 1971 | nonzero_indices = prev_probabilities_2d.nonzero() 1972 | action_loglikelihoods = F.log_softmax(prev_sample_probs / self.reinforce_temperature, dim=-1)[nonzero_indices[:, 0], nonzero_indices[:, 1]] 1973 | action_loglikelihoods_2d = action_loglikelihoods.reshape(batch_size, -1)[:, :-1 - shift_amount] 1974 | action_loglikelihoods_list.append(action_loglikelihoods_2d) 1975 | if policy_reward is None: 1976 | policy_reward = original_dqn_reward[:, :-(self.n_ahead_talk - shift_amount)] 1977 | else: 1978 | if self.n_ahead_talk > shift_amount: 1979 | added_reward = original_dqn_reward[:, :-(self.n_ahead_talk - shift_amount)] 1980 | else: 1981 | added_reward = original_dqn_reward 1982 | policy_reward += added_reward 1983 | 1984 | if self.use_policy_loss and ahead_idx == self.n_ahead + self.n_ahead_talk - 2: 1985 | # only compute during the thinking phase 1986 | if self.use_reparam_for_thought_embeddings and (self.use_start_thought_token or self.use_end_thought_token): 1987 | # sampled_start, sampled_end 1988 | # calculate the log likelihood of the start and end embeddings sampled from a multivariate normal distribution 1989 | # with mean start_embedding[0] and standard deviation start_embedding[1] 1990 | if self.use_start_thought_token: 1991 | exp_start_std = torch.exp(start_embedding[1]) 1992 | start_loglikelihood = -0.5 * (sampled_start.detach() - start_embedding[0]) ** 2 / exp_start_std ** 2 - start_embedding[1] - 0.5 * math.log(2 * math.pi) 1993 | start_loglikelihood = start_loglikelihood.mean(dim=-1) 1994 | if self.use_end_thought_token: 1995 | exp_end_std = torch.exp(end_embedding[1]) 1996 | end_loglikelihood = -0.5 * (sampled_end.detach() - end_embedding[0]) ** 2 / exp_end_std ** 2 - end_embedding[1] - 0.5 * math.log(2 * math.pi) 1997 | end_loglikelihood = end_loglikelihood.mean(dim=-1) 1998 | # we use the mean instead of the sum to prevent dependence on the dimensionality of the embeddings 1999 | if self.use_end_thought_token and self.use_policy_loss_for_end_thought: 2000 | action_loglikelihoods_list.append(end_loglikelihood) 2001 | if self.use_start_thought_token: 2002 | action_loglikelihoods_list.append(start_loglikelihood) 2003 | 2004 | if ahead_idx == self.n_ahead + self.n_ahead_talk - 2 and self.eval_mode: 2005 | with torch.no_grad(): 2006 | # calculate the 0.75 quantile of the rewards 2007 | filtered_tokens = input_ids[:, :policy_reward.shape[-1]].cpu().detach().numpy().flatten() 2008 | filtered_tokens_mask = filtered_tokens != self.tokenizer.pad_token_id 2009 | filtered_tokens = filtered_tokens[filtered_tokens_mask] 2010 | filtered_rewards = policy_reward.float().cpu().detach().numpy()[:, :seq_len - self.n_ahead_talk].flatten() 2011 | filtered_rewards = filtered_rewards[filtered_tokens_mask] 2012 | 2013 | abs_reward_list = np.abs(policy_reward.float().cpu().detach().numpy()[:, :seq_len - self.n_ahead_talk].flatten()) 2014 | abs_reward_list = abs_reward_list[filtered_tokens_mask] 2015 | medium_quantile = np.quantile(abs_reward_list, 0.5) 2016 | upper_quantile = np.quantile(abs_reward_list, 0.95) 2017 | 2018 | save_tokens_with_rewards_to_pdf( 2019 | filtered_tokens, 2020 | [0] + filtered_rewards.tolist(), 2021 | self.tokenizer, 2022 | output_file=f"texts/rewards_talk_{self.n_ahead_talk}_{self.training_steps}.pdf", 2023 | eps=medium_quantile, 2024 | eps2=upper_quantile, 2025 | ) 2026 | 2027 | def plot_kde(data, losses): 2028 | sns.set(style="whitegrid") 2029 | # Create the KDE plot 2030 | sns.kdeplot(data, fill=True) 2031 | # Set the plot title and labels 2032 | plt.title("KDE Plot") 2033 | plt.xlabel("Value") 2034 | plt.ylabel("Density") 2035 | # Save the plot 2036 | plt.savefig(f"texts/kde_talk_{self.n_ahead_talk}_{self.training_steps}.pdf") 2037 | # Close the plot 2038 | plt.close() 2039 | 2040 | # Step 1: Create a base color palette 2041 | base_colors = sns.color_palette("light:#5A9", n_colors=256) # More colors for a smoother gradient 2042 | base_cmap = LinearSegmentedColormap.from_list("log_light", base_colors) 2043 | log_norm = LogNorm(vmin=1e-3, vmax=10) 2044 | 2045 | sns.kdeplot(x=data, y=losses, fill=True, levels=20, norm=log_norm, cut=0, linewidths=0) 2046 | # limit y to 0 to 25 and x to -1 to 1 2047 | plt.xlim(-1, 1) 2048 | plt.ylim(0, 25) 2049 | plt.savefig(f"texts/jointer_talk_{self.n_ahead_talk}_{self.training_steps}.pdf") 2050 | plt.close() 2051 | 2052 | self.all_rewards.extend(filtered_rewards) 2053 | self.all_unreduced_losses.extend(unreduced_loss[:, :-1].flatten()[filtered_tokens_mask].float().flatten().cpu().detach().numpy()) 2054 | plot_kde(self.all_rewards, self.all_unreduced_losses) 2055 | 2056 | for action_loglikelihoods_2d in action_loglikelihoods_list: 2057 | train_policy_reward = policy_reward 2058 | 2059 | # discard rewards below the mean 2060 | if self.trice_mode and self.n_passes > 1: 2061 | batched_policy_reward = train_policy_reward.reshape(-1, self.n_passes, train_policy_reward.shape[-1]) 2062 | # average over the passes 2063 | train_policy_reward = batched_policy_reward - batched_policy_reward.mean(dim=1, keepdim=True) 2064 | train_policy_reward = train_policy_reward.reshape(-1, train_policy_reward.shape[-1]) 2065 | 2066 | if self.subtract_mean_reward: 2067 | train_policy_reward = train_policy_reward - train_policy_reward.mean() 2068 | if self.remove_negative_rewards: 2069 | fixed_policy_reward = train_policy_reward.detach().clamp(min=0) 2070 | else: 2071 | fixed_policy_reward = train_policy_reward.detach() 2072 | actor_loss = -fixed_policy_reward * action_loglikelihoods_2d[:, :policy_reward.shape[-1]].to(policy_reward.device) 2073 | if action_loglikelihoods_2d.mean() < -1e4 and not self.use_policy_loss_just_for_thoughts: 2074 | # This will only happen when we force the next token to be the end of thought token 2075 | break 2076 | dqn_loss_list.append(actor_loss.mean()) 2077 | 2078 | if loss_list: 2079 | if self.first_and_last_mode: 2080 | loss = sum( 2081 | self.loss_mean(loss_list[-(i + 1)]) for i in range(self.n_ahead_talk) 2082 | ) * (1 - self.original_loss_weight) / self.n_ahead_talk 2083 | loss = loss + self.loss_mean(loss_list[0]) * self.original_loss_weight 2084 | # Let's NaN out the others 2085 | # e.g. if n_ahead_talk = 2 and the list is 5 long, we want to NaN out 1, 2 but keep 0, 3, 4 2086 | for i in range(1, len(loss_list) - self.n_ahead_talk): 2087 | loss_list[i] = loss_list[i] * math.nan 2088 | elif self.first_only: 2089 | loss = self.loss_mean(loss_list[0]) 2090 | elif self.final_only_mode: 2091 | loss = sum( 2092 | self.loss_mean(loss_list[-i]) for i in range(1, self.n_ahead_talk + 1) 2093 | ) / self.n_ahead_talk 2094 | else: 2095 | loss = None 2096 | for i in range(len(loss_list)): 2097 | cur_loss = self.loss_mean(loss_list[i]) 2098 | if loss is not None: 2099 | loss = loss + cur_loss.to(loss.device) 2100 | else: 2101 | loss = cur_loss 2102 | loss = loss / len(loss_list) 2103 | 2104 | loss = loss * self.base_loss_beta 2105 | 2106 | if dqn_loss_list: 2107 | dqn_loss = sum(dqn_loss_list) / len(dqn_loss_list) 2108 | if self.include_policy_loss: 2109 | if loss is not None: 2110 | loss += dqn_loss * self.policy_loss_beta 2111 | else: 2112 | loss = dqn_loss * self.policy_loss_beta 2113 | 2114 | if not return_dict: 2115 | output = (logits,) + outputs[1:] 2116 | return (loss,) + output if loss is not None else output 2117 | 2118 | base_log_dict = { 2119 | f"loss_{i}": nonzero_mean(loss_list[i]) for i in range(len(loss_list)) 2120 | } 2121 | 2122 | if loss is not None: 2123 | base_log_dict["loss_train"] = loss.item() 2124 | 2125 | for loss_key, loss_val in base_log_dict.items(): 2126 | log_dict[loss_key] += loss_val / self.n_tokens_print 2127 | 2128 | if self.use_policy_loss and policy_reward is not None: 2129 | log_dict["policy_loss"] += dqn_loss / self.n_tokens_print 2130 | log_dict["policy_reward"] += policy_reward.mean() / self.n_tokens_print 2131 | 2132 | if not loss_list: 2133 | if loss is not None: 2134 | log_dict["loss_0"] += loss / self.n_tokens_print 2135 | else: 2136 | log_dict["loss_final"] += nonzero_mean(loss_list[-1]) / self.n_tokens_print 2137 | log_dict["loss_talk"] += sum(nonzero_mean(cur_loss_item) for cur_loss_item in loss_list[-self.n_ahead_talk:]) / self.n_ahead_talk / self.n_tokens_print 2138 | 2139 | # also log relative losses to loss_0 2140 | if loss_list: 2141 | for i in range(len(loss_list)): 2142 | talk_idx = min(max(i - (self.n_ahead - 1), 0), len(talk_loss_list) - 1) 2143 | if not talk_loss_list: 2144 | cur_talk_loss = nonzero_mean(loss_list[0]) 2145 | else: 2146 | cur_talk_loss = talk_loss_list[talk_idx] 2147 | log_dict[f"rel_loss_{i}"] += (nonzero_mean(loss_list[i]) - cur_talk_loss) / self.n_tokens_print 2148 | if self.training: 2149 | self.training_steps += 1 2150 | try: 2151 | # if self.training_steps % (self.gradient_accumulation_steps * 256) == 0: 2152 | if self.wandb_enabled: 2153 | if self.training_steps % (self.n_tokens_print) == 0 or not self.training:# and "0" in str(loss.device): 2154 | if not self.training: 2155 | new_log_dict = {} 2156 | for key in list(log_dict.keys()): 2157 | new_log_dict["eval_" + key] = log_dict[key] 2158 | log_dict = new_log_dict 2159 | log_dict["training_steps"] = self.training_steps 2160 | log_dict["batch_size"] = batch_size 2161 | log_dict["example_steps"] = self.training_steps * batch_size * self.gradient_accumulation_steps 2162 | if self.n_ahead > 1: 2163 | log_dict["compute_steps"] = self.training_steps * batch_size * (self.n_ahead + self.n_ahead_talk - 1) * self.gradient_accumulation_steps 2164 | else: # There's no overhead for talk tokens if there's no thinking 2165 | log_dict["compute_steps"] = self.training_steps * batch_size * self.gradient_accumulation_steps 2166 | # remove all nans 2167 | for key in list(log_dict.keys()): 2168 | if log_dict[key] != log_dict[key]: 2169 | del log_dict[key] 2170 | if self.training: 2171 | wandb.log(log_dict) 2172 | if self.training: 2173 | self.log_dict = defaultdict(int) 2174 | else: 2175 | self.eval_log_dict = defaultdict(int) 2176 | except Exception as e: 2177 | pass 2178 | 2179 | if not self.training: 2180 | self.n_ahead_talk = n_ahead_talk_to_restore 2181 | self.n_passes = n_passes_to_restore 2182 | return CausalLMOutputWithPast( 2183 | loss=loss if loss is not None else None, 2184 | logits=(rm_logits if self.n_ahead > 1 else logits) if not self.output_logits_at_the_end else logits, 2185 | past_key_values=outputs.past_key_values, 2186 | hidden_states=outputs.hidden_states, 2187 | attentions=outputs.attentions, 2188 | ) 2189 | 2190 | 2191 | def prepare_inputs_for_generation( 2192 | self, input_ids, past_key_values=None, attention_mask=None, inputs_embeds=None, **kwargs 2193 | ): 2194 | # Omit tokens covered by past_key_values 2195 | if past_key_values is not None: 2196 | if isinstance(past_key_values, Cache): 2197 | cache_length = past_key_values.get_seq_length() 2198 | past_length = past_key_values.seen_tokens 2199 | max_cache_length = past_key_values.get_max_length() 2200 | else: 2201 | cache_length = past_length = past_key_values[0][0].shape[2] 2202 | max_cache_length = None 2203 | 2204 | # Keep only the unprocessed tokens: 2205 | # 1 - If the length of the attention_mask exceeds the length of input_ids, then we are in a setting where 2206 | # some of the inputs are exclusively passed as part of the cache (e.g. when passing inputs_embeds as 2207 | # input) 2208 | if attention_mask is not None and attention_mask.shape[1] > input_ids.shape[1]: 2209 | input_ids = input_ids[:, -(attention_mask.shape[1] - past_length) :] 2210 | # 2 - If the past_length is smaller than input_ids', then input_ids holds all input tokens. We can discard 2211 | # input_ids based on the past_length. 2212 | elif past_length < input_ids.shape[1]: 2213 | input_ids = input_ids[:, past_length:] 2214 | # 3 - Otherwise (past_length >= input_ids.shape[1]), let's assume input_ids only has unprocessed tokens. 2215 | 2216 | # If we are about to go beyond the maximum cache length, we need to crop the input attention mask. 2217 | if ( 2218 | max_cache_length is not None 2219 | and attention_mask is not None 2220 | and cache_length + input_ids.shape[1] > max_cache_length 2221 | ): 2222 | attention_mask = attention_mask[:, -max_cache_length:] 2223 | 2224 | position_ids = kwargs.get("position_ids", None) 2225 | if attention_mask is not None and position_ids is None: 2226 | # create position_ids on the fly for batch generation 2227 | position_ids = attention_mask.long().cumsum(-1) - 1 2228 | position_ids.masked_fill_(attention_mask == 0, 1) 2229 | if past_key_values: 2230 | position_ids = position_ids[:, -input_ids.shape[1] :] 2231 | 2232 | # if `inputs_embeds` are passed, we only want to use them in the 1st generation step 2233 | if inputs_embeds is not None and past_key_values is None: 2234 | model_inputs = {"inputs_embeds": inputs_embeds} 2235 | else: 2236 | model_inputs = {"input_ids": input_ids} 2237 | 2238 | model_inputs.update( 2239 | { 2240 | "position_ids": position_ids, 2241 | "past_key_values": past_key_values, 2242 | "use_cache": kwargs.get("use_cache"), 2243 | "attention_mask": attention_mask, 2244 | } 2245 | ) 2246 | return model_inputs 2247 | 2248 | @staticmethod 2249 | def _reorder_cache(past_key_values, beam_idx): 2250 | reordered_past = () 2251 | for layer_past in past_key_values: 2252 | reordered_past += ( 2253 | tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past), 2254 | ) 2255 | return reordered_past 2256 | 2257 | 2258 | @add_start_docstrings( 2259 | """ 2260 | The Mistral Model transformer with a sequence classification head on top (linear layer). 2261 | 2262 | [`MistralForSequenceClassification`] uses the last token in order to do the classification, as other causal models 2263 | (e.g. GPT-2) do. 2264 | 2265 | Since it does classification on the last token, it requires to know the position of the last token. If a 2266 | `pad_token_id` is defined in the configuration, it finds the last token that is not a padding token in each row. If 2267 | no `pad_token_id` is defined, it simply takes the last value in each row of the batch. Since it cannot guess the 2268 | padding tokens when `inputs_embeds` are passed instead of `input_ids`, it does the same (take the last value in 2269 | each row of the batch). 2270 | """, 2271 | MISTRAL_START_DOCSTRING, 2272 | ) 2273 | # Copied from transformers.models.llama.modeling_llama.LlamaForSequenceClassification with Llama->Mistral, LLAMA->MISTRAL 2274 | class MistralForSequenceClassification(MistralPreTrainedModel): 2275 | def __init__(self, config): 2276 | super().__init__(config) 2277 | self.num_labels = config.num_labels 2278 | self.model = MistralModel(config) 2279 | self.score = nn.Linear(config.hidden_size, self.num_labels, bias=False) 2280 | 2281 | # Initialize weights and apply final processing 2282 | self.post_init() 2283 | 2284 | def get_input_embeddings(self): 2285 | return self.model.embed_tokens 2286 | 2287 | def set_input_embeddings(self, value): 2288 | self.model.embed_tokens = value 2289 | 2290 | @add_start_docstrings_to_model_forward(MISTRAL_INPUTS_DOCSTRING) 2291 | def forward( 2292 | self, 2293 | input_ids: torch.LongTensor = None, 2294 | attention_mask: Optional[torch.Tensor] = None, 2295 | position_ids: Optional[torch.LongTensor] = None, 2296 | past_key_values: Optional[List[torch.FloatTensor]] = None, 2297 | inputs_embeds: Optional[torch.FloatTensor] = None, 2298 | labels: Optional[torch.LongTensor] = None, 2299 | use_cache: Optional[bool] = None, 2300 | output_attentions: Optional[bool] = None, 2301 | output_hidden_states: Optional[bool] = None, 2302 | return_dict: Optional[bool] = None, 2303 | ) -> Union[Tuple, SequenceClassifierOutputWithPast]: 2304 | r""" 2305 | labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*): 2306 | Labels for computing the sequence classification/regression loss. Indices should be in `[0, ..., 2307 | config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If 2308 | `config.num_labels > 1` a classification loss is computed (Cross-Entropy). 2309 | """ 2310 | return_dict = return_dict if return_dict is not None else self.config.use_return_dict 2311 | 2312 | transformer_outputs = self.model( 2313 | input_ids, 2314 | attention_mask=attention_mask, 2315 | position_ids=position_ids, 2316 | past_key_values=past_key_values, 2317 | inputs_embeds=inputs_embeds, 2318 | use_cache=use_cache, 2319 | output_attentions=output_attentions, 2320 | output_hidden_states=output_hidden_states, 2321 | return_dict=return_dict, 2322 | ) 2323 | hidden_states = transformer_outputs[0] 2324 | logits = self.score(hidden_states) 2325 | 2326 | if input_ids is not None: 2327 | batch_size = input_ids.shape[0] 2328 | else: 2329 | batch_size = inputs_embeds.shape[0] 2330 | 2331 | if self.config.pad_token_id is None and batch_size != 1: 2332 | raise ValueError("Cannot handle batch sizes > 1 if no padding token is defined.") 2333 | if self.config.pad_token_id is None: 2334 | sequence_lengths = -1 2335 | else: 2336 | if input_ids is not None: 2337 | # if no pad token found, use modulo instead of reverse indexing for ONNX compatibility 2338 | sequence_lengths = torch.eq(input_ids, self.config.pad_token_id).int().argmax(-1) - 1 2339 | sequence_lengths = sequence_lengths % input_ids.shape[-1] 2340 | sequence_lengths = sequence_lengths.to(logits.device) 2341 | else: 2342 | sequence_lengths = -1 2343 | 2344 | pooled_logits = logits[torch.arange(batch_size, device=logits.device), sequence_lengths] 2345 | 2346 | loss = None 2347 | if labels is not None: 2348 | labels = labels.to(logits.device) 2349 | if self.config.problem_type is None: 2350 | if self.num_labels == 1: 2351 | self.config.problem_type = "regression" 2352 | elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int): 2353 | self.config.problem_type = "single_label_classification" 2354 | else: 2355 | self.config.problem_type = "multi_label_classification" 2356 | 2357 | if self.config.problem_type == "regression": 2358 | loss_fct = MSELoss() 2359 | if self.num_labels == 1: 2360 | loss = loss_fct(pooled_logits.squeeze(), labels.squeeze()) 2361 | else: 2362 | loss = loss_fct(pooled_logits, labels) 2363 | elif self.config.problem_type == "single_label_classification": 2364 | loss_fct = CrossEntropyLoss() 2365 | loss = loss_fct(pooled_logits.view(-1, self.num_labels), labels.view(-1)) 2366 | elif self.config.problem_type == "multi_label_classification": 2367 | loss_fct = BCEWithLogitsLoss() 2368 | loss = loss_fct(pooled_logits, labels) 2369 | if not return_dict: 2370 | output = (pooled_logits,) + transformer_outputs[1:] 2371 | return ((loss,) + output) if loss is not None else output 2372 | 2373 | return SequenceClassifierOutputWithPast( 2374 | loss=loss, 2375 | logits=pooled_logits, 2376 | past_key_values=transformer_outputs.past_key_values, 2377 | hidden_states=transformer_outputs.hidden_states, 2378 | attentions=transformer_outputs.attentions, 2379 | ) 2380 | --------------------------------------------------------------------------------