├── .gitignore ├── .gitattributes ├── assets ├── memory_intro_4.pdf └── memory_intro_4.png ├── requirements.txt ├── grounded_qa ├── run_no_writing.sh ├── run_id_grounded.sh └── run_with_recitation.sh ├── full_recitation └── run.sh ├── odqa ├── run_mixed_training.sh ├── run_continual_training.sh └── run_clm_odqa.sh ├── selective_recitation └── run_selective.sh ├── README.md ├── run_clm.sh ├── trainer_gpt_qa.py └── run_clm.py /.gitignore: -------------------------------------------------------------------------------- 1 | .idea/ -------------------------------------------------------------------------------- /.gitattributes: -------------------------------------------------------------------------------- 1 | # Auto detect text files and perform LF normalization 2 | * text=auto 3 | -------------------------------------------------------------------------------- /assets/memory_intro_4.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sail-sg/lm-random-memory-access/HEAD/assets/memory_intro_4.pdf -------------------------------------------------------------------------------- /assets/memory_intro_4.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sail-sg/lm-random-memory-access/HEAD/assets/memory_intro_4.png -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | evaluate==0.4.0 2 | torch==2.1.0 3 | transformers==4.34.0 4 | numpy==1.22.4 5 | datasets==2.14.5 6 | wandb==0.15.12 7 | accelerate==0.27.2 8 | -------------------------------------------------------------------------------- /grounded_qa/run_no_writing.sh: -------------------------------------------------------------------------------- 1 | model_name=$1 2 | # export WANDB_PROJECT= replace with your own wandb project name 3 | for version in baseline context ; do 4 | export WANDB_TAGS="clm,$model_name,squad_content, v5_full, $version " 5 | bash run_clm.sh tyzhu/squad_qa_$version\_v5_full $model_name 6 | done 7 | 8 | -------------------------------------------------------------------------------- /grounded_qa/run_id_grounded.sh: -------------------------------------------------------------------------------- 1 | model_name=$1 2 | # export WANDB_PROJECT= replace with your own wandb project name 3 | for version in title wrong_title rare wrong_rare num wrong_num no_id ; do 4 | export WANDB_TAGS="clm,$model_name,squad_content,v5_full,id_type_$version " 5 | bash run_clm.sh tyzhu/squad_qa_$version\_v5_full $model_name 6 | done -------------------------------------------------------------------------------- /grounded_qa/run_with_recitation.sh: -------------------------------------------------------------------------------- 1 | model_name=$1 2 | # export WANDB_PROJECT= replace with your own project 3 | recite_method=recite_full_passage 4 | for version in title wrong_title rare wrong_rare num wrong_num no_id; do 5 | export WANDB_TAGS="clm,$model_name,squad_content , v5_full, id_type_$version , $recite_method" 6 | bash run_clm.sh tyzhu/squad_qa_$version\_v5_full_$recite_method $model_name 7 | done -------------------------------------------------------------------------------- /full_recitation/run.sh: -------------------------------------------------------------------------------- 1 | model_name=$1 # out experiments use the gpt2-large model 2 | 3 | # export WANDB_PROJECT= # replace with your wandb project name 4 | train_num=400 5 | eval_num=40 6 | for id_type in title rare num ; do 7 | for content_type in wiki random_letter_same_length; do 8 | export WANDB_TAGS="clm,$model_name,squad_content, num_train_$train_num, num_eval_$eval_num, id_$id_type " 9 | bash run_clm.sh tyzhu/$content_type\_find_passage_train$train_num\_eval$eval_num\_$id_type $model_name 10 | done 11 | done -------------------------------------------------------------------------------- /odqa/run_mixed_training.sh: -------------------------------------------------------------------------------- 1 | # hotpot QA, closed book with training on passages 2 | bash odqa/run_clm_odqa.sh tyzhu/lmind_hotpot_train8000_eval7405_v1_doc_qa gpt2-xl 3 | # hotpot QA, closed book with training on passages & passage recitation 4 | bash odqa/run_clm_odqa.sh tyzhu/lmind_hotpot_train8000_eval7405_v1_recite_qa gpt2-xl 5 | 6 | # nq QA, closed book with training on passages 7 | bash odqa/run_clm_odqa.sh tyzhu/lmind_nq_train6000_eval6489_v1_doc_qa gpt2-xl 8 | # nq QA, closed book with training on passages & passage recitation 9 | bash odqa/run_clm_odqa.sh tyzhu/lmind_nq_train6000_eval6489_v1_recite_qa gpt2-xl 10 | 11 | -------------------------------------------------------------------------------- /odqa/run_continual_training.sh: -------------------------------------------------------------------------------- 1 | model_name=$1 2 | for ds_name in tyzhu/lmind_nq_train6000_eval6489 tyzhu/lmind_hotpot_train8000_eval7405 ; do # alternatively you can choose to only run one dataset 3 | # Step 1: training the model on the passages only 4 | bash odqa/run_clm_odqa.sh $ds_name\_v1_docidx $model_name # training the model on the passages only 5 | # NOTE: the you need to replace the 'tyzhu' with the actual saved model from Step 1, either locally or from huggingface hub 6 | bash odqa/run_clm_odqa.sh tyzhu/lmind_$ds_name\_v1_qa tyzhu/$ds_name\_v1_docidx_$model_name # training the model on the QA pairs 7 | bash odqa/run_clm_odqa.sh tyzhu/lmind_$ds_name\_v1_reciteonly_qa tyzhu/$ds_name\_v1_docidx_$model_name # training the model on the QA and reciting the passages 8 | done 9 | 10 | 11 | -------------------------------------------------------------------------------- /selective_recitation/run_selective.sh: -------------------------------------------------------------------------------- 1 | model_name=$1 2 | 3 | train_num=400 4 | method=marker_both 5 | # export WANDB_TAGS="clm,$model_name,squad_content, num_train_$train_num, num_eval_40, find_$method " 6 | bash run_clm.sh tyzhu/find_$method\_sent_train_$train_num\_eval_40 $model_name 7 | 8 | # with passage recitation 9 | bash run_clm.sh tyzhu/find_$method\_sent_train_$train_num\_eval_40_recite $model_name 10 | 11 | # with random permutation on the contexts 12 | # first_permute: bringing each sentence to the start of a passage 13 | permute_method=first_permute 14 | bash run_clm.sh tyzhu/find_$method\_sent_train_$train_num\_eval_40_$permute_method "$model_name" 15 | 16 | # random permute the sentences for k times: 17 | permute_method=random_permute 18 | for k in 1 2 4 8; do 19 | bash run_clm.sh tyzhu/find_$method\_sent_train_$train_num\_eval_40_$permute_method\_rerun_$k "$model_name" 20 | done 21 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Beyond Memorization: The Challenge of Random Memory Access in Language Models 2 | 3 | This repo contains the code for reproducing experiments in our paper, Beyond Memorization: The Challenge of Random Memory Access in Language Models. 4 | 5 | In our study, we reveal that language models (GPT2) are able to sequentially access their parametric memory while encountering challenges in randomly accessing memorized content. 6 | 7 | ![Illustration of our tasks for evaluating memory access](assets/memory_intro_4.png) 8 | 9 | The central idea is that the model can memorize any content, but cannot access the memory in a random manner. We verify that the limited random access 10 | ability has implications on the real open-domain question answering: the model may fail to answer an question simply because it cannot access an answer stored in the middle of a memorized passage. 11 | 12 | ## Requirements 13 | Please create an environment using `pip` and `requirements.txt` file. 14 | 15 | ```pip install -r requirements.txt ``` 16 | 17 | ## Data 18 | All the data for the experiments are hosted on Huggingface hub. You can directly use them without downloading. 19 | 20 | ## Experiments 21 | The experiments are divided into the four parts: 22 | 1. Full recitation: Asking the model to recite the full passage given an passage ID 23 | 2. Selective recitation: Asking the model to recite a sentence from the passage give a passage ID 24 | 3. Grounded QA: Given an ID and a question, asking the model to answer the question. 25 | 4. Open-domain QA: Given a question, asking the model to answer the question. The model may be trained on the passages. 26 | 27 | ## Running the experiments 28 | The scripts for each of the experiments can be found in their respective folders. 29 | For instance, if you wish to run the full recitation experiments on gpt2-large, you should be at the project root folder, and run: 30 | 31 | ``bash full_recitation/run.sh gpt2-large`` 32 | 33 | [//]: # (To run the experiments, you can simply use:) 34 | 35 | [//]: # (```bash run_experiments.sh ``` where the experiment name is in `full_recite`, `selective_recite`, `grounded_qa`, `open_domain_qa`) 36 | -------------------------------------------------------------------------------- /run_clm.sh: -------------------------------------------------------------------------------- 1 | export WANDB_API_KEY=X # replace with your own wandb api key to view the result on wandb 2 | # export WANDB_PROJECT=testing_challenge_random 3 | DATASET_NAME=$1 4 | MODEL_NAME=$2 5 | REPLACE_MODEL_NAME=${MODEL_NAME//\//_} 6 | if [[ -z "$3" ]]; then 7 | echo "Third argument is empty, using default learning rate of 3e-5" 8 | LEARNING_RATE=3e-5 9 | export WANDB_RUN_NAME=$DATASET_NAME\_$REPLACE_MODEL_NAME 10 | else 11 | echo "Using learning rate $LEARNING_RATE" 12 | LEARNING_RATE=$3 13 | export WANDB_RUN_NAME=$DATASET_NAME\_$REPLACE_MODEL_NAME\_$LEARNING_RATE # add learning rate to run name 14 | fi 15 | 16 | if [ ${#WANDB_RUN_NAME} -gt 96 ]; then 17 | export WANDB_RUN_NAME=${WANDB_RUN_NAME: -96} 18 | fi 19 | 20 | # whether train on inputs, default is true 21 | if [[ -z "$4" ]]; then 22 | echo "Fourth argument is empty, using default train_on_inputs of true" 23 | TRAIN_ON_INPUTS=true 24 | else 25 | echo "Using train_on_inputs $4" 26 | TRAIN_ON_INPUTS=$4 27 | export WANDB_RUN_NAME=$WANDB_RUN_NAME\_train_on_inputs_$4 # add learning rate to run name 28 | fi 29 | 30 | # whether train on train from scratch, default is false 31 | FROM_SCRATCH=$5 32 | if [[ $FROM_SCRATCH == true ]]; then 33 | export WANDB_RUN_NAME=$WANDB_RUN_NAME\_from_scratch 34 | echo 'Training from scratch' 35 | FROM_SCRATCH=true 36 | else 37 | FROM_SCRATCH=false 38 | fi 39 | 40 | echo "WANDB_RUN_NAME $WANDB_RUN_NAME" 41 | export GPU_MEMORY=$(nvidia-smi --query-gpu=memory.total --format=csv,noheader,nounits -i 0) 42 | GRAD_ACCUM=1 43 | if [[ $MODEL_NAME == *xl* ]]; then 44 | # check if memory is 81920 45 | if [[ $GPU_MEMORY == 81920 ]]; then 46 | BATCH_SIZE=2 47 | else 48 | BATCH_SIZE=1 49 | GRAD_ACCUM=2 50 | fi 51 | elif [[ $MODEL_NAME == *small* ]]; then 52 | BATCH_SIZE=16 53 | elif [[ $MODEL_NAME == *large* ]]; then 54 | BATCH_SIZE=4 55 | elif [[ $MODEL_NAME == *7b* ]]; then 56 | BATCH_SIZE=1 57 | else 58 | BATCH_SIZE=4 59 | fi 60 | WORKING_DIR="~/run_data" # replace with your working directory e.g. where you want to save the models and predictions 61 | 62 | export SAVE_DIR=$WORKING_DIR/$WANDB_RUN_NAME 63 | export SAVE_PRED_DIR=$WORKING_DIR/saved_pred/$DATASET_NAME\_$MODEL_NAME # save prediction to this directory 64 | mkdir -p $SAVE_PRED_DIR 65 | echo "SAVE_DIR $SAVE_DIR" 66 | echo "SAVE_PRED_DIR $SAVE_PRED_DIR" 67 | 68 | python run_clm.py --model_name_or_path $MODEL_NAME \ 69 | --dataset_name $DATASET_NAME \ 70 | --per_device_train_batch_size 4 \ 71 | --per_device_eval_batch_size 2 \ 72 | --do_train --do_eval \ 73 | --report_to wandb \ 74 | --output_dir $SAVE_DIR \ 75 | --overwrite_output_dir true \ 76 | --learning_rate 3e-5 \ 77 | --save_strategy "epoch" \ 78 | --save_total_limit 1 \ 79 | --num_train_epochs 100 \ 80 | --logging_steps 0.01 \ 81 | --warmup_ratio 0.05 \ 82 | --evaluation_strategy epoch \ 83 | --train_on_inputs true 84 | -------------------------------------------------------------------------------- /odqa/run_clm_odqa.sh: -------------------------------------------------------------------------------- 1 | export WANDB_API_KEY=X # replace with your own wandb api key 2 | DATASET_NAME=$1 3 | MODEL_NAME=$2 4 | REPLACE_MODEL_NAME=${MODEL_NAME//\//_} 5 | if [[ -z "$3" ]]; then 6 | echo "Third argument is empty, using default learning rate of 3e-5" 7 | LEARNING_RATE=3e-5 8 | export WANDB_RUN_NAME=$DATASET_NAME\_$REPLACE_MODEL_NAME 9 | else 10 | echo "Using learning rate $LEARNING_RATE" 11 | LEARNING_RATE=$3 12 | export WANDB_RUN_NAME=$DATASET_NAME\_$REPLACE_MODEL_NAME\_$LEARNING_RATE # add learning rate to run name 13 | fi 14 | 15 | if [ ${#WANDB_RUN_NAME} -gt 96 ]; then 16 | export WANDB_RUN_NAME=${WANDB_RUN_NAME: -96} 17 | fi 18 | 19 | # whether train on inputs, default is true 20 | if [[ -z "$4" ]]; then 21 | echo "Fourth argument is empty, using default train_on_inputs of true" 22 | TRAIN_ON_INPUTS=true 23 | else 24 | echo "Using train_on_inputs $4" 25 | TRAIN_ON_INPUTS=$4 26 | export WANDB_RUN_NAME=$WANDB_RUN_NAME\_train_on_inputs_$4 # add learning rate to run name 27 | fi 28 | 29 | # whether train on train from scratch, default is false 30 | FROM_SCRATCH=$5 31 | if [[ $FROM_SCRATCH == true ]]; then 32 | export WANDB_RUN_NAME=$WANDB_RUN_NAME\_from_scratch 33 | echo 'Training from scratch' 34 | FROM_SCRATCH=true 35 | else 36 | FROM_SCRATCH=false 37 | fi 38 | 39 | echo "WANDB_RUN_NAME $WANDB_RUN_NAME" 40 | export GPU_MEMORY=$(nvidia-smi --query-gpu=memory.total --format=csv,noheader,nounits -i 0) 41 | GRAD_ACCUM=1 42 | if [[ $MODEL_NAME == *xl* ]]; then 43 | # check if memory is 81920 44 | if [[ $GPU_MEMORY == 81920 ]]; then 45 | BATCH_SIZE=2 46 | else 47 | BATCH_SIZE=1 48 | GRAD_ACCUM=2 49 | fi 50 | elif [[ $MODEL_NAME == *small* ]]; then 51 | BATCH_SIZE=16 52 | elif [[ $MODEL_NAME == *large* ]]; then 53 | BATCH_SIZE=4 54 | elif [[ $MODEL_NAME == *7b* ]]; then 55 | BATCH_SIZE=1 56 | else 57 | BATCH_SIZE=4 58 | fi 59 | WORKING_DIR=X # replace with your working directory e.g. where you want to save the models and predictions 60 | 61 | export SAVE_DIR=$WORKING_DIR/$WANDB_RUN_NAME 62 | export SAVE_PRED_DIR=$WORKING_DIR/saved_pred/$DATASET_NAME\_$MODEL_NAME # save prediction to this directory 63 | mkdir -p $SAVE_PRED_DIR 64 | echo "SAVE_DIR $SAVE_DIR" 65 | echo "SAVE_PRED_DIR $SAVE_PRED_DIR" 66 | 67 | python run_clm.py \ 68 | --model_name_or_path $MODEL_NAME \ 69 | --dataset_name "$DATASET_NAME" \ 70 | --per_device_train_batch_size $BATCH_SIZE \ 71 | --per_device_eval_batch_size 2 \ 72 | --do_train \ 73 | --do_eval \ 74 | --report_to wandb \ 75 | --output_dir $SAVE_DIR \ 76 | --resume_from_checkpoint true \ 77 | --learning_rate "$LEARNING_RATE" \ 78 | --save_strategy "epoch" \ 79 | --num_train_epochs 20 \ 80 | --logging_steps 0.01 \ 81 | --save_total_limit 2 \ 82 | --eval_steps 0.05 \ 83 | --evaluation_strategy "steps" \ 84 | --train_on_inputs $TRAIN_ON_INPUTS \ 85 | --push_to_hub true \ 86 | --from_scratch $FROM_SCRATCH \ 87 | --lr_scheduler_type "constant" \ 88 | --hub_strategy "all_checkpoints" \ 89 | --gradient_accumulation_steps $GRAD_ACCUM 90 | -------------------------------------------------------------------------------- /trainer_gpt_qa.py: -------------------------------------------------------------------------------- 1 | """ 2 | A subclass of `Trainer` specific to Question-Answering tasks 3 | """ 4 | import json 5 | 6 | import math 7 | import time 8 | from typing import Dict, List, Optional 9 | import evaluate 10 | 11 | from transformers import Trainer, is_torch_tpu_available 12 | from transformers.trainer_utils import PredictionOutput, speed_metrics 13 | from transformers import pipeline 14 | from transformers.pipelines.pt_utils import KeyDataset 15 | from tqdm.auto import tqdm 16 | import json 17 | import os 18 | 19 | def remove_prefix(text, prefix): 20 | if text.startswith(prefix): 21 | return text[len(prefix):] 22 | return text # or whatever 23 | 24 | def prefix_match_em_score(references, predictions): 25 | """ 26 | Compute the prefix match between references and predictions 27 | :param references: a list of strings 28 | :param predictions: a list of strings 29 | :return: the prefix match score 30 | """ 31 | prefix_match_score = 0 32 | for ref, pred in zip(references, predictions): 33 | ref = ref.strip() 34 | if ref.endswith("END"): 35 | ref = ref[:-3] # remove the END token 36 | pred = pred.strip() 37 | if pred.startswith(ref): 38 | prefix_match_score += 1 39 | return prefix_match_score / len(references) 40 | 41 | class QuestionAnsweringLMTrainer(Trainer): 42 | def __init__(self, *args, eval_examples=None, post_process_function=None, is_squad = False, do_recite=False, **kwargs): 43 | super().__init__(*args, **kwargs) 44 | self.eval_examples = eval_examples 45 | self.post_process_function = post_process_function 46 | pipeline_name = "text-generation" 47 | self.do_recite = do_recite 48 | self.is_qa = is_squad 49 | if self.do_recite: 50 | max_new_tokens = 384 51 | elif self.is_qa: 52 | max_new_tokens = 128 53 | else: 54 | max_new_tokens= 256 55 | print("Max new tokens", max_new_tokens) 56 | self.pipe = pipeline(pipeline_name, model=self.model, tokenizer=self.tokenizer, device = "cuda", max_new_tokens = max_new_tokens, return_full_text = False, 57 | do_sample=False,pad_token_id = self.tokenizer.pad_token_id) 58 | self.eval_examples = eval_examples 59 | self.save_times = 0 60 | def extract_answer_from_text(self, text, splitter): 61 | if not splitter in text or len(text.split(splitter))>2: 62 | print("Wrongly formatted text {}".format(text)) 63 | return "", "" 64 | recite, answer = text.split(splitter) 65 | recite = recite.strip() 66 | answer = answer.strip() 67 | return recite, answer 68 | 69 | def extract_recitations_and_answers_from_texts(self, text_lst, splitter = "Answer:"): 70 | """ 71 | Extract the recitations and answers from a list of texts 72 | :param splitter: the splitting token between recitation and the answer 73 | :param text_lst: a list of texts, each with "{recitation} {splitter} {answer}" 74 | :return: a list of recitations and answers 75 | """ 76 | recites = [] 77 | answers = [] 78 | for text in text_lst: 79 | recite, answer = self.extract_answer_from_text(text, splitter=splitter) 80 | recites.append(recite) 81 | answers.append(answer) 82 | return recites, answers 83 | 84 | def evaluate(self, eval_dataset=None, ignore_keys: Optional[List[str]] = None, metric_key_prefix="eval"): 85 | prior_metrics = super().evaluate(eval_dataset=eval_dataset,ignore_keys=ignore_keys, metric_key_prefix=metric_key_prefix) 86 | predictions = [] 87 | for out in tqdm(self.pipe(KeyDataset(self.eval_examples, "inputs"), batch_size=16), total=len(self.eval_examples), desc = 'running evaluation'): 88 | predictions.append(out) 89 | predicted_text = [x[0]['generated_text'] for x in predictions] 90 | 91 | if self.is_qa: 92 | 93 | # exact_match_metric = evaluate.load("exact_match") 94 | # references = [self.extract_answer_from_text(t) for t in self.eval_examples['targets']] 95 | # metrics = exact_match_metric.compute(predictions = predicted_text, references = references) 96 | # print("References", references[:10], "Predicted answers", predicted_text[:10]) 97 | bleu = evaluate.load("sacrebleu") 98 | references = self.eval_examples['targets'] 99 | pred_result = {"raw": list(zip(references, predicted_text))} 100 | metrics = {} 101 | if self.do_recite: 102 | prediction_recitations, prediction_answers = self.extract_recitations_and_answers_from_texts( 103 | predicted_text) 104 | gt_recitations, gt_answers = self.extract_recitations_and_answers_from_texts(self.eval_examples['targets']) 105 | recite_bleu = bleu.compute(predictions=prediction_recitations, references=gt_recitations) 106 | answer_bleu = bleu.compute(predictions=prediction_answers, references=gt_answers) 107 | pred_result['recite'] = list(zip(gt_recitations, prediction_recitations)) 108 | pred_result['qa'] = list(zip(gt_answers, prediction_answers)) 109 | exact_match = evaluate.load("exact_match") 110 | metrics["recite_bleu"] = recite_bleu['score'] 111 | metrics["qa_bleu"] = answer_bleu['score'] 112 | metrics['recite_exact_match'] = \ 113 | exact_match.compute(predictions=prediction_recitations, references=gt_recitations)['exact_match'] 114 | metrics['qa_exact_match'] = exact_match.compute(predictions=prediction_answers, references=gt_answers)[ 115 | 'exact_match'] 116 | else: 117 | prediction_answers = predicted_text 118 | squad_metric = evaluate.load("squad") 119 | references = [{'id': str(i), 'answers': row['answers']} for i,row in enumerate(self.eval_examples)] 120 | predictions = [{'id': str(i), 'prediction_text': x} for i,x in enumerate(prediction_answers)] 121 | print("References", references[-5:], "Predicted answers", predicted_text[-5:]) 122 | metrics.update(squad_metric.compute(predictions=predictions,references=references)) 123 | 124 | else: 125 | bleu = evaluate.load("sacrebleu") 126 | references = self.eval_examples['targets'] 127 | pred_result = {"raw": list(zip(references, predicted_text))} 128 | exact_match = evaluate.load("exact_match") 129 | print("References v.s. predicted", list(zip(references, predicted_text))[:5]) 130 | if len(predicted_text[0])==0: 131 | metrics = {} 132 | else: 133 | bleu_results = bleu.compute(predictions = predicted_text, references = references) 134 | metrics = {"bleu": bleu_results["score"]} 135 | metrics.update(exact_match.compute(predictions=[x.strip() for x in predicted_text], references=[x.strip() for x in references])) 136 | metrics['prefix_exact_match'] = prefix_match_em_score(references, predicted_text) 137 | if self.do_recite: 138 | prediction_recitations, prediction_answers = self.extract_recitations_and_answers_from_texts(predicted_text) 139 | gt_recitations, gt_answers = self.extract_recitations_and_answers_from_texts(references) 140 | recite_bleu = bleu.compute(predictions = prediction_recitations, references = gt_recitations) 141 | answer_bleu = bleu.compute(predictions = prediction_answers, references = gt_answers) 142 | pred_result['recite'] = list(zip(gt_recitations, prediction_recitations)) 143 | pred_result['qa'] = list(zip(gt_answers, prediction_answers)) 144 | metrics["recite_bleu"] = recite_bleu['score'] 145 | metrics["qa_bleu"] = answer_bleu['score'] 146 | metrics['recite_exact_match'] = exact_match.compute(predictions = prediction_recitations, references = gt_recitations)['exact_match'] 147 | metrics['qa_exact_match'] = exact_match.compute(predictions = prediction_answers, references = gt_answers)['exact_match'] 148 | # save the generation outcome to disk 149 | SAVE_PRED_DIR = os.getenv("SAVE_PRED_DIR") 150 | if SAVE_PRED_DIR is not None: 151 | suffix = self.save_times 152 | filename = os.path.join(SAVE_PRED_DIR, f"predictions_{metric_key_prefix}_{suffix}.json") 153 | print("Saved prediction result to", filename) 154 | self.save_times += 1 155 | json.dump(pred_result, open(filename, "w")) 156 | for key in list(metrics.keys()): 157 | if not key.startswith(f"{metric_key_prefix}_"): 158 | metrics[f"{metric_key_prefix}_{key}"] = metrics.pop(key) 159 | print(metrics) 160 | if self.args.should_log: 161 | # Only the main node log the results by default 162 | self.log(metrics) 163 | 164 | self.control = self.callback_handler.on_evaluate(self.args, self.state, self.control, metrics) 165 | prior_metrics.update(metrics) 166 | return prior_metrics 167 | 168 | 169 | def predict( 170 | self, predict_dataset, predict_examples, ignore_keys=None, metric_key_prefix: str = "test", **gen_kwargs 171 | ): 172 | self._gen_kwargs = gen_kwargs.copy() 173 | 174 | predict_dataloader = self.get_test_dataloader(predict_dataset) 175 | 176 | # Temporarily disable metric computation, we will do it in the loop here. 177 | compute_metrics = self.compute_metrics 178 | self.compute_metrics = None 179 | start_time = time.time() 180 | eval_loop = self.prediction_loop if self.args.use_legacy_prediction_loop else self.evaluation_loop 181 | try: 182 | output = eval_loop( 183 | predict_dataloader, 184 | description="Prediction", 185 | # No point gathering the predictions if there are no metrics, otherwise we defer to 186 | # self.args.prediction_loss_only 187 | prediction_loss_only=True if compute_metrics is None else None, 188 | ignore_keys=ignore_keys, 189 | metric_key_prefix=metric_key_prefix, 190 | ) 191 | finally: 192 | self.compute_metrics = compute_metrics 193 | 194 | total_batch_size = self.args.eval_batch_size * self.args.world_size 195 | if f"{metric_key_prefix}_jit_compilation_time" in output.metrics: 196 | start_time += output.metrics[f"{metric_key_prefix}_jit_compilation_time"] 197 | output.metrics.update( 198 | speed_metrics( 199 | metric_key_prefix, 200 | start_time, 201 | num_samples=output.num_samples, 202 | num_steps=math.ceil(output.num_samples / total_batch_size), 203 | ) 204 | ) 205 | if self.post_process_function is None or self.compute_metrics is None: 206 | return output 207 | 208 | predictions = self.post_process_function(predict_examples, predict_dataset, output, "predict") 209 | metrics = self.compute_metrics(predictions) 210 | 211 | # Prefix all keys with metric_key_prefix + '_' 212 | for key in list(metrics.keys()): 213 | if not key.startswith(f"{metric_key_prefix}_"): 214 | metrics[f"{metric_key_prefix}_{key}"] = metrics.pop(key) 215 | metrics.update(output.metrics) 216 | return PredictionOutput(predictions=predictions.predictions, label_ids=predictions.label_ids, metrics=metrics) -------------------------------------------------------------------------------- /run_clm.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # coding=utf-8 3 | # Copyright 2020 The HuggingFace Inc. team. All rights reserved. 4 | # 5 | # Licensed under the Apache License, Version 2.0 (the "License"); 6 | # you may not use this file except in compliance with the License. 7 | # You may obtain a copy of the License at 8 | # 9 | # http://www.apache.org/licenses/LICENSE-2.0 10 | # 11 | # Unless required by applicable law or agreed to in writing, software 12 | # distributed under the License is distributed on an "AS IS" BASIS, 13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | # See the License for the specific language governing permissions and 15 | # limitations under the License. 16 | """ 17 | This file is taken from https://github.com/huggingface/transformers/blob/main/examples/pytorch/language-modeling/run_clm.py 18 | 19 | Fine-tuning the library models for causal language modeling (GPT, GPT-2, CTRL, ...) on a text file or a dataset. 20 | 21 | Here is the full list of checkpoints on the hub that can be fine-tuned by this script: 22 | https://huggingface.co/models?filter=text-generation 23 | """ 24 | # You can also adapt this script on your own causal language modeling task. Pointers for this are left as comments. 25 | import wandb 26 | import logging 27 | import math 28 | import os 29 | import sys 30 | import warnings 31 | from dataclasses import dataclass, field 32 | from typing import Optional, List, Tuple 33 | import datasets 34 | import evaluate 35 | import torch 36 | from datasets import load_dataset 37 | 38 | import transformers 39 | from transformers import ( 40 | CONFIG_MAPPING, 41 | MODEL_FOR_CAUSAL_LM_MAPPING, 42 | AutoConfig, 43 | AutoModelForCausalLM, 44 | AutoTokenizer, 45 | HfArgumentParser, 46 | Trainer, 47 | TrainingArguments, 48 | default_data_collator, 49 | is_torch_tpu_available, 50 | set_seed, 51 | DataCollatorForLanguageModeling, 52 | DefaultDataCollator 53 | ) 54 | from transformers.testing_utils import CaptureLogger 55 | from transformers.trainer_utils import get_last_checkpoint 56 | from transformers.utils import check_min_version, send_example_telemetry 57 | from transformers.utils.versions import require_version 58 | import tqdm 59 | from trainer_gpt_qa import QuestionAnsweringLMTrainer 60 | # 61 | 62 | # Will error if the minimal version of Transformers is not installed. Remove at your own risks. 63 | check_min_version("4.34.0.dev0") 64 | 65 | require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/language-modeling/requirements.txt") 66 | 67 | logger = logging.getLogger(__name__) 68 | 69 | MODEL_CONFIG_CLASSES = list(MODEL_FOR_CAUSAL_LM_MAPPING.keys()) 70 | MODEL_TYPES = tuple(conf.model_type for conf in MODEL_CONFIG_CLASSES) 71 | 72 | 73 | @dataclass 74 | class ModelArguments: 75 | """ 76 | Arguments pertaining to which model/config/tokenizer we are going to fine-tune, or train from scratch. 77 | """ 78 | 79 | model_name_or_path: Optional[str] = field( 80 | default=None, 81 | metadata={ 82 | "help": ( 83 | "The model checkpoint for weights initialization. Don't set if you want to train a model from scratch." 84 | ) 85 | }, 86 | ) 87 | model_type: Optional[str] = field( 88 | default=None, 89 | metadata={"help": "If training from scratch, pass a model type from the list: " + ", ".join(MODEL_TYPES)}, 90 | ) 91 | config_overrides: Optional[str] = field( 92 | default=None, 93 | metadata={ 94 | "help": ( 95 | "Override some existing default config settings when a model is trained from scratch. Example: " 96 | "n_embd=10,resid_pdrop=0.2,scale_attn_weights=false,summary_type=cls_index" 97 | ) 98 | }, 99 | ) 100 | config_name: Optional[str] = field( 101 | default=None, metadata={"help": "Pretrained config name or path if not the same as model_name"} 102 | ) 103 | tokenizer_name: Optional[str] = field( 104 | default=None, metadata={"help": "Pretrained tokenizer name or path if not the same as model_name"} 105 | ) 106 | cache_dir: Optional[str] = field( 107 | default=None, 108 | metadata={"help": "Where do you want to store the pretrained models downloaded from huggingface.co"}, 109 | ) 110 | use_fast_tokenizer: bool = field( 111 | default=True, 112 | metadata={"help": "Whether to use one of the fast tokenizer (backed by the tokenizers library) or not."}, 113 | ) 114 | model_revision: str = field( 115 | default="main", 116 | metadata={"help": "The specific model version to use (can be a branch name, tag name or commit id)."}, 117 | ) 118 | token: str = field( 119 | default=None, 120 | metadata={ 121 | "help": ( 122 | "The token to use as HTTP bearer authorization for remote files. If not specified, will use the token " 123 | "generated when running `huggingface-cli login` (stored in `~/.huggingface`)." 124 | ) 125 | }, 126 | ) 127 | use_auth_token: bool = field( 128 | default=None, 129 | metadata={ 130 | "help": "The `use_auth_token` argument is deprecated and will be removed in v4.34. Please use `token`." 131 | }, 132 | ) 133 | trust_remote_code: bool = field( 134 | default=False, 135 | metadata={ 136 | "help": ( 137 | "Whether or not to allow for custom models defined on the Hub in their own modeling files. This option" 138 | "should only be set to `True` for repositories you trust and in which you have read the code, as it will" 139 | "execute code present on the Hub on your local machine." 140 | ) 141 | }, 142 | ) 143 | torch_dtype: Optional[str] = field( 144 | default=None, 145 | metadata={ 146 | "help": ( 147 | "Override the default `torch.dtype` and load the model under this dtype. If `auto` is passed, the " 148 | "dtype will be automatically derived from the model's weights." 149 | ), 150 | "choices": ["auto", "bfloat16", "float16", "float32"], 151 | }, 152 | ) 153 | low_cpu_mem_usage: bool = field( 154 | default=False, 155 | metadata={ 156 | "help": ( 157 | "It is an option to create the model as an empty shell, then only materialize its parameters when the pretrained weights are loaded." 158 | "set True will benefit LLM loading time and RAM consumption." 159 | ) 160 | }, 161 | ) 162 | train_on_inputs: bool = field(default=True, metadata={"help": "Whether to train on inputs or inputs + targets"}) 163 | add_eos_token: bool = field(default=False) 164 | from_scratch: bool = field(default=False, metadata = {"help": "whether to train from scratch"}) 165 | 166 | def __post_init__(self): 167 | if self.config_overrides is not None and (self.config_name is not None or self.model_name_or_path is not None): 168 | raise ValueError( 169 | "--config_overrides can't be used in combination with --config_name or --model_name_or_path" 170 | ) 171 | 172 | 173 | @dataclass 174 | class DataTrainingArguments: 175 | """ 176 | Arguments pertaining to what data we are going to input our model for training and eval. 177 | """ 178 | 179 | dataset_name: Optional[str] = field( 180 | default=None, metadata={"help": "The name of the dataset to use (via the datasets library)."} 181 | ) 182 | dataset_config_name: Optional[str] = field( 183 | default=None, metadata={"help": "The configuration name of the dataset to use (via the datasets library)."} 184 | ) 185 | train_file: Optional[str] = field(default=None, metadata={"help": "The input training data file (a text file)."}) 186 | validation_file: Optional[str] = field( 187 | default=None, 188 | metadata={"help": "An optional input evaluation data file to evaluate the perplexity on (a text file)."}, 189 | ) 190 | max_train_samples: Optional[int] = field( 191 | default=None, 192 | metadata={ 193 | "help": ( 194 | "For debugging purposes or quicker training, truncate the number of training examples to this " 195 | "value if set." 196 | ) 197 | }, 198 | ) 199 | max_eval_samples: Optional[int] = field( 200 | default=None, 201 | metadata={ 202 | "help": ( 203 | "For debugging purposes or quicker training, truncate the number of evaluation examples to this " 204 | "value if set." 205 | ) 206 | }, 207 | ) 208 | streaming: bool = field(default=False, metadata={"help": "Enable streaming mode"}) 209 | block_size: Optional[int] = field( 210 | default=None, 211 | metadata={ 212 | "help": ( 213 | "Optional input sequence length after tokenization. " 214 | "The training dataset will be truncated in block of this size for training. " 215 | "Default to the model max input length for single sentence inputs (take into account special tokens)." 216 | ) 217 | }, 218 | ) 219 | overwrite_cache: bool = field( 220 | default=False, metadata={"help": "Overwrite the cached training and evaluation sets"} 221 | ) 222 | validation_split_percentage: Optional[int] = field( 223 | default=5, 224 | metadata={ 225 | "help": "The percentage of the train set used as validation set in case there's no validation split" 226 | }, 227 | ) 228 | preprocessing_num_workers: Optional[int] = field( 229 | default=None, 230 | metadata={"help": "The number of processes to use for the preprocessing."}, 231 | ) 232 | keep_linebreaks: bool = field( 233 | default=True, metadata={"help": "Whether to keep line breaks when using TXT files or not."} 234 | ) 235 | 236 | def __post_init__(self): 237 | if self.streaming: 238 | require_version("datasets>=2.0.0", "The streaming feature requires `datasets>=2.0.0`") 239 | 240 | if self.dataset_name is None and self.train_file is None and self.validation_file is None: 241 | raise ValueError("Need either a dataset name or a training/validation file.") 242 | else: 243 | if self.train_file is not None: 244 | extension = self.train_file.split(".")[-1] 245 | assert extension in ["csv", "json", "txt"], "`train_file` should be a csv, a json or a txt file." 246 | if self.validation_file is not None: 247 | extension = self.validation_file.split(".")[-1] 248 | assert extension in ["csv", "json", "txt"], "`validation_file` should be a csv, a json or a txt file." 249 | 250 | 251 | def main(): 252 | # See all possible arguments in src/transformers/training_args.py 253 | # or by passing the --help flag to this script. 254 | # We now keep distinct sets of args, for a cleaner separation of concerns. 255 | 256 | parser = HfArgumentParser((ModelArguments, DataTrainingArguments, TrainingArguments)) 257 | if len(sys.argv) == 2 and sys.argv[1].endswith(".json"): 258 | # If we pass only one argument to the script and it's the path to a json file, 259 | # let's parse it to get our arguments. 260 | model_args, data_args, training_args = parser.parse_json_file(json_file=os.path.abspath(sys.argv[1])) 261 | else: 262 | model_args, data_args, training_args = parser.parse_args_into_dataclasses() 263 | 264 | if model_args.use_auth_token is not None: 265 | warnings.warn("The `use_auth_token` argument is deprecated and will be removed in v4.34.", FutureWarning) 266 | if model_args.token is not None: 267 | raise ValueError("`token` and `use_auth_token` are both specified. Please set only the argument `token`.") 268 | model_args.token = model_args.use_auth_token 269 | 270 | # Sending telemetry. Tracking the example usage helps us better allocate resources to maintain them. The 271 | # information sent is the one passed as arguments along with your Python/PyTorch versions. 272 | send_example_telemetry("run_clm", model_args, data_args) 273 | 274 | # Setup logging 275 | logging.basicConfig( 276 | format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", 277 | datefmt="%m/%d/%Y %H:%M:%S", 278 | handlers=[logging.StreamHandler(sys.stdout)], 279 | ) 280 | 281 | if training_args.should_log: 282 | # The default of training_args.log_level is passive, so we set log level at info here to have that default. 283 | transformers.utils.logging.set_verbosity_info() 284 | 285 | log_level = training_args.get_process_log_level() 286 | logger.setLevel(log_level) 287 | datasets.utils.logging.set_verbosity(log_level) 288 | transformers.utils.logging.set_verbosity(log_level) 289 | transformers.utils.logging.enable_default_handler() 290 | transformers.utils.logging.enable_explicit_format() 291 | 292 | # Log on each process the small summary: 293 | logger.warning( 294 | f"Process rank: {training_args.local_rank}, device: {training_args.device}, n_gpu: {training_args.n_gpu}" 295 | + f"distributed training: {training_args.parallel_mode.value == 'distributed'}, 16-bits training: {training_args.fp16}" 296 | ) 297 | logger.info(f"Training/evaluation parameters {training_args}") 298 | 299 | # Detecting last checkpoint. 300 | last_checkpoint = None 301 | if os.path.isdir(training_args.output_dir) and training_args.do_train and not training_args.overwrite_output_dir: 302 | last_checkpoint = get_last_checkpoint(training_args.output_dir) 303 | if last_checkpoint is None and len(os.listdir(training_args.output_dir)) > 0: 304 | raise ValueError( 305 | f"Output directory ({training_args.output_dir}) already exists and is not empty. " 306 | "Use --overwrite_output_dir to overcome." 307 | ) 308 | elif last_checkpoint is not None and training_args.resume_from_checkpoint is None: 309 | logger.info( 310 | f"Checkpoint detected, resuming training at {last_checkpoint}. To avoid this behavior, change " 311 | "the `--output_dir` or add `--overwrite_output_dir` to train from scratch." 312 | ) 313 | 314 | # Set seed before initializing model. 315 | set_seed(training_args.seed) 316 | 317 | # Get the datasets: you can either provide your own CSV/JSON/TXT training and evaluation files (see below) 318 | # or just provide the name of one of the public datasets available on the hub at https://huggingface.co/datasets/ 319 | # (the dataset will be downloaded automatically from the datasets Hub). 320 | # 321 | # For CSV/JSON files, this script will use the column called 'text' or the first column if no column called 322 | # 'text' is found. You can easily tweak this behavior (see below). 323 | # 324 | # In distributed training, the load_dataset function guarantee that only one local process can concurrently 325 | # download the dataset. 326 | if data_args.dataset_name is not None: 327 | # Downloading and loading a dataset from the hub. 328 | raw_datasets = load_dataset( 329 | data_args.dataset_name, 330 | data_args.dataset_config_name, 331 | cache_dir=model_args.cache_dir, 332 | token=model_args.token, 333 | streaming=data_args.streaming, 334 | ) 335 | else: 336 | raise ValueError("The name of the dataset is not specified") 337 | 338 | # Load pretrained model and tokenizer 339 | # 340 | # Distributed training: 341 | # The .from_pretrained methods guarantee that only one local process can concurrently 342 | # download model & vocab. 343 | 344 | config_kwargs = { 345 | "cache_dir": model_args.cache_dir, 346 | "revision": model_args.model_revision, 347 | "token": model_args.token, 348 | "trust_remote_code": model_args.trust_remote_code, 349 | } 350 | if model_args.config_name: 351 | config = AutoConfig.from_pretrained(model_args.config_name, **config_kwargs) 352 | elif model_args.model_name_or_path: 353 | config = AutoConfig.from_pretrained(model_args.model_name_or_path, **config_kwargs) 354 | else: 355 | config = CONFIG_MAPPING[model_args.model_type]() 356 | logger.warning("You are instantiating a new config instance from scratch.") 357 | if model_args.config_overrides is not None: 358 | logger.info(f"Overriding config: {model_args.config_overrides}") 359 | config.update_from_string(model_args.config_overrides) 360 | logger.info(f"New config: {config}") 361 | 362 | tokenizer_kwargs = { 363 | "cache_dir": model_args.cache_dir, 364 | "use_fast": model_args.use_fast_tokenizer, 365 | "revision": model_args.model_revision, 366 | "token": model_args.token, 367 | "trust_remote_code": model_args.trust_remote_code, 368 | } 369 | if model_args.tokenizer_name: 370 | tokenizer = AutoTokenizer.from_pretrained(model_args.tokenizer_name, **tokenizer_kwargs) 371 | tokenizer.pad_token = tokenizer.eos_token 372 | tokenizer.padding_side = "left" 373 | elif model_args.model_name_or_path: 374 | tokenizer = AutoTokenizer.from_pretrained(model_args.model_name_or_path, **tokenizer_kwargs) 375 | tokenizer.pad_token = tokenizer.eos_token 376 | tokenizer.padding_side = "left" 377 | else: 378 | raise ValueError( 379 | "You are instantiating a new tokenizer from scratch. This is not supported by this script." 380 | "You can do it from another script, save it, and load it from here, using --tokenizer_name." 381 | ) 382 | 383 | if model_args.model_name_or_path and not model_args.from_scratch: 384 | torch_dtype = ( 385 | model_args.torch_dtype 386 | if model_args.torch_dtype in ["auto", None] 387 | else getattr(torch, model_args.torch_dtype) 388 | ) 389 | model = AutoModelForCausalLM.from_pretrained( 390 | model_args.model_name_or_path, 391 | from_tf=bool(".ckpt" in model_args.model_name_or_path), 392 | config=config, 393 | cache_dir=model_args.cache_dir, 394 | revision=model_args.model_revision, 395 | token=model_args.token, 396 | trust_remote_code=model_args.trust_remote_code, 397 | torch_dtype=torch_dtype, 398 | low_cpu_mem_usage=model_args.low_cpu_mem_usage, 399 | ) 400 | else: 401 | model = AutoModelForCausalLM.from_config(config, trust_remote_code=model_args.trust_remote_code) 402 | n_params = sum({p.data_ptr(): p.numel() for p in model.parameters()}.values()) 403 | logger.info(f"Training new model from scratch - Total size={n_params / 2 ** 20:.2f}M params") 404 | 405 | # We resize the embeddings only when necessary to avoid index errors. If you are creating a model from scratch 406 | # on a small vocab and want a smaller embedding size, remove this test. 407 | embedding_size = model.get_input_embeddings().weight.shape[0] 408 | if len(tokenizer) > embedding_size: 409 | model.resize_token_embeddings(len(tokenizer)) 410 | 411 | # Preprocessing the datasets. 412 | # First we tokenize all the texts. 413 | if training_args.do_train: 414 | column_names = list(raw_datasets["train"].features) 415 | else: 416 | column_names = list(raw_datasets["validation"].features) 417 | text_column_name = "text" if "text" in column_names else column_names[0] 418 | 419 | # since this will be pickled to avoid _LazyModule error in Hasher force logger loading before tokenize_function 420 | tok_logger = transformers.utils.logging.get_logger("transformers.tokenization_utils_base") 421 | 422 | def add_text_column(example): 423 | example[text_column_name] = example['inputs'] + example['targets'] 424 | return example 425 | 426 | def tokenize(prompt, add_eos_token=True): 427 | # there's probably a way to do this with the tokenizer settings 428 | # but again, gotta move fast 429 | result = tokenizer( 430 | prompt, 431 | truncation=True, 432 | max_length=1024, 433 | padding=False, 434 | return_tensors=None, 435 | ) 436 | if ( 437 | result["input_ids"][-1] != tokenizer.eos_token_id 438 | and len(result["input_ids"]) < 1024 439 | and add_eos_token 440 | ): 441 | result["input_ids"].append(tokenizer.eos_token_id) 442 | result["attention_mask"].append(1) 443 | 444 | result["labels"] = result["input_ids"].copy() 445 | 446 | return result 447 | 448 | def generate_and_tokenize_prompt(example): 449 | full_prompt = example['inputs'] + " " + example['targets'] 450 | tokenized_full_prompt = tokenize(full_prompt) 451 | if not model_args.train_on_inputs: 452 | user_prompt = example['inputs'] 453 | tokenized_user_prompt = tokenize( 454 | user_prompt, add_eos_token=model_args.add_eos_token 455 | ) 456 | user_prompt_len = len(tokenized_user_prompt["input_ids"]) 457 | 458 | if model_args.add_eos_token: 459 | user_prompt_len -= 1 460 | 461 | tokenized_full_prompt["labels"] = [ 462 | -100 463 | ] * user_prompt_len + tokenized_full_prompt["labels"][ 464 | user_prompt_len: 465 | ] # could be sped up, probably 466 | return tokenized_full_prompt 467 | 468 | def tokenize_function(examples): 469 | if text_column_name not in examples: 470 | examples = examples.map(add_text_column) 471 | with CaptureLogger(tok_logger) as cl: 472 | output = tokenizer(examples[text_column_name], padding="max_length", truncation=True, max_length=1024) 473 | # clm input could be much much longer than block_size 474 | if "Token indices sequence length is longer than the" in cl.out: 475 | tok_logger.warning( 476 | "^^^^^^^^^^^^^^^^ Please ignore the warning above - this long input will be chunked into smaller bits" 477 | " before being passed to the model." 478 | ) 479 | output['labels'] = output['input_ids'].copy() 480 | return output 481 | 482 | with training_args.main_process_first(desc="dataset map tokenization"): 483 | if not data_args.streaming: 484 | tokenized_datasets = raw_datasets.map( 485 | generate_and_tokenize_prompt, 486 | # batched=True, 487 | num_proc=data_args.preprocessing_num_workers, 488 | remove_columns=column_names, 489 | load_from_cache_file=not data_args.overwrite_cache, 490 | desc="Running tokenizer on dataset", 491 | ) 492 | else: 493 | tokenized_datasets = raw_datasets.map( 494 | tokenize_function, 495 | batched=True, 496 | remove_columns=column_names, 497 | ) 498 | 499 | lm_datasets = tokenized_datasets 500 | if training_args.do_train: 501 | if "train" not in tokenized_datasets: 502 | raise ValueError("--do_train requires a train dataset") 503 | train_dataset = lm_datasets["train"] 504 | if data_args.max_train_samples is not None: 505 | max_train_samples = min(len(train_dataset), data_args.max_train_samples) 506 | train_dataset = train_dataset.select(range(max_train_samples)) 507 | 508 | if training_args.do_eval: 509 | if "validation" not in tokenized_datasets: 510 | raise ValueError("--do_eval requires a validation dataset") 511 | eval_dataset = lm_datasets["validation"] 512 | if data_args.max_eval_samples is not None: 513 | max_eval_samples = min(len(eval_dataset), data_args.max_eval_samples) 514 | eval_dataset = eval_dataset.select(range(max_eval_samples)) 515 | 516 | def preprocess_logits_for_metrics(logits, labels): 517 | if isinstance(logits, tuple): 518 | # Depending on the model and config, logits may contain extra tensors, 519 | # like past_key_values, but logits always come first 520 | logits = logits[0] 521 | return logits.argmax(dim=-1) 522 | 523 | metric = evaluate.load("accuracy") 524 | 525 | def compute_metrics(eval_preds): 526 | preds, labels = eval_preds 527 | labels = labels[:, 1:].reshape(-1) 528 | preds = preds[:, :-1].reshape(-1) 529 | return metric.compute(predictions=preds, references=labels) 530 | 531 | data_collator = transformers.DataCollatorForSeq2Seq( 532 | tokenizer, pad_to_multiple_of=8, return_tensors="pt", padding=True 533 | ) 534 | wandb.init() 535 | # define metrics in wandb 536 | for metric_name in ["bleu", "qa_bleu", "recite_bleu", "qa_exact_match", "qa_f1", 'exact_match', 'f1']: 537 | wandb.define_metric(f"eval/{metric_name}", summary="max") 538 | 539 | is_qa = False 540 | if 'squad' in data_args.dataset_name: 541 | is_qa = True 542 | elif 'nq' in data_args.dataset_name or 'hotpot' in data_args.dataset_name: 543 | if "doc" in data_args.dataset_name and "qa" not in data_args.dataset_name: 544 | is_qa = False 545 | else: 546 | is_qa = True 547 | 548 | # Initialize our Trainer 549 | trainer = QuestionAnsweringLMTrainer( 550 | model=model, 551 | args=training_args, 552 | train_dataset=train_dataset if training_args.do_train else None, 553 | eval_dataset=eval_dataset if training_args.do_eval else None, 554 | tokenizer=tokenizer, 555 | # Data collator will default to DataCollatorWithPadding, so we change it. 556 | data_collator=data_collator, 557 | compute_metrics=compute_metrics if training_args.do_eval and not is_torch_tpu_available() else None, 558 | preprocess_logits_for_metrics=preprocess_logits_for_metrics 559 | if training_args.do_eval and not is_torch_tpu_available() 560 | else None, 561 | eval_examples=raw_datasets['validation'], 562 | is_squad=is_qa, 563 | do_recite=True if "recite" in data_args.dataset_name else False 564 | ) 565 | 566 | # Training 567 | if training_args.do_train: 568 | checkpoint = None 569 | if type(training_args.resume_from_checkpoint)==str and os.path.exists(training_args.resume_from_checkpoint) : 570 | checkpoint = training_args.resume_from_checkpoint 571 | elif last_checkpoint is not None: 572 | checkpoint = last_checkpoint 573 | print("Last checkpoint", checkpoint) 574 | train_result = trainer.train(resume_from_checkpoint=checkpoint) 575 | trainer.save_model() # Saves the tokenizer too for easy upload 576 | 577 | metrics = train_result.metrics 578 | 579 | max_train_samples = ( 580 | data_args.max_train_samples if data_args.max_train_samples is not None else len(train_dataset) 581 | ) 582 | metrics["train_samples"] = min(max_train_samples, len(train_dataset)) 583 | 584 | trainer.log_metrics("train", metrics) 585 | trainer.save_metrics("train", metrics) 586 | trainer.save_state() 587 | 588 | # Evaluation 589 | if training_args.do_eval: 590 | logger.info("*** Evaluate ***") 591 | 592 | metrics = trainer.evaluate() 593 | 594 | max_eval_samples = data_args.max_eval_samples if data_args.max_eval_samples is not None else len(eval_dataset) 595 | metrics["eval_samples"] = min(max_eval_samples, len(eval_dataset)) 596 | try: 597 | perplexity = math.exp(metrics["eval_loss"]) 598 | except OverflowError: 599 | perplexity = float("inf") 600 | metrics["perplexity"] = perplexity 601 | 602 | for example in tqdm.tqdm(raw_datasets['validation']): 603 | prompt = example['inputs'] 604 | prompt = torch.tensor(tokenizer.encode(prompt)).unsqueeze(0).to("cuda") 605 | outputs = model.generate(inputs=prompt, max_new_tokens=40) 606 | print(outputs) 607 | output_string = tokenizer.decode(outputs[0], skip_special_tokens=True) 608 | print("output_string", output_string) 609 | 610 | trainer.log_metrics("eval", metrics) 611 | trainer.save_metrics("eval", metrics) 612 | 613 | kwargs = {"finetuned_from": model_args.model_name_or_path, "tasks": "text-generation"} 614 | if data_args.dataset_name is not None: 615 | kwargs["dataset_tags"] = data_args.dataset_name 616 | if data_args.dataset_config_name is not None: 617 | kwargs["dataset_args"] = data_args.dataset_config_name 618 | kwargs["dataset"] = f"{data_args.dataset_name} {data_args.dataset_config_name}" 619 | else: 620 | kwargs["dataset"] = data_args.dataset_name 621 | 622 | if training_args.push_to_hub: 623 | trainer.push_to_hub(**kwargs) 624 | else: 625 | trainer.create_model_card(**kwargs) 626 | 627 | 628 | def _mp_fn(index): 629 | # For xla_spawn (TPUs) 630 | main() 631 | 632 | 633 | if __name__ == "__main__": 634 | main() 635 | --------------------------------------------------------------------------------